In [1]:
!pip3 install torch transformers fair-esm biopython scikit-learn pandas



In [None]:
import torch
print(torch.cuda.is_available())  # should return true
print(torch.cuda.get_device_name(0))  # should say T4

import numpy as np
import pandas as pd
from sklearn.preprocessing import MultiLabelBinarizer
from collections import Counter
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.utils.class_weight import compute_class_weight
import ast  # parses string representations

# imports for evaluation
from sklearn.metrics import (
    balanced_accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix
)

True
Tesla T4


In [3]:
class SegmentedDataset(Dataset):
    def __init__(self, embeddings_path, csv_path):
        # embeddings_path: path to .npy file with a dict mapping protein IDs to lists of segment embeddings
        # csv_path: path to CSV file with filtered labels
        
        # load the embeddings dictionary
        self.embeddings_dict = np.load(embeddings_path, allow_pickle=True).item()
        
        # Load CSV with GO annotations
        self.df = pd.read_csv(csv_path)
        
        # keep only rows where Entry exists in embeddings and has at least one segment
        self.df = self.df[self.df["Entry"].isin(self.embeddings_dict)]
        self.df = self.df[self.df["Entry"].apply(lambda x: len(self.embeddings_dict[x]) > 0)]
        
        # ensure the molecular_functions column is parsed correctly
        def parse_str(s):
            if isinstance(s, str):
                try:
                    return ast.literal_eval(s)
                except Exception:
                    return [item.strip() for item in s.split(';')]
            return []
        
        self.df["molecular_functions"] = self.df["molecular_functions"].apply(parse_str)
        self.entries = self.df["Entry"].tolist()

        # create multi-label binarized labels
        self.mlb = MultiLabelBinarizer()
        self.labels = self.mlb.fit_transform(self.df["molecular_functions"])
        # self.num_classes = self.labels.shape[1]
        
    def __len__(self):
        return len(self.entries)
    
    def __getitem__(self, idx):
        entry = self.entries[idx]
        seg_list = self.embeddings_dict[entry]
        seg_tensors = [torch.tensor(seg, dtype=torch.float) for seg in seg_list]
        label = torch.tensor(self.labels[idx], dtype=torch.float)
        return seg_tensors, label

def collate_fn(batch):
    # this collate function pads sequences of segment embeddings for each protein
    seg_lists, labels = zip(*batch)

    # pad to max segments
    padded = pad_sequence([torch.stack(seg) for seg in seg_lists], batch_first=True)  # (B, S_max, emb_dim)
    labels = torch.stack(labels)

    # mask: true where padding
    B, S, _ = padded.shape
    mask = torch.zeros(B, S, dtype=torch.bool)

    for i, seg in enumerate(seg_lists):
        if len(seg)<S:
            mask[i, len(seg):] = True
    return padded, labels, mask

class TransformerClassifier(nn.Module):
    def __init__(self, input_dim, nhead=4, num_classes=50, dim_feedforward=256, dropout=0.1):
        super(TransformerClassifier, self).__init__()

        d_model = input_dim

        # 1-layer transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=False) 
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)
        self.classifier = nn.Linear(d_model, num_classes)
        
    def forward(self, x, pad_mask):
        x = x.transpose(0,1) # (S, B, d_model)
        x = self.encoder(x, src_key_padding_mask=pad_mask)
        x = x.transpose(0,1) # (B, S, d_model)

        # mask out padding then mean-pool
        mask = (~pad_mask).unsqueeze(-1) # (B, S, 1)
        x = x * mask
        x = x.sum(dim=1) / mask.sum(dim=1) # (B, d_model)
        
        return self.classifier(x)
    
def train_model(model, dataloader, criterion, optimizer, device, num_epochs=10):
    model.train()

    for epoch in range(num_epochs):
        epoch_loss = 0.0
        y_true_all = []
        y_pred_all = []

        for batch in dataloader:
            inputs, labels, pad_mask = batch

            inputs = inputs.to(device)
            labels = labels.to(device)
            pad_mask = pad_mask.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs, pad_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * inputs.size(0)

            # collect for metrics
            with torch.no_grad():
                probs = torch.sigmoid(outputs)
                preds = (probs > 0.5).long()
                y_true_all.append(labels.cpu().numpy())
                y_pred_all.append(preds.cpu().numpy())
            
        avg_loss = epoch_loss / len(dataloader.dataset)

        y_true = np.vstack(y_true_all)
        y_pred = np.vstack(y_pred_all)

        acc = (y_true == y_pred).all(axis=1).mean()

        f1 = f1_score(y_true, y_pred, average="macro", zero_division=0)
        precision = precision_score(y_true, y_pred, average="macro", zero_division=0)
        recall = recall_score(y_true, y_pred, average="macro", zero_division=0)

        # flatten for sklearn's balanced accuracy, each class instance as one sample
        balanced_acc = balanced_accuracy_score(y_true.flatten(), y_pred.flatten())

        print(f"Epoch {epoch+1}/{num_epochs} | Subset Accuracy: {acc:.4f}")
        print(f"Epoch {epoch+1}/{num_epochs} | Loss: {avg_loss:.4f} | F1: {f1:.4f} | Precision: {precision:.4f} | Recall: {recall:.4f} | Balanced Acc: {balanced_acc:.4f}")

