In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
import numpy as np
import ast  # for parsing string representations
from sklearn.utils.class_weight import compute_class_weight

from sklearn.metrics import (
    balanced_accuracy_score,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    multilabel_confusion_matrix
)

# use CUDA if available, for AWS EC2 with GPU
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("Using CPU")

print("using device:", device)

Using GPU: Tesla T4
using device: cuda


In [None]:
class ProteinDataset(Dataset):
    def __init__(self, embeddings_path, csv_path):
        # embeddings_path is the path to the .npy file with dict mapping protein IDs to lists of segment embeddings
        # csv_path is the path to the .csv file with filtered labels
        self.embeddings_dict = np.load(embeddings_path, allow_pickle=True).item()
        self.df = pd.read_csv(csv_path)
        # keep only proteins that exist in the embeddings and have at least one segment
        self.df = self.df[self.df["Entry"].isin(self.embeddings_dict.keys())].reset_index(drop=True)
        self.df = self.df[self.df["Entry"].apply(lambda x: len(self.embeddings_dict[x]) > 0)]
        self.entries = self.df["Entry"].tolist()
        
        # ensure molecular_functions column is parsed into lists
        def parse_str(s):
            if isinstance(s, str):
                try:
                    return ast.literal_eval(s)
                except Exception:
                    return [item.strip() for item in s.split(';') if item.strip()]
            return s
        self.df["molecular_functions"] = self.df["molecular_functions"].apply(parse_str)
        
        # binarize labels
        from sklearn.preprocessing import MultiLabelBinarizer
        self.mlb = MultiLabelBinarizer()
        self.labels = self.mlb.fit_transform(self.df["molecular_functions"])
        self.num_classes = self.labels.shape[1]
        
    def __len__(self):
        return len(self.entries)
    
    def __getitem__(self, idx):
        entry = self.entries[idx]
        seg_list = self.embeddings_dict[entry]  # list of segment embeddings (numpy arrays)
        # convert each segment to a torch tensor
        seg_tensors = [torch.tensor(seg, dtype=torch.float) for seg in seg_list]
        label = torch.tensor(self.labels[idx], dtype=torch.float)
        return seg_tensors, label

def collate_fn(batch):
    # collate function to pad variable number of segments per protein.
    all_segments = []
    labels = []
    lengths = []

    for seg_list, label in batch:
        lengths.append(len(seg_list))
        # stack segments of one protein: shape (num_segments, embedding_dim)
        all_segments.append(torch.stack(seg_list))
        labels.append(label)
    # pad sequences along the segments axis so that all proteins have the same number of segments
    padded_segments = pad_sequence(all_segments, batch_first=True)  # (batch_size, max_seq_len, embedding_dim)
    labels = torch.stack(labels)  # (batch_size, num_classes)
    batch_size, max_seq_len, _ = padded_segments.size()
    pad_mask = torch.zeros((batch_size, max_seq_len), dtype=torch.bool)
    
    for i, l in enumerate(lengths):
        if l < max_seq_len:
            pad_mask[i, l:] = True
    return padded_segments, labels, pad_mask

In [13]:
class BaselineProteinClassifier(nn.Module):
    def __init__(self, input_dim, num_classes, hidden_dim=512, hidden_dim2=256):
        # input_dim is the dimension of each segment embedding (like 1280 for esm2_t33_650M_UR50D)
        # num_classes is the number of output classes (from MultiLabelBinarizer)
        super(BaselineProteinClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim2)
        self.fc3 = nn.Linear(hidden_dim2, num_classes)
        self.relu = nn.ReLU()
        
    def forward(self, x, src_key_padding_mask=None):
        # x: (batch_size, seq_len, input_dim) containing segment embeddings
        # src_key_padding_mask: boolean mask of shape (batch_size, seq_len) marking padded segments
        # mean pool over the sequence of segments, ignoring padded positions if mask is provided
        if src_key_padding_mask is not None:
            mask = (~src_key_padding_mask).unsqueeze(-1).float()  # shape: (batch_size, seq_len, 1)
            x = x * mask
            x = x.sum(dim=1) / mask.sum(dim=1)
        else:
            x = x.mean(dim=1)  # (batch_size, input_dim)
        
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        logits = self.fc3(x)
        return logits

# for training, we're going to use BCEWithLogitsLoss for multi-label classification

