## __Protein Function Prediction Leveraging Sequence and Structure Information__

> __Author:__ Grace Tang

> __Edited:__ `04.24.25`

> __Runtime:__ `A100`

---


#### __Description__

Predicting the functions of proteins is an essential problem in the study of proteins. Many previous works developed computational methods that utilize protein sequence information to predict protein functions. With the advancement of `AlphaFold2`, accurate protein structure data has been available for hundreds of millions of proteins.

Incorporating structure information should be able to further boost the performance of protein function prediction methods. There are several potential ways to utilize the structure information: use computer-vision-like models to extract information from contact maps; use graph neural networks to encode 3D structures; use pretrained protein structure model (e.g., `ESM-IF`).

In addition, with the advancement of LLMs, the text annotations of proteins can also be leveraged as an extra information source for protein function prediction, e.g., use LLM as encoder to encode the text annotations of proteins. Utilizing the information from sequence, structure, and text annotations, we can develop a model that accurately predicts protein functions.

__Using a subset of the dataset used for [DeepFRI](https://github.com/flatironinstitute/DeepFRI) (a similar protein function prediction model) we aim to coallesce sequence data and structure data to predict protein function in the form of [EC Numbers](https://en.wikipedia.org/wiki/Enzyme_Commission_number).__

[[1]](ttps://www.biorxiv.org/content/10.1101/2022.11.29.518451v1), [[2]](https://www.nature.com/articles/s42003-024-07359-z), [[3]](https://www.nature.com/articles/s41467-021-23303-9), [[4]](https://www.biorxiv.org/content/10.1101/2024.05.14.594226v1 )

#### __Methodology__

> _This section relies on having_ `annotations.tsv`, `train.txt`, `validation.txt`, _and_ `sequences.fasta` _in your runtime_

> _While not strictly necessary, we **highly** recommend importing a_ `structures.zip` _into your runtime as well with all the structure information - this will greatly assist the computation time. Even better if you have a_  `cache.zip` _with all pre-computed pytorch vectors that will assist greatly!_


In [1]:

from google.colab import files
uploaded = files.upload()

Saving cache.zip to cache.zip
Saving annotations.tsv to annotations.tsv
Saving sequences.fasta to sequences.fasta
Saving train.txt to train.txt
Saving validation.txt to validation.txt


In [2]:
!unzip cache.zip -d cache

Archive:  cache.zip
   creating: cache/cache/
  inflating: cache/cache/2QYU-A.pt   
  inflating: cache/cache/2DCE-A.pt   
  inflating: cache/cache/1DIR-A.pt   
  inflating: cache/cache/5AA5-C.pt   
  inflating: cache/cache/2RS7-A.pt   
  inflating: cache/cache/2YV2-A.pt   
  inflating: cache/cache/3WRX-C.pt   
  inflating: cache/cache/3PPS-A.pt   
  inflating: cache/cache/3A24-A.pt   
  inflating: cache/cache/2YWX-A.pt   
  inflating: cache/cache/5Y4R-A.pt   
  inflating: cache/cache/3KV5-D.pt   
  inflating: cache/cache/3DKO-A.pt   
  inflating: cache/cache/6OJO-A.pt   
  inflating: cache/cache/4H05-A.pt   
  inflating: cache/cache/2D5R-A.pt   
  inflating: cache/cache/3N6R-A.pt   
  inflating: cache/cache/1O54-A.pt   
  inflating: cache/cache/2JEA-B.pt   
  inflating: cache/cache/1I5E-A.pt   
  inflating: cache/cache/1F7U-A.pt   
  inflating: cache/cache/3M7I-A.pt   
  inflating: cache/cache/5AG8-A.pt   
  inflating: cache/cache/4PWV-A.pt   
  inflating: cache/cache/4M1N-A.pt   
  in

In [3]:
!mv cache/cache/* cache/
!rmdir cache/cache


In [4]:
!pip install torch-geometric
!pip install transformers biopython scipy

Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m21.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.6.1
Collecting biopython
  Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m34.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: biopython
Successfu

In [5]:
import os
import pandas as pd
import re
import numpy as np
import requests

import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader as GeometricDataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
import torch.optim as optim
from torch.serialization import add_safe_globals
from torch_geometric.data import Data
from torch_geometric.data.data import DataEdgeAttr, DataTensorAttr
from torch_geometric.data.storage import GlobalStorage

from transformers import BertTokenizer, BertModel

from Bio.PDB import MMCIFParser

from scipy.spatial.distance import pdist, squareform

from sklearn.metrics import precision_score, recall_score, f1_score

add_safe_globals([Data, DataEdgeAttr, DataTensorAttr, GlobalStorage])

We start by defining some utility functions that help parse the input data into a representation that can be fed into our downstream `Dataset` classes.

In [6]:
def parse_fasta(path):
    seqs = {}
    with open(path) as f:
        curr, lines = None, []
        for line in f:
            if line.startswith('>'):
                if curr:
                    seqs[curr] = ''.join(lines)
                curr = line[1:].split()[0]
                lines = []
            else:
                lines.append(line.strip())
        if curr:
            seqs[curr] = ''.join(lines)
    return seqs

def parse_annotations(path):
    ann_map, all_ec = {}, set()
    with open(path) as f:
        for ln in f:
            if ln.startswith('### PDB-chain'):
                break
        for ln in f:
            if not ln.strip() or ln.startswith('#'): continue
            pid, ecs = ln.strip().split('\t')
            ec_list = [e.strip() for e in ecs.split(',') if e.strip()]
            ann_map[pid] = ec_list
            all_ec.update(ec_list)
    classes = sorted(all_ec)
    return ann_map, classes

def fetch_cif(pdb_chain, out_dir="structures"):
    pdb_id, chain = pdb_chain.split('-')
    os.makedirs(out_dir, exist_ok=True)
    local_path = os.path.join(out_dir, f"{pdb_chain}.cif")
    if not os.path.exists(local_path):
        url = f"https://files.rcsb.org/download/{pdb_id}.cif"
        resp = requests.get(url)
        resp.raise_for_status()

        with open(local_path, "wb") as f:
            f.write(resp.content)
    return local_path

def build_df(ids, seqs, ann_map, classes):
    rows = []
    for pid in ids:
        seq = seqs.get(pid, '')
        if not seq:
            print(f"Warning: {pid} not found")
            continue
        labels = [1 if ec in ann_map.get(pid, []) else 0 for ec in classes]
        rows.append({'id': pid, 'sequence': seq, 'labels': labels})
    return pd.DataFrame(rows)

Next we read the `ids` from the `train.txt` and `validation.txt` to perform our _splitting_ into a train and validation dataset.

In [133]:
train_ids = [l.strip() for l in open('train.txt') if l.strip()]
val_ids   = [l.strip() for l in open('validation.txt') if l.strip()]
seqs_map  = parse_fasta('sequences.fasta')
ann_map, classes = parse_annotations('annotations.tsv')

This `Dataset` class feeds each protein through two parallel pipelines — sequence and structure — and returns a single `torch_geometric.data.Data` object per example for graph‐based training.

1. **Structure Fetching:** Given an ID like `4PR3-A`, it downloads the corresponding `mmCIF` file from `RCSB` if not already cached. Uses Biopython’s `PDBParser` to extract all Cα atom coordinates for the specified chain.

2. **Graph Construction:** Builds an undirected residue‐level graph by connecting any two Cα atoms within 8 Å.  Encodes edge weights as the inverse distance (`1/d`) to reflect spatial proximity.

3. **Node Features:** Represents each residue by a 20‐dimensional one-hot vector (one position per standard amino acid). Ensures that non-standard letters (U, Z, O, B) are mapped to “X” and treated uniformly.

4. **`ProtBert` Embedding** Cleans and “space-separates” the amino-acid string for `Rostlab/prot_bert`.  Tokenizes, pads/truncates, and runs the model in `eval()` mode to get the last hidden states.  Applies mean-pooling (masking out padding) to produce a fixed‐size (1024-dim) global embedding.

5. **Label Vector:** Builds a multi-hot target vector for all EC classes based on `annotations.tsv`.

6. **Output**  
   - Returns a `Data` object with fields:  
     - `x`: `[L, 20]` residue features  
     - `edge_index`: `[2, E]` graph connectivity  
     - `edge_attr`: `[E, 1]` distance weights  
     - `seq_emb`: `[1024]` `ProtBert` fingerprint  
     - `y`: `[num_ec]` multi-hot EC labels  


In [135]:
print(f"Available cached proteins: {len(cached_proteins)}")
print(f"Train IDs after filtering: {len(train_ids)}")
print(f"Val IDs after filtering: {len(val_ids)}")


Available cached proteins: 2729
Train IDs after filtering: 1000
Val IDs after filtering: 1729


In [134]:
import os

cached_proteins = set(f.split('.')[0] for f in os.listdir('cache') if f.endswith('.pt'))

train_ids = [pid for pid in train_ids if pid in cached_proteins]
val_ids = [pid for pid in val_ids if pid in cached_proteins]

print(f"Filtered Train IDs: {len(train_ids)} proteins")
print(f"Filtered Val IDs: {len(val_ids)} proteins")


Filtered Train IDs: 1000 proteins
Filtered Val IDs: 1729 proteins


In [129]:
import os
import torch
from torch.utils.data import Dataset

class ProteinDataset(Dataset):
    def __init__(self, ids, seqs_map, ann_map, classes, pdb_dir=None, cache_dir='cache'):
        self.ids = ids
        self.seqs_map = seqs_map
        self.ann_map = ann_map
        self.classes = classes
        self.pdb_dir = pdb_dir
        self.cache_dir = cache_dir

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        id = self.ids[idx]
        cache_path = os.path.join(self.cache_dir, f"{id}.pt")

        if os.path.exists(cache_path):
            data = torch.load(cache_path)
            return data
        else:
            raise FileNotFoundError(f"Cached tensor not found for ID: {id}. Please check your cache folder.")


Using our `Dataset` class now we can input our training and validation data and define some `DataLoaders` to be used by our chosen model.

> We are leveraging a `df_train_mini` subset for faster development for our final model we will use the _full_ training set.

In [128]:
df_train_mini = train_ids[:2000]

train_ds = ProteinDataset(train_ids, seqs_map, ann_map, classes, cache_dir='cache')
val_ds   = ProteinDataset(val_ids, seqs_map, ann_map, classes, cache_dir='cache')

train_loader = GeometricDataLoader(train_ds, batch_size=16, shuffle=True)
val_loader   = GeometricDataLoader(val_ds,   batch_size=16)

This `ConvClassifer` ingests the `Data` objects produced by `ProteinDataset` and fuses structure and sequence information to predict one or more EC numbers.

1. **Structure Branch (GCN)**  
   - **GCNConv(20 → 128)** + ReLU  
   - **GCNConv(128 → 128)** + ReLU  
   - **Global Mean Pooling** over residues → a 128-dim graph embedding

2. **Sequence Branch (MLP)**  
   - **Linear(1024 → 128)** + ReLU  
   - Takes the `ProtBert` pooled embedding and projects it down to 128 dims

3. **Fusion**  
   - **Concatenate** the 128-dim GCN output and 128-dim MLP output → 256-dim vector  
   - **MLP Head**: Linear(256 → 256) + ReLU + Dropout(0.1) + Linear(256 → `num_ec`)  
   - Outputs one raw logit per EC class

4. **Training Details**  
   - **Loss**: `BCEWithLogitsLoss` for multi-label classification  
   - **Optimizer**: AdamW with `lr=1e-3`, `weight_decay=1e-5`  
   - **DataLoader**: uses `torch_geometric.loader.DataLoader` to batch graphs  

In [200]:
import os
import torch

cache_dir = 'cache'

for fname in os.listdir(cache_dir):
    if fname.endswith('.pt'):
        path = os.path.join(cache_dir, fname)
        data = torch.load(path)

        if data.y.dim() == 1:
            data.y = data.y.view(1, -1)  # make it [1, 538]

            torch.save(data, path)
            print(f"✔️ Fixed y shape in {fname}")
        else:
            print(f"Already OK: {fname}")


✔️ Fixed y shape in 5K22-A.pt
✔️ Fixed y shape in 3Q7I-A.pt
✔️ Fixed y shape in 4JZU-A.pt
✔️ Fixed y shape in 1WA3-A.pt
✔️ Fixed y shape in 2KXE-A.pt
✔️ Fixed y shape in 5B46-A.pt
✔️ Fixed y shape in 2EDW-A.pt
✔️ Fixed y shape in 1RWG-A.pt
✔️ Fixed y shape in 1JNR-B.pt
✔️ Fixed y shape in 1BXB-A.pt
✔️ Fixed y shape in 5CGD-A.pt
✔️ Fixed y shape in 3QO6-A.pt
✔️ Fixed y shape in 2YEV-A.pt
✔️ Fixed y shape in 2CWH-A.pt
✔️ Fixed y shape in 2EBN-A.pt
✔️ Fixed y shape in 4B94-A.pt
✔️ Fixed y shape in 2M85-A.pt
✔️ Fixed y shape in 5JI2-A.pt
✔️ Fixed y shape in 1CPM-A.pt
✔️ Fixed y shape in 6QIN-A.pt
✔️ Fixed y shape in 6IQY-A.pt
✔️ Fixed y shape in 3F9I-A.pt
✔️ Fixed y shape in 1QBA-A.pt
✔️ Fixed y shape in 4YUU-A1.pt
✔️ Fixed y shape in 4UIR-A.pt
✔️ Fixed y shape in 1YN9-A.pt
✔️ Fixed y shape in 3K8C-A.pt
✔️ Fixed y shape in 3IU0-A.pt
✔️ Fixed y shape in 4OEL-A.pt
✔️ Fixed y shape in 6D24-C.pt
✔️ Fixed y shape in 2ZOO-A.pt
✔️ Fixed y shape in 3GYQ-A.pt
✔️ Fixed y shape in 1XFI-A.pt
✔️ Fixed 

In [201]:
import os

cached_proteins = sorted([f.split('.')[0] for f in os.listdir('cache') if f.endswith('.pt')])

train_ids = cached_proteins[:2000]   # 2000 for training
val_ids   = cached_proteins[2000:]   # rest for validation

print(f"Train IDs: {len(train_ids)} proteins")
print(f"Val IDs: {len(val_ids)} proteins")


Train IDs: 2000 proteins
Val IDs: 729 proteins


In [202]:
class ConvClassifier(nn.Module):
    def __init__(self, seq_emb_dim, node_feat_dim, gcn_hidden, num_classes):
        super().__init__()
        self.seq_mlp = nn.Sequential(
            nn.Linear(seq_emb_dim, gcn_hidden),  # was backward before
            nn.ReLU()
        )

        self.conv1 = GCNConv(node_feat_dim, gcn_hidden)
        self.conv2 = GCNConv(gcn_hidden,      gcn_hidden)

        self.classifier = nn.Sequential(
            nn.Linear(2 * gcn_hidden, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, data):
        x, edge_index, edge_attr, batch = (
            data.x, data.edge_index, data.edge_attr, data.batch
        )

        edge_index = edge_index.clamp(0, x.size(0) - 1)

        h = self.conv1(x, edge_index, edge_weight=edge_attr.squeeze())
        h = torch.relu(h)
        h = self.conv2(h, edge_index, edge_weight=edge_attr.squeeze())
        h = torch.relu(h)
        gcn_emb = global_mean_pool(h, batch)  # → [B, gcn_hidden]

        seq_emb_raw = data.seq_emb             # [B, seq_emb_dim]
        B = gcn_emb.size(0)
        seq_emb = seq_emb_raw.view(B, -1).to(gcn_emb.device)

        seq_emb = self.seq_mlp(seq_emb)        # → [B, gcn_hidden]

        comb = torch.cat([gcn_emb, seq_emb], dim=1)  # → [B, 256]
        return self.classifier(comb)


In [214]:
example   = train_ds[0]
seq_dim   = example.seq_emb.shape[0]
node_dim  = example.x.shape[1]
num_classes = len(classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(example)

model = ConvClassifier(seq_emb_dim=seq_dim,
                      node_feat_dim=node_dim,
                      gcn_hidden=32,
                      num_classes=len(classes)).to(torch.device('cuda'))
opt   = torch.optim.AdamW(model.parameters(), lr=1e-5)
# ADding positive weight here
pos_weight = torch.ones([num_classes]).to(device) * 10
crit = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

Data(x=[330, 21], edge_index=[2, 3198], edge_attr=[3198, 1], y=[1, 538], seq_emb=[1024])


In [204]:
import os

cache_dir = 'cache'

cached_ids = set(f.replace('.pt', '') for f in os.listdir(cache_dir) if f.endswith('.pt'))
print(f"Found {len(cached_ids)} cached proteins.")


Found 2729 cached proteins.


In [205]:
print(f"Number of proteins in train_ds: {len(train_ds)}")


Number of proteins in train_ds: 2000


In [206]:
for i in range(5):
    try:
        ex = train_ds[i]
        print(f"Item {i}: x={ex.x.shape}, edge_index={ex.edge_index.shape}, seq_emb={ex.seq_emb.shape}, y={ex.y.shape}")
    except Exception as e:
        print(f"Item {i} ERROR:", e)


Item 0: x=torch.Size([330, 21]), edge_index=torch.Size([2, 3198]), seq_emb=torch.Size([1024]), y=torch.Size([1, 538])
Item 1: x=torch.Size([164, 21]), edge_index=torch.Size([2, 1484]), seq_emb=torch.Size([1024]), y=torch.Size([1, 538])
Item 2: x=torch.Size([683, 21]), edge_index=torch.Size([2, 7226]), seq_emb=torch.Size([1024]), y=torch.Size([1, 538])
Item 3: x=torch.Size([130, 21]), edge_index=torch.Size([2, 128288]), seq_emb=torch.Size([1024]), y=torch.Size([1, 538])
Item 4: x=torch.Size([87, 21]), edge_index=torch.Size([2, 2016630]), seq_emb=torch.Size([1024]), y=torch.Size([1, 538])


In [207]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)


Using device: cuda


In [208]:
print("seq_emb_dim =", seq_dim)
print("node_feat_dim =", node_dim)
print("gcn_hidden = 128")
print("fusion_hidden = 256")
print("num_classes =", num_classes)


seq_emb_dim = 1024
node_feat_dim = 21
gcn_hidden = 128
fusion_hidden = 256
num_classes = 538


In [209]:
from torch_geometric.loader import DataLoader as GeometricDataLoader

def custom_collate(batch):
    batch = [b for b in batch if b is not None]
    if len(batch) == 0:
        return None
    return torch_geometric.data.Batch.from_data_list(batch)

# Build datasets
train_ds = ProteinDataset(train_ids, seqs_map, ann_map, classes, cache_dir='cache')
val_ds   = ProteinDataset(val_ids, seqs_map, ann_map, classes, cache_dir='cache')

print(f"Number of proteins in train_ds: {len(train_ds)}")
print(f"Number of proteins in val_ds: {len(val_ds)}")


Number of proteins in train_ds: 2000
Number of proteins in val_ds: 729


In [187]:
for epoch in range(1, 51):
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        logits = model(batch)                           # [B, num_classes]
        labels = batch.y.view(logits.size())            # reshape to [B, num_classes]
        loss = crit(logits, labels)
        opt.zero_grad(); loss.backward(); opt.step()
        total_loss += loss.item() * batch.num_graphs
    print(f"Epoch {epoch} Train Loss: {total_loss/len(train_ds):.4f}")

    model.eval()
    all_preds, all_targs = [], []
    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            probs = torch.sigmoid(model(batch)).cpu()
            preds = (probs > 0.2).int()
            targs = batch.y.view(probs.size())          # [B, num_classes]
            all_preds.append(preds)
            all_targs.append(targs)
    preds = torch.vstack(all_preds).numpy()
    targs = torch.vstack(all_targs).cpu().numpy()
    print("Val P, R, F1=",
          precision_score(targs, preds, average='micro', zero_division=0),
          recall_score(targs, preds, average='micro', zero_division=0),
          f1_score(targs, preds, average='micro', zero_division=0))


Epoch 1 Train Loss: 0.3509
Val P, R, F1= 0.0031799544615040603 1.0 0.006339748810493377
Epoch 2 Train Loss: 0.3484
Val P, R, F1= 0.0031799544615040603 1.0 0.006339748810493377
Epoch 3 Train Loss: 0.3458
Val P, R, F1= 0.0031799544615040603 1.0 0.006339748810493377
Epoch 4 Train Loss: 0.3428
Val P, R, F1= 0.0031799544615040603 1.0 0.006339748810493377
Epoch 5 Train Loss: 0.3395
Val P, R, F1= 0.0031799544615040603 1.0 0.006339748810493377
Epoch 6 Train Loss: 0.3356
Val P, R, F1= 0.0031799544615040603 1.0 0.006339748810493377
Epoch 7 Train Loss: 0.3312
Val P, R, F1= 0.0031799544615040603 1.0 0.006339748810493377
Epoch 8 Train Loss: 0.3259
Val P, R, F1= 0.0031799544615040603 1.0 0.006339748810493377
Epoch 9 Train Loss: 0.3200
Val P, R, F1= 0.0031799544615040603 1.0 0.006339748810493377
Epoch 10 Train Loss: 0.3133
Val P, R, F1= 0.0031799544615040603 1.0 0.006339748810493377
Epoch 11 Train Loss: 0.3059
Val P, R, F1= 0.0031799544615040603 1.0 0.006339748810493377
Epoch 12 Train Loss: 0.2977
Va

In [210]:
example = train_ds[0]
print("Sequence embedding shape:", example.seq_emb.shape)
print("Node feature shape:", example.x.shape)


Sequence embedding shape: torch.Size([1024])
Node feature shape: torch.Size([330, 21])


In [211]:
model.eval()
all_probs, all_targs = [], []

with torch.no_grad():
    for batch in val_loader:
        batch = batch.to(device)
        probs = torch.sigmoid(model(batch)).cpu()
        targs = batch.y.view(probs.size()).cpu()
        all_probs.append(probs)
        all_targs.append(targs)

probs_val = torch.vstack(all_probs).numpy()
targs_val = torch.vstack(all_targs).numpy()

print("Validation probs shape:", probs_val.shape)
print("Validation targets shape:", targs_val.shape)

Validation probs shape: (1729, 538)
Validation targets shape: (1729, 538)


In [190]:
def find_best_thresholds_per_ec(y_true, y_probs, metric='f1'):
    thresholds = np.linspace(0.1, 0.9, 81)
    best_thresholds = []

    for i in range(y_true.shape[1]):
        best_score = -1
        best_t = 0.5
        for t in thresholds:
            preds = (y_probs[:, i] >= t).astype(int)
            if metric == 'f1':
                score = f1_score(y_true[:, i], preds, zero_division=0)
            elif metric == 'precision':
                score = precision_score(y_true[:, i], preds)
            elif metric == 'recall':
                score = recall_score(y_true[:, i], preds)
            if score > best_score:
                best_score = score
                best_t = t
        best_thresholds.append(best_t)

    return np.array(best_thresholds)

In [191]:
best_thresholds = find_best_thresholds_per_ec(targs_val, probs_val, metric='f1')
print("Best thresholds:", best_thresholds)


Best thresholds: [0.17 0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1
 0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1
 0.1  0.11 0.1  0.13 0.1  0.1  0.1  0.13 0.1  0.1  0.1  0.1  0.1  0.1
 0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1
 0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.11 0.1  0.1  0.1  0.1  0.1
 0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1
 0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1
 0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.28 0.1  0.1  0.1  0.1  0.1  0.1
 0.1  0.16 0.12 0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.25
 0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1
 0.1  0.1  0.23 0.1  0.1  0.1  0.1  0.14 0.1  0.1  0.1  0.14 0.1  0.1
 0.23 0.1  0.1  0.1  0.1  0.1  0.11 0.1  0.1  0.1  0.1  0.15 0.1  0.1
 0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.11 0.1  0.1
 0.1  0.19 0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.14
 

In [212]:
adjusted_preds = (probs_val >= best_thresholds).astype(int)
precision = precision_score(targs_val, adjusted_preds, average='micro', zero_division=0)
recall    = recall_score(targs_val, adjusted_preds, average='micro', zero_division=0)
f1        = f1_score(targs_val, adjusted_preds, average='micro', zero_division=0)

print(f"Adjusted Validation Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")


Adjusted Validation Precision: 0.0032, Recall: 1.0000, F1 Score: 0.0063


Saving the Trained Model

In [None]:
import torch
import numpy as np
import pandas as pd

torch.save(model.state_dict(), 'trained_model.pth')
print("Saved model to trained_model.pth")

np.save('best_thresholds.npy', best_thresholds)
print("Saved best thresholds to best_thresholds.npy")


rows = []
for pid, trow, prow in zip(val_ids, targs_val, (probs_val >= best_thresholds).astype(int)):
    true_ecs = [c for c, flag in zip(classes, trow) if flag]
    pred_ecs = [c for c, flag in zip(classes, prow) if flag]
    rows.append({
        'strain': pid,
        'validation EC class': ';'.join(true_ecs) or '-',
        'predicted': ';'.join(pred_ecs) or '-'
    })

results_df = pd.DataFrame(rows, columns=['strain', 'validation EC class', 'predicted'])
results_df.to_csv('validation_results.csv', index=False)
print("Saved validation results to validation_results.csv")

np.save('probs_val.npy', probs_val)
np.save('targs_val.npy', targs_val)
print("Saved validation probs and targets")


In [None]:
from google.colab import files

files.download('trained_model.pth')
files.download('best_thresholds.npy')
files.download('validation_results.csv')
files.download('probs_val.npy')
files.download('targs_val.npy')


In [218]:
d = torch.load("cache/1A63-A.pt")
print(d.y.shape)


torch.Size([1, 538])


Re-running Model with best Thresholds


In [228]:
import torch


best_thresholds_tensor = torch.tensor(best_thresholds, dtype=torch.float32)

for epoch in range(1, 11):
    model.train()
    total_loss = 0

    for batch in train_loader:
      batch = batch.to(device)

      logits = model(batch)
      labels = batch.y.float()

      loss = crit(logits, labels)

      opt.zero_grad()
      loss.backward()
      opt.step()


    avg_loss = total_loss / len(train_ds)
    print(f"Epoch {epoch} Train Loss: {avg_loss:.4f}")

    model.eval()
    all_preds, all_targs = [], []


    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            logits = model(batch)
            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).int()
            targs = batch.y.view(logits.shape).int()   # <-- fix here!

            all_preds.append(preds.cpu())
            all_targs.append(targs.cpu())


    if len(all_preds) > 0 and len(all_targs) > 0:
        preds = torch.vstack(all_preds).numpy()
        targs = torch.vstack(all_targs).numpy()

        precision = precision_score(targs, preds, average='micro', zero_division=0)
        recall = recall_score(targs, preds, average='micro', zero_division=0)
        f1 = f1_score(targs, preds, average='micro', zero_division=0)

        print(f"Validation P: {precision:.4f} R: {recall:.4f} F1: {f1:.4f}")
    else:
        print(" Warning: No valid batches for validation this epoch.")


Epoch 1 Train Loss: 0.0000
Validation P: 0.0000 R: 0.0000 F1: 0.0000
Epoch 2 Train Loss: 0.0000
Validation P: 0.0000 R: 0.0000 F1: 0.0000
Epoch 3 Train Loss: 0.0000
Validation P: 0.0000 R: 0.0000 F1: 0.0000
Epoch 4 Train Loss: 0.0000
Validation P: 0.0000 R: 0.0000 F1: 0.0000
Epoch 5 Train Loss: 0.0000
Validation P: 0.0000 R: 0.0000 F1: 0.0000
Epoch 6 Train Loss: 0.0000
Validation P: 0.0000 R: 0.0000 F1: 0.0000
Epoch 7 Train Loss: 0.0000
Validation P: 0.0000 R: 0.0000 F1: 0.0000
Epoch 8 Train Loss: 0.0000
Validation P: 0.0000 R: 0.0000 F1: 0.0000
Epoch 9 Train Loss: 0.0000
Validation P: 0.0000 R: 0.0000 F1: 0.0000
Epoch 10 Train Loss: 0.0000
Validation P: 0.0000 R: 0.0000 F1: 0.0000


In [227]:
model.eval()
with torch.no_grad():
    batch = next(iter(train_loader)).to(device)
    logits = model(batch)
    print("Sample logits:", logits[0][:10].detach().cpu())  # first 10 values


Sample logits: tensor([-0.8893, -2.4346, -2.0931, -2.3130, -2.9819, -2.2988, -2.4646, -2.5680,
        -2.0719, -2.5941])


In [None]:
!zip -r structures.zip structures
!zip -r cache.zip cache