<a href="https://colab.research.google.com/github/yifan-grace-tang/final-project/blob/main/Renee/renee's_work_from_grace_mlcb_project_04_24.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## __Protein Function Prediction Leveraging Sequence and Structure Information__

> __Author:__ Grace Tang

> __Edited:__ `04.23.25`

---


#### __Description__

Predicting the functions of proteins is an essential problem in the study of proteins. Many previous works developed computational methods that utilize protein sequence information to predict protein functions. With the advancement of `AlphaFold2`, accurate protein structure data has been available for hundreds of millions of proteins.

Incorporating structure information should be able to further boost the performance of protein function prediction methods. There are several potential ways to utilize the structure information: use computer-vision-like models to extract information from contact maps; use graph neural networks to encode 3D structures; use pretrained protein structure model (e.g., `ESM-IF`).

In addition, with the advancement of LLMs, the text annotations of proteins can also be leveraged as an extra information source for protein function prediction, e.g., use LLM as encoder to encode the text annotations of proteins. Utilizing the information from sequence, structure, and text annotations, we can develop a model that accurately predicts protein functions.

__Using a subset of the dataset used for [DeepFRI](https://github.com/flatironinstitute/DeepFRI) (a similar protein function prediction model) we aim to coallesce sequence data and structure data to predict protein function in the form of [EC Numbers](https://en.wikipedia.org/wiki/Enzyme_Commission_number).__

[[1]](ttps://www.biorxiv.org/content/10.1101/2022.11.29.518451v1), [[2]](https://www.nature.com/articles/s42003-024-07359-z), [[3]](https://www.nature.com/articles/s41467-021-23303-9), [[4]](https://www.biorxiv.org/content/10.1101/2024.05.14.594226v1 )

#### __Methodology__

> _This section relies on having_ `annotations.tsv`, `train.txt`, `validation.txt`, _and_ `sequences.fasta` _in your runtime_


In [1]:
!pip install torch-geometric
!pip install transformers biopython scipy

Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m28.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.6.1
Collecting biopython
  Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m48.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: biopython
Successfu

In [2]:
import os
import pandas as pd
import re
import numpy as np
import requests

import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader as GeometricDataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
import torch.optim as optim

from transformers import BertTokenizer, BertModel

from Bio.PDB import MMCIFParser

from scipy.spatial.distance import pdist, squareform

from sklearn.metrics import precision_score, recall_score, f1_score

We start by defining some utility functions that help parse the input data into a representation that can be fed into our downstream `Dataset` classes.

In [3]:
def parse_fasta(path):
    seqs = {}
    with open(path) as f:
        curr, lines = None, []
        for line in f:
            if line.startswith('>'):
                if curr:
                    seqs[curr] = ''.join(lines)
                curr = line[1:].split()[0]
                lines = []
            else:
                lines.append(line.strip())
        if curr:
            seqs[curr] = ''.join(lines)
    return seqs

def parse_annotations(path):
    ann_map, all_ec = {}, set()
    with open(path) as f:
        for ln in f:
            if ln.startswith('### PDB-chain'):
                break
        for ln in f:
            if not ln.strip() or ln.startswith('#'): continue
            pid, ecs = ln.strip().split('\t')
            ec_list = [e.strip() for e in ecs.split(',') if e.strip()]
            ann_map[pid] = ec_list
            all_ec.update(ec_list)
    classes = sorted(all_ec)
    return ann_map, classes

def fetch_cif(pdb_chain, out_dir="structures"):
    pdb_id, chain = pdb_chain.split('-')
    os.makedirs(out_dir, exist_ok=True)
    local_path = os.path.join(out_dir, f"{pdb_chain}.cif")
    if not os.path.exists(local_path):
        url = f"https://files.rcsb.org/download/{pdb_id}.cif"
        resp = requests.get(url)
        resp.raise_for_status()

        with open(local_path, "wb") as f:
            f.write(resp.content)
    return local_path

def build_df(ids, seqs, ann_map, classes):
    rows = []
    for pid in ids:
        seq = seqs.get(pid, '')
        if not seq:
            print(f"Warning: {pid} not found")
            continue
        labels = [1 if ec in ann_map.get(pid, []) else 0 for ec in classes]
        rows.append({'id': pid, 'sequence': seq, 'labels': labels})
    return pd.DataFrame(rows)

Next we read the `ids` from the `train.txt` and `validation.txt` to perform our _splitting_ into a train and validation dataset.

In [4]:
train_ids = [l.strip() for l in open('train.txt') if l.strip()]
val_ids   = [l.strip() for l in open('validation.txt') if l.strip()]
seqs_map  = parse_fasta('sequences.fasta')
ann_map, classes = parse_annotations('annotations.tsv')

This `Dataset` class feeds each protein through two parallel pipelines — sequence and structure — and returns a single `torch_geometric.data.Data` object per example for graph‐based training.

1. **Structure Fetching:** Given an ID like `4PR3-A`, it downloads the corresponding `mmCIF` file from `RCSB` if not already cached. Uses Biopython’s `PDBParser` to extract all Cα atom coordinates for the specified chain.

2. **Graph Construction:** Builds an undirected residue‐level graph by connecting any two Cα atoms within 8 Å.  Encodes edge weights as the inverse distance (`1/d`) to reflect spatial proximity.

3. **Node Features:** Represents each residue by a 20‐dimensional one-hot vector (one position per standard amino acid). Ensures that non-standard letters (U, Z, O, B) are mapped to “X” and treated uniformly.

4. **`ProtBert` Embedding** Cleans and “space-separates” the amino-acid string for `Rostlab/prot_bert`.  Tokenizes, pads/truncates, and runs the model in `eval()` mode to get the last hidden states.  Applies mean-pooling (masking out padding) to produce a fixed‐size (1024-dim) global embedding.

5. **Label Vector:** Builds a multi-hot target vector for all EC classes based on `annotations.tsv`.

6. **Output**  
   - Returns a `Data` object with fields:  
     - `x`: `[L, 20]` residue features  
     - `edge_index`: `[2, E]` graph connectivity  
     - `edge_attr`: `[E, 1]` distance weights  
     - `seq_emb`: `[1024]` `ProtBert` fingerprint  
     - `y`: `[num_ec]` multi-hot EC labels  


In [5]:
class ProteinDataset(Dataset):
    def __init__(self, ids, seqs_map, ann_map, classes,
                 max_length=512, device=None, ca_cutoff=8.0,
                 pdb_dir="structures"):
        self.ids      = ids
        self.seqs_map = seqs_map
        self.labels   = {
            pid: torch.tensor([1 if ec in ann_map.get(pid, []) else 0
                               for ec in classes], dtype=torch.float)
            for pid in ids
        }
        self.classes = classes
        self.max_length = max_length
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.ca_cutoff = ca_cutoff
        self.pdb_dir = pdb_dir

        self.tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
        self.bert      = BertModel.from_pretrained("Rostlab/prot_bert").to(self.device).eval()

        self.parser = MMCIFParser(QUIET=True)

        aa_list = list("ACDEFGHIKLMNPQRSTVWY")
        self.aa2idx = {aa:i for i, aa in enumerate(aa_list)}
        self.num_tokens = len(aa_list)

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        pid = self.ids[idx]

        seq = self.seqs_map[pid].upper()
        clean = re.sub(r"[^ACDEFGHIKLMNPQRSTVWY]", "X", seq)
        spaced = " ".join(clean)
        enc = self.tokenizer(
            spaced, padding="max_length", truncation=True,
            max_length=self.max_length, return_tensors="pt"
        )
        input_ids      = enc["input_ids"].to(self.device)
        attention_mask = enc["attention_mask"].to(self.device)
        with torch.no_grad():
            out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
            last = out.last_hidden_state                  # [1, L, H]
            mask = attention_mask.unsqueeze(-1).float()   # [1, L, 1]
            summed = (last * mask).sum(1)                 # [1, H]
            lengths = mask.sum(1)
            seq_emb = (summed / lengths).squeeze(0).cpu() # [H]

        cif_path = fetch_cif(pid, out_dir=self.pdb_dir)
        struct = self.parser.get_structure(pid, cif_path)
        ca_coords = []
        chain_id = pid.split("-")[1]
        for model in struct:
            for chain in model:
                if chain.id == chain_id:
                    for res in chain:
                        if "CA" in res:
                            ca_coords.append(res["CA"].get_coord())
        coords = np.array(ca_coords)

        dmat = squareform(pdist(coords))
        row, col = np.where((dmat < self.ca_cutoff) & (dmat > 0))
        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr  = torch.tensor((1.0 / dmat[row, col])[:, None], dtype=torch.float)

        feats = []
        for aa in clean:
            vec = np.zeros(self.num_tokens, dtype=float)
            if aa in self.aa2idx:
                vec[self.aa2idx[aa]] = 1.0
            feats.append(vec)
        x = torch.tensor(feats, dtype=torch.float)

        return Data(
            x=x,
            edge_index=edge_index,
            edge_attr=edge_attr,
            seq_emb=seq_emb,
            y=self.labels[pid]
        )

Using our `Dataset` class now we can input our training and validation data and define some `DataLoaders` to be used by our chosen model.

> We are leveraging a `df_train_mini` subset for faster development for our final model we will use the _full_ training set.

In [6]:
df_train_mini = train_ids[:1000]

train_ds = ProteinDataset(df_train_mini, seqs_map, ann_map, classes, pdb_dir='structures')
val_ds   = ProteinDataset(val_ids,       seqs_map, ann_map, classes, pdb_dir='structures')

train_loader = GeometricDataLoader(train_ds, batch_size=16, shuffle=True)
val_loader   = GeometricDataLoader(val_ds,   batch_size=16)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/86.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/81.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/361 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.68G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.68G [00:00<?, ?B/s]

This `ConvClassifer` ingests the `Data` objects produced by `ProteinDataset` and fuses structure and sequence information to predict one or more EC numbers.

1. **Structure Branch (GCN)**  
   - **GCNConv(20 → 128)** + ReLU  
   - **GCNConv(128 → 128)** + ReLU  
   - **Global Mean Pooling** over residues → a 128-dim graph embedding

2. **Sequence Branch (MLP)**  
   - **Linear(1024 → 128)** + ReLU  
   - Takes the `ProtBert` pooled embedding and projects it down to 128 dims

3. **Fusion**  
   - **Concatenate** the 128-dim GCN output and 128-dim MLP output → 256-dim vector  
   - **MLP Head**: Linear(256 → 256) + ReLU + Dropout(0.1) + Linear(256 → `num_ec`)  
   - Outputs one raw logit per EC class

4. **Training Details**  
   - **Loss**: `BCEWithLogitsLoss` for multi-label classification  
   - **Optimizer**: AdamW with `lr=1e-3`, `weight_decay=1e-5`  
   - **DataLoader**: uses `torch_geometric.loader.DataLoader` to batch graphs  

In [14]:
class ConvClassifier(nn.Module):
    def __init__(self, seq_emb_dim, node_feat_dim, gcn_hidden, num_classes):
        super().__init__()
        self.seq_mlp = nn.Sequential(
            nn.Linear(seq_emb_dim, gcn_hidden),  # was backward before
            nn.ReLU()
        )

        self.conv1 = GCNConv(node_feat_dim, gcn_hidden)
        self.conv2 = GCNConv(gcn_hidden,      gcn_hidden)

        self.classifier = nn.Sequential(
            nn.Linear(2 * gcn_hidden, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, data):
        x, edge_index, edge_attr, batch = (
            data.x, data.edge_index, data.edge_attr, data.batch
        )

        h = self.conv1(x, edge_index, edge_weight=edge_attr.squeeze())
        h = torch.relu(h)
        h = self.conv2(h, edge_index, edge_weight=edge_attr.squeeze())
        h = torch.relu(h)
        gcn_emb = global_mean_pool(h, batch)  # → [B, gcn_hidden]

        seq_emb_raw = data.seq_emb             # [B, seq_emb_dim]
        seq_emb = seq_emb_raw.to(gcn_emb.device)
        seq_emb = self.seq_mlp(seq_emb)        # → [B, gcn_hidden]

        comb = torch.cat([gcn_emb, seq_emb], dim=1)  # → [B, 256]
        return self.classifier(comb)


In [15]:
example   = train_ds[0]
seq_dim   = example.seq_emb.shape[0]
node_dim  = example.x.shape[1]

model = ConvClassifier(seq_emb_dim=seq_dim,
                      node_feat_dim=node_dim,
                      gcn_hidden=128,
                      num_classes=len(classes)).to(torch.device('cuda'))
opt   = torch.optim.AdamW(model.parameters(), lr=1e-3)
crit  = nn.BCEWithLogitsLoss()

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [9]:
for epoch in range(1, 11):
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(model.classifier[0].weight.device)
        logits = model(batch)
        loss = crit(logits, batch.y)
        opt.zero_grad(); loss.backward(); opt.step()
        total_loss += loss.item() * batch.num_graphs
    print(f"Epoch {epoch} Train Loss: {total_loss/len(train_ds):.4f}")

    model.eval()
    all_preds, all_targs = [], []
    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(model.classifier[0].weight.device)
            probs = torch.sigmoid(model(batch)).cpu()
            all_preds.append(probs)
            all_targs.append(batch.y.cpu())
    preds = torch.vstack(all_preds).numpy() > 0.5
    targs = torch.vstack(all_targs).numpy() > 0.5
    print("Val P, R, F1=",
          precision_score(targs, preds, average='micro', zero_division=0),
          recall_score(targs, preds, average='micro', zero_division=0),
          f1_score(targs, preds, average='micro', zero_division=0))

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
def find_best_thresholds_per_ec(y_true, y_probs, metric='f1'):
    thresholds = np.linspace(0.1, 0.9, 81)
    best_thresholds = []

    for i in range(y_true.shape[1]):
        best_score = -1
        best_t = 0.5
        for t in thresholds:
            preds = (y_probs[:, i] >= t).astype(int)
            if metric == 'f1':
                score = f1_score(y_true[:, i], preds, zero_division=0)
            elif metric == 'precision':
                score = precision_score(y_true[:, i], preds)
            elif metric == 'recall':
                score = recall_score(y_true[:, i], preds)
            if score > best_score:
                best_score = score
                best_t = t
        best_thresholds.append(best_t)

    return np.array(best_thresholds)

In [16]:
model.eval()
all_true, all_probs = [], []
with torch.no_grad():
    for seqs, labs in val_loader:
        seqs = seqs.to(device)
        logits = model(seqs)
        probs  = torch.sigmoid(logits).cpu()  # probabilities
        all_true.append(labs)
        all_probs.append(probs)

y_val_true = torch.cat(all_true, dim=0).numpy()
y_val_probs = torch.cat(all_probs, dim=0).numpy()
best_thresholds = find_best_thresholds_per_ec(y_val_true, y_val_probs, metric='f1')

true = torch.cat(all_true, dim=0)
pred = (torch.cat(all_probs, dim=0).numpy() >= best_thresholds).astype(int)

rows = []
for pid, trow, prow in zip(val_ids, true, pred):
    true_ecs = [c for c, flag in zip(classes, trow) if flag]
    pred_ecs = [c for c, flag in zip(classes, prow) if flag]
    rows.append({
        'strain': pid,
        'validation EC class': ';'.join(true_ecs) or '-',
        'predicted': ';'.join(pred_ecs) or '-'
    })

results_df = pd.DataFrame(rows, columns=['strain', 'validation EC class', 'predicted'])
print(results_df)

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
