## __Protein Function Prediction Leveraging Sequence and Structure Information__

> __Author:__ Grace Tang

> __Edited:__ `04.24.25`

> __Runtime:__ `A100`

---


#### __Description__

Predicting the functions of proteins is an essential problem in the study of proteins. Many previous works developed computational methods that utilize protein sequence information to predict protein functions. With the advancement of `AlphaFold2`, accurate protein structure data has been available for hundreds of millions of proteins.

Incorporating structure information should be able to further boost the performance of protein function prediction methods. There are several potential ways to utilize the structure information: use computer-vision-like models to extract information from contact maps; use graph neural networks to encode 3D structures; use pretrained protein structure model (e.g., `ESM-IF`).

In addition, with the advancement of LLMs, the text annotations of proteins can also be leveraged as an extra information source for protein function prediction, e.g., use LLM as encoder to encode the text annotations of proteins. Utilizing the information from sequence, structure, and text annotations, we can develop a model that accurately predicts protein functions.

__Using a subset of the dataset used for [DeepFRI](https://github.com/flatironinstitute/DeepFRI) (a similar protein function prediction model) we aim to coallesce sequence data and structure data to predict protein function in the form of [EC Numbers](https://en.wikipedia.org/wiki/Enzyme_Commission_number).__

[[1]](ttps://www.biorxiv.org/content/10.1101/2022.11.29.518451v1), [[2]](https://www.nature.com/articles/s42003-024-07359-z), [[3]](https://www.nature.com/articles/s41467-021-23303-9), [[4]](https://www.biorxiv.org/content/10.1101/2024.05.14.594226v1 )

#### __Methodology__

> _This section relies on having_ `annotations.tsv`, `train.txt`, `validation.txt`, _and_ `sequences.fasta` _in your runtime_

> _While not strictly necessary, we **highly** recommend importing a_ `structures.zip` _into your runtime as well with all the structure information - this will greatly assist the computation time. Even better if you have a_  `cache.zip` _with all pre-computed pytorch vectors that will assist greatly!_


In [3]:

from google.colab import files
uploaded = files.upload()

Saving cache.zip to cache.zip
Saving annotations.tsv to annotations.tsv
Saving sequences.fasta to sequences.fasta
Saving train.txt to train.txt
Saving validation.txt to validation.txt


In [4]:
!unzip cache.zip -d cache

Archive:  cache.zip
   creating: cache/cache/
  inflating: cache/cache/2QYU-A.pt   
  inflating: cache/cache/2DCE-A.pt   
  inflating: cache/cache/1DIR-A.pt   
  inflating: cache/cache/5AA5-C.pt   
  inflating: cache/cache/2RS7-A.pt   
  inflating: cache/cache/2YV2-A.pt   
  inflating: cache/cache/3WRX-C.pt   
  inflating: cache/cache/3PPS-A.pt   
  inflating: cache/cache/3A24-A.pt   
  inflating: cache/cache/2YWX-A.pt   
  inflating: cache/cache/5Y4R-A.pt   
  inflating: cache/cache/3KV5-D.pt   
  inflating: cache/cache/3DKO-A.pt   
  inflating: cache/cache/6OJO-A.pt   
  inflating: cache/cache/4H05-A.pt   
  inflating: cache/cache/2D5R-A.pt   
  inflating: cache/cache/3N6R-A.pt   
  inflating: cache/cache/1O54-A.pt   
  inflating: cache/cache/2JEA-B.pt   
  inflating: cache/cache/1I5E-A.pt   
  inflating: cache/cache/1F7U-A.pt   
  inflating: cache/cache/3M7I-A.pt   
  inflating: cache/cache/5AG8-A.pt   
  inflating: cache/cache/4PWV-A.pt   
  inflating: cache/cache/4M1N-A.pt   
  in

In [17]:
!mv cache/cache/* cache/


In [18]:
!rmdir cache/cache

In [5]:
!pip install torch-geometric
!pip install transformers biopython scipy

Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━[0m [32m0.8/1.1 MB[0m [31m24.0 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m21.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.6.1
Collecting biopython
  Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Downloading biopython-

In [6]:
import os
import pandas as pd
import re
import numpy as np
import requests

import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader as GeometricDataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
import torch.optim as optim
from torch.serialization import add_safe_globals
from torch_geometric.data import Data
from torch_geometric.data.data import DataEdgeAttr, DataTensorAttr
from torch_geometric.data.storage import GlobalStorage

from transformers import BertTokenizer, BertModel

from Bio.PDB import MMCIFParser

from scipy.spatial.distance import pdist, squareform

from sklearn.metrics import precision_score, recall_score, f1_score

add_safe_globals([Data, DataEdgeAttr, DataTensorAttr, GlobalStorage])

We start by defining some utility functions that help parse the input data into a representation that can be fed into our downstream `Dataset` classes.

In [8]:
def parse_fasta(path):
    seqs = {}
    with open(path) as f:
        curr, lines = None, []
        for line in f:
            if line.startswith('>'):
                if curr:
                    seqs[curr] = ''.join(lines)
                curr = line[1:].split()[0]
                lines = []
            else:
                lines.append(line.strip())
        if curr:
            seqs[curr] = ''.join(lines)
    return seqs

def parse_annotations(path):
    ann_map, all_ec = {}, set()
    with open(path) as f:
        for ln in f:
            if ln.startswith('### PDB-chain'):
                break
        for ln in f:
            if not ln.strip() or ln.startswith('#'): continue
            pid, ecs = ln.strip().split('\t')
            ec_list = [e.strip() for e in ecs.split(',') if e.strip()]
            ann_map[pid] = ec_list
            all_ec.update(ec_list)
    classes = sorted(all_ec)
    return ann_map, classes

def fetch_cif(pdb_chain, out_dir="structures"):
    pdb_id, chain = pdb_chain.split('-')
    os.makedirs(out_dir, exist_ok=True)
    local_path = os.path.join(out_dir, f"{pdb_chain}.cif")
    if not os.path.exists(local_path):
        url = f"https://files.rcsb.org/download/{pdb_id}.cif"
        resp = requests.get(url)
        resp.raise_for_status()

        with open(local_path, "wb") as f:
            f.write(resp.content)
    return local_path

def build_df(ids, seqs, ann_map, classes):
    rows = []
    for pid in ids:
        seq = seqs.get(pid, '')
        if not seq:
            print(f"Warning: {pid} not found")
            continue
        labels = [1 if ec in ann_map.get(pid, []) else 0 for ec in classes]
        rows.append({'id': pid, 'sequence': seq, 'labels': labels})
    return pd.DataFrame(rows)

Next we read the `ids` from the `train.txt` and `validation.txt` to perform our _splitting_ into a train and validation dataset.

In [9]:
train_ids = [l.strip() for l in open('train.txt') if l.strip()]
val_ids   = [l.strip() for l in open('validation.txt') if l.strip()]
seqs_map  = parse_fasta('sequences.fasta')
ann_map, classes = parse_annotations('annotations.tsv')

This `Dataset` class feeds each protein through two parallel pipelines — sequence and structure — and returns a single `torch_geometric.data.Data` object per example for graph‐based training.

1. **Structure Fetching:** Given an ID like `4PR3-A`, it downloads the corresponding `mmCIF` file from `RCSB` if not already cached. Uses Biopython’s `PDBParser` to extract all Cα atom coordinates for the specified chain.

2. **Graph Construction:** Builds an undirected residue‐level graph by connecting any two Cα atoms within 8 Å.  Encodes edge weights as the inverse distance (`1/d`) to reflect spatial proximity.

3. **Node Features:** Represents each residue by a 20‐dimensional one-hot vector (one position per standard amino acid). Ensures that non-standard letters (U, Z, O, B) are mapped to “X” and treated uniformly.

4. **`ProtBert` Embedding** Cleans and “space-separates” the amino-acid string for `Rostlab/prot_bert`.  Tokenizes, pads/truncates, and runs the model in `eval()` mode to get the last hidden states.  Applies mean-pooling (masking out padding) to produce a fixed‐size (1024-dim) global embedding.

5. **Label Vector:** Builds a multi-hot target vector for all EC classes based on `annotations.tsv`.

6. **Output**  
   - Returns a `Data` object with fields:  
     - `x`: `[L, 20]` residue features  
     - `edge_index`: `[2, E]` graph connectivity  
     - `edge_attr`: `[E, 1]` distance weights  
     - `seq_emb`: `[1024]` `ProtBert` fingerprint  
     - `y`: `[num_ec]` multi-hot EC labels  


In [None]:
!rm -rf structures/*
!rm -rf cache/*

In [20]:
import os
import torch
from torch.utils.data import Dataset

class ProteinDataset(Dataset):
    def __init__(self, ids, seqs_map, ann_map, classes, pdb_dir=None, cache_dir='cache'):
        self.ids = ids
        self.seqs_map = seqs_map
        self.ann_map = ann_map
        self.classes = classes
        self.pdb_dir = pdb_dir
        self.cache_dir = cache_dir

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        id = self.ids[idx]
        cache_path = os.path.join(self.cache_dir, f"{id}.pt")

        if os.path.exists(cache_path):
            data = torch.load(cache_path)
            return data
        else:
            raise FileNotFoundError(f"Cached tensor not found for ID: {id}. Please check your cache folder.")



Using our `Dataset` class now we can input our training and validation data and define some `DataLoaders` to be used by our chosen model.

> We are leveraging a `df_train_mini` subset for faster development for our final model we will use the _full_ training set.

In [22]:
df_train_mini = train_ids[:1000]

train_ds = ProteinDataset(train_ids, seqs_map, ann_map, classes, cache_dir='cache')
val_ds   = ProteinDataset(val_ids, seqs_map, ann_map, classes, cache_dir='cache')

train_loader = GeometricDataLoader(train_ds, batch_size=16, shuffle=True)
val_loader   = GeometricDataLoader(val_ds,   batch_size=16)

This `ConvClassifer` ingests the `Data` objects produced by `ProteinDataset` and fuses structure and sequence information to predict one or more EC numbers.

1. **Structure Branch (GCN)**  
   - **GCNConv(20 → 128)** + ReLU  
   - **GCNConv(128 → 128)** + ReLU  
   - **Global Mean Pooling** over residues → a 128-dim graph embedding

2. **Sequence Branch (MLP)**  
   - **Linear(1024 → 128)** + ReLU  
   - Takes the `ProtBert` pooled embedding and projects it down to 128 dims

3. **Fusion**  
   - **Concatenate** the 128-dim GCN output and 128-dim MLP output → 256-dim vector  
   - **MLP Head**: Linear(256 → 256) + ReLU + Dropout(0.1) + Linear(256 → `num_ec`)  
   - Outputs one raw logit per EC class

4. **Training Details**  
   - **Loss**: `BCEWithLogitsLoss` for multi-label classification  
   - **Optimizer**: AdamW with `lr=1e-3`, `weight_decay=1e-5`  
   - **DataLoader**: uses `torch_geometric.loader.DataLoader` to batch graphs  

In [23]:
class ConvClassifier(nn.Module):
    def __init__(self, seq_emb_dim, node_feat_dim, gcn_hidden, num_classes):
        super().__init__()
        self.seq_mlp = nn.Sequential(
            nn.Linear(seq_emb_dim, gcn_hidden),  # was backward before
            nn.ReLU()
        )

        self.conv1 = GCNConv(node_feat_dim, gcn_hidden)
        self.conv2 = GCNConv(gcn_hidden,      gcn_hidden)

        self.classifier = nn.Sequential(
            nn.Linear(2 * gcn_hidden, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, data):
        x, edge_index, edge_attr, batch = (
            data.x, data.edge_index, data.edge_attr, data.batch
        )

        edge_index = edge_index.clamp(0, x.size(0) - 1)

        h = self.conv1(x, edge_index, edge_weight=edge_attr.squeeze())
        h = torch.relu(h)
        h = self.conv2(h, edge_index, edge_weight=edge_attr.squeeze())
        h = torch.relu(h)
        gcn_emb = global_mean_pool(h, batch)  # → [B, gcn_hidden]

        seq_emb_raw = data.seq_emb             # [B, seq_emb_dim]
        B = gcn_emb.size(0)
        seq_emb = seq_emb_raw.view(B, -1).to(gcn_emb.device)

        seq_emb = self.seq_mlp(seq_emb)        # → [B, gcn_hidden]

        comb = torch.cat([gcn_emb, seq_emb], dim=1)  # → [B, 256]
        return self.classifier(comb)


In [87]:
example   = train_ds[0]
seq_dim   = example.seq_emb.shape[0]
node_dim  = example.x.shape[1]

print(example)

model = ConvClassifier(seq_emb_dim=seq_dim,
                      node_feat_dim=node_dim,
                      gcn_hidden=128,
                      num_classes=len(classes)).to(torch.device('cuda'))
opt   = torch.optim.AdamW(model.parameters(), lr=1e-5)
# ADding positive weight here
pos_weight = torch.ones([num_classes]).to(device) * 10
crit = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

Data(x=[145, 21], edge_index=[2, 1294], edge_attr=[1294, 1], y=[538], seq_emb=[1024])


In [88]:
num_classes = len(classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for epoch in range(1, 11):
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        logits = model(batch)                           # [B, num_classes]
        labels = batch.y.view(logits.size())            # reshape to [B, num_classes]
        loss = crit(logits, labels)
        opt.zero_grad(); loss.backward(); opt.step()
        total_loss += loss.item() * batch.num_graphs
    print(f"Epoch {epoch} Train Loss: {total_loss/len(train_ds):.4f}")

    model.eval()
    all_preds, all_targs = [], []
    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            probs = torch.sigmoid(model(batch)).cpu()
            preds = (probs > 0.5).int()
            targs = batch.y.view(probs.size())          # [B, num_classes]
            all_preds.append(preds)
            all_targs.append(targs)
    preds = torch.vstack(all_preds).numpy()
    targs = torch.vstack(all_targs).cpu().numpy()
    print("Val P, R, F1=",
          precision_score(targs, preds, average='micro', zero_division=0),
          recall_score(targs, preds, average='micro', zero_division=0),
          f1_score(targs, preds, average='micro', zero_division=0))


Epoch 1 Train Loss: 0.7128
Val P, R, F1= 0.003272587686482196 0.5101419878296146 0.0065034553647042095
Epoch 2 Train Loss: 0.7104
Val P, R, F1= 0.0033562953443457754 0.47058823529411764 0.006665054668291434
Epoch 3 Train Loss: 0.7068
Val P, R, F1= 0.0034721907712188596 0.3887762001352265 0.006882909735127678
Epoch 4 Train Loss: 0.7009
Val P, R, F1= 0.0037465712183046766 0.2839756592292089 0.007395570581476737
Epoch 5 Train Loss: 0.6907
Val P, R, F1= 0.003858005968990367 0.1612576064908722 0.007535723595345861
Epoch 6 Train Loss: 0.6742
Val P, R, F1= 0.0035753060536857476 0.06457065584854632 0.0067754522880454065
Epoch 7 Train Loss: 0.6492
Val P, R, F1= 0.002168769716088328 0.014874915483434753 0.003785597522154349
Epoch 8 Train Loss: 0.6144
Val P, R, F1= 0.002584721424468696 0.006085192697768763 0.0036283007458173754
Epoch 9 Train Loss: 0.5703
Val P, R, F1= 0.010089686098654708 0.0030425963488843813 0.004675324675324675
Epoch 10 Train Loss: 0.5180
Val P, R, F1= 0.0 0.0 0.0


In [36]:
model.eval()
all_probs, all_targs = [], []

with torch.no_grad():
    for batch in val_loader:
        batch = batch.to(device)
        probs = torch.sigmoid(model(batch)).cpu()
        targs = batch.y.view(probs.size()).cpu()
        all_probs.append(probs)
        all_targs.append(targs)

probs_val = torch.vstack(all_probs).numpy()
targs_val = torch.vstack(all_targs).numpy()

print("Validation probs shape:", probs_val.shape)
print("Validation targets shape:", targs_val.shape)

Validation probs shape: (1729, 538)
Validation targets shape: (1729, 538)


In [38]:
def find_best_thresholds_per_ec(y_true, y_probs, metric='f1'):
    thresholds = np.linspace(0.1, 0.9, 81)
    best_thresholds = []

    for i in range(y_true.shape[1]):
        best_score = -1
        best_t = 0.5
        for t in thresholds:
            preds = (y_probs[:, i] >= t).astype(int)
            if metric == 'f1':
                score = f1_score(y_true[:, i], preds, zero_division=0)
            elif metric == 'precision':
                score = precision_score(y_true[:, i], preds)
            elif metric == 'recall':
                score = recall_score(y_true[:, i], preds)
            if score > best_score:
                best_score = score
                best_t = t
        best_thresholds.append(best_t)

    return np.array(best_thresholds)

For faster future operations save the `structures/` folder as a zip-file to "pre-conpute" embeddings for subsequent runs.

In [39]:
best_thresholds = find_best_thresholds_per_ec(targs_val, probs_val, metric='f1')
print("Best thresholds:", best_thresholds)


Best thresholds: [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1
 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1
 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1
 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1
 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1
 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1
 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1
 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1
 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1
 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1
 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1
 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1
 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1
 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0

In [40]:
preds_val = (probs_val >= best_thresholds).astype(int)

precision = precision_score(targs_val, preds_val, average='micro', zero_division=0)
recall    = recall_score(targs_val, preds_val, average='micro', zero_division=0)
f1        = f1_score(targs_val, preds_val, average='micro', zero_division=0)

print(f"Validation Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")


Validation Precision: 0.1356, Recall: 0.0233, F1 Score: 0.0398


Saving the Trained Model

In [52]:
import torch
import numpy as np
import pandas as pd

torch.save(model.state_dict(), 'trained_model.pth')
print("Saved model to trained_model.pth")

np.save('best_thresholds.npy', best_thresholds)
print("Saved best thresholds to best_thresholds.npy")


rows = []
for pid, trow, prow in zip(val_ids, targs_val, (probs_val >= best_thresholds).astype(int)):
    true_ecs = [c for c, flag in zip(classes, trow) if flag]
    pred_ecs = [c for c, flag in zip(classes, prow) if flag]
    rows.append({
        'strain': pid,
        'validation EC class': ';'.join(true_ecs) or '-',
        'predicted': ';'.join(pred_ecs) or '-'
    })

results_df = pd.DataFrame(rows, columns=['strain', 'validation EC class', 'predicted'])
results_df.to_csv('validation_results.csv', index=False)
print("Saved validation results to validation_results.csv")

np.save('probs_val.npy', probs_val)
np.save('targs_val.npy', targs_val)
print("Saved validation probs and targets")


Saved model to trained_model.pth
Saved best thresholds to best_thresholds.npy
Saved validation results to validation_results.csv
Saved validation probs and targets


In [53]:
from google.colab import files

files.download('trained_model.pth')
files.download('best_thresholds.npy')
files.download('validation_results.csv')
files.download('probs_val.npy')
files.download('targs_val.npy')


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
!zip -r structures.zip structures
!zip -r cache.zip cache