def evaluate_model(model, dataloader, device, threshold=0.5):
    model.eval()
    all_labels = []
    all_preds = []
    total_loss = 0.0
    criterion = nn.BCEWithLogitsLoss()

    with torch.no_grad():
        for inputs, labels, pad_mask in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            pad_mask = pad_mask.to(device)

            outputs = model(inputs, pad_mask)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * inputs.size(0)

            probs = torch.sigmoid(outputs)
            preds = (probs > threshold).long()

            all_labels.append(labels.cpu().numpy())
            all_preds.append(preds.cpu().numpy())

    y_true = np.vstack(all_labels)
    y_pred = np.vstack(all_preds)

    # confusion matrices for first 10 GO classes
    classes = dataloader.dataset.mlb.classes_
    for i in range(min(10, y_true.shape[1])):
        tn, fp, fn, tp = confusion_matrix(y_true[:,i], y_pred[:,i]).ravel()
        print(f"[{i:2d}] {classes[i]}")
        print(f"    TN={tn:5d}  FP={fp:5d}  FN={fn:5d}  TP={tp:5d}\n")

In [4]:
def main():
    # set file paths
    embeddings_path = "Data/esm2_segmented_embeddings.npy"
    csv_path = "Data/filtered_parsed_data.csv"
    
    # create dataset and dataloader
    dataset = SegmentedDataset(embeddings_path, csv_path)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
    
    print(f"Dataset size: {len(dataset)} proteins")
    print(f"Number of classes: {dataset.labels.shape[1]}")
    
    # get input dimension from one sample
    sample_segments, _ = dataset[0]
    input_dim = sample_segments[0].shape[0]
    
    # instantiate transformer classifier model
    model = TransformerClassifier(input_dim=input_dim, nhead=4, num_classes=dataset.labels.shape[1], dim_feedforward=256, dropout=0.1)
    
    # use CUDA if available, AWS EC2 with GPU
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print("Using GPU:", torch.cuda.get_device_name(0))
    else:
        device = torch.device("cpu")
        print("Using CPU")

    model.to(device)

    print(f"model is on: {next(model.parameters()).device}")

    # set up loss function and optimizer for multi-label classification
    # compute class weights
    label_counts = dataset.labels.sum(axis=0)  # positive counts per class
    total_counts = len(dataset)  # total samples

    # avoid division by zero
    pos_weights = (total_counts - label_counts) / (label_counts + 1e-5)
    class_weights_tensor = torch.tensor(pos_weights, dtype=torch.float).to(device)

    # weighted loss
    criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights_tensor)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    
    # train the model
    train_model(model, dataloader, criterion, optimizer, device, num_epochs=10)
    evaluate_model(model, dataloader, device)
    
    # save the trained model state
    torch.save(model.state_dict(), "protein_transformer_model.pt")
    print("Model saved as protein_transformer_model.pt")

if __name__ == "__main__":
    main()

Dataset size: 10540 proteins
Number of classes: 50
Using GPU: Tesla T4




model is on: cuda:0
Epoch 1/10 | Subset Accuracy: 0.0312
Epoch 1/10 | Loss: 0.8183 | F1: 0.2628 | Precision: 0.1719 | Recall: 0.7936 | Balanced Acc: 0.7657
Epoch 2/10 | Subset Accuracy: 0.0462
Epoch 2/10 | Loss: 0.5734 | F1: 0.4376 | Precision: 0.3173 | Recall: 0.8725 | Balanced Acc: 0.8470
Epoch 3/10 | Subset Accuracy: 0.0546
Epoch 3/10 | Loss: 0.5128 | F1: 0.4725 | Precision: 0.3509 | Recall: 0.8889 | Balanced Acc: 0.8652
Epoch 4/10 | Subset Accuracy: 0.0651
Epoch 4/10 | Loss: 0.4751 | F1: 0.4951 | Precision: 0.3728 | Recall: 0.8985 | Balanced Acc: 0.8754
Epoch 5/10 | Subset Accuracy: 0.0647
Epoch 5/10 | Loss: 0.4446 | F1: 0.5144 | Precision: 0.3920 | Recall: 0.9069 | Balanced Acc: 0.8843
Epoch 6/10 | Subset Accuracy: 0.0762
Epoch 6/10 | Loss: 0.4182 | F1: 0.5310 | Precision: 0.4090 | Recall: 0.9135 | Balanced Acc: 0.8913
Epoch 7/10 | Subset Accuracy: 0.0792
Epoch 7/10 | Loss: 0.3953 | F1: 0.5420 | Precision: 0.4189 | Recall: 0.9207 | Balanced Acc: 0.8976
Epoch 8/10 | Subset Accuracy