In [14]:
def train_model(model, dataloader, criterion, optimizer, device, num_epochs=10):
    for epoch in range(1,num_epochs+1):
        model.train()
        epoch_loss = 0.0

        for inputs, labels, pad_mask in dataloader:
            inputs, labels, pad_mask = inputs.to(device), labels.to(device), pad_mask.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs, src_key_padding_mask=pad_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item() * inputs.size(0)

        avg_loss = epoch_loss / len(dataloader.dataset)
        
        # print training results for the epoch
        model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for inputs, labels, pad_mask in dataloader:
                inputs, labels, pad_mask = inputs.to(device), labels.to(device), pad_mask.to(device)
                logits = model(inputs, src_key_padding_mask=pad_mask)
                probs = torch.sigmoid(logits)
                preds = (probs > 0.5).cpu().numpy()
                all_preds.append(preds)
                all_labels.append(labels.cpu().numpy())

        all_preds = np.vstack(all_preds)
        all_labels = np.vstack(all_labels)

        balanced_acc = balanced_accuracy_score(all_preds.flatten(), all_labels.flatten())

        # calculate metrics
        acc   = accuracy_score(all_labels, all_preds)
        prec  = precision_score(all_labels, all_preds, average='macro', zero_division=0)
        rec   = recall_score(all_labels, all_preds, average='macro', zero_division=0)
        f1    = f1_score(all_labels, all_preds, average='macro', zero_division=0)
        cms   = multilabel_confusion_matrix(all_labels, all_preds)

        print(f"Epoch {epoch}/{num_epochs} — Loss: {avg_loss:.4f}", end="  ")
        print(f"Acc: {acc:.4f}  Prec: {prec:.4f}  Rec: {rec:.4f}  F1: {f1:.4f} Balanced Acc: {balanced_acc:.4f}")
        
        # print out the matrices for the first 10 classes
        for idx, cm in enumerate(cms[:10]):
            tn, fp, fn, tp = cm.ravel()
            print(f"Class {idx:2d} ── TN: {tn:5d}  FP: {fp:5d}  FN: {fn:5d}  TP: {tp:5d}")

# main training routine
def main():
    embeddings_path = "Data/esm2_segmented_embeddings.npy"
    csv_path = "Data/filtered_parsed_data.csv"
    
    # create dataset and dataloader
    dataset = ProteinDataset(embeddings_path, csv_path)
    all_terms = dataset.mlb.classes_

    for idx, term in enumerate(all_terms):
        print(f"{idx:3d}: {term}")
        dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
    
    print(f"Dataset size: {len(dataset)} proteins")
    print(f"Number of classes: {dataset.num_classes}")
    
    # get input dimension from one sample (assuming at least one segment exists)
    sample_segments, _ = dataset[0]
    input_dim = sample_segments[0].shape[0]
    
    # instantiate the baseline model
    model = BaselineProteinClassifier(input_dim=input_dim, num_classes=dataset.num_classes)
    model.to(device)
    
    # set up loss function and optimizer
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    
    # train model
    train_model(model, dataloader, criterion, optimizer, device, num_epochs=10)
    
    # save the trained model
    torch.save(model.state_dict(), "protein_model.pt")
    print("Model saved as protein_model.pt")

if __name__ == "__main__":
    main()

  0: 3-phosphoinositide-dependent protein kinase activity [GO:0004676]
  1: AMP-activated protein kinase activity [GO:0004679]
  2: ATP binding [GO:0005524]
  3: ATP hydrolysis activity [GO:0016887]
  4: DNA binding [GO:0003677]
  5: DNA-binding transcription activator activity, RNA polymerase II-specific [GO:0001228]
  6: DNA-binding transcription factor activity [GO:0003700]
  7: DNA-binding transcription factor activity, RNA polymerase II-specific [GO:0000981]
  8: DNA-binding transcription repressor activity, RNA polymerase II-specific [GO:0001227]
  9: DNA-dependent protein kinase activity [GO:0004677]
 10: G protein-coupled receptor activity [GO:0004930]
 11: GTP binding [GO:0005525]
 12: GTPase activator activity [GO:0005096]
 13: GTPase activity [GO:0003924]
 14: RNA binding [GO:0003723]
 15: RNA polymerase II cis-regulatory region sequence-specific DNA binding [GO:0000978]
 16: RNA polymerase II transcription regulatory region sequence-specific DNA binding [GO:0000977]
 17: ac