In [1]:
# avoid out-of-memory errors during training
import os
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"

In [2]:
!pip3 install torch transformers fair-esm biopython scikit-learn

Defaulting to user installation because normal site-packages is not writeable
You should consider upgrading via the '/Library/Developer/CommandLineTools/usr/bin/python3 -m pip install --upgrade pip' command.[0m


In [3]:
import torch
import numpy as np
import pandas as pd
from sklearn.preprocessing import MultiLabelBinarizer
from collections import Counter
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import ast  # parses string representations

In [4]:
class ProteinDataset(Dataset):
    def __init__(self, embeddings_path, csv_path):
        # embeddings_path: path to .npy file with a dict mapping protein IDs to lists of segment embeddings
        # csv_path: path to CSV file with filtered labels
        
        # Load the embeddings dictionary
        self.embeddings_dict = np.load(embeddings_path, allow_pickle=True).item()
        
        # Load CSV with GO annotations
        self.df = pd.read_csv(csv_path)
        
        # keep only rows where Entry exists in embeddings and has at least one segment
        self.df = self.df[self.df["Entry"].isin(self.embeddings_dict.keys())].reset_index(drop=True)
        self.df = self.df[self.df["Entry"].apply(lambda x: len(self.embeddings_dict[x]) > 0)]
        self.entries = self.df["Entry"].tolist()
        
        # ensure the molecular_functions column is parsed correctly
        def parse_str(s):
            if isinstance(s, str):
                try:
                    return ast.literal_eval(s)
                except Exception:
                    return [item.strip() for item in s.split(';') if item.strip()]
            return s
        
        self.df["molecular_functions"] = self.df["molecular_functions"].apply(parse_str)
        
        # create multi-label binarized labels
        self.mlb = MultiLabelBinarizer()
        self.labels = self.mlb.fit_transform(self.df["molecular_functions"])
        self.num_classes = self.labels.shape[1]
        
    def __len__(self):
        return len(self.entries)
    
    def __getitem__(self, idx):
        entry = self.entries[idx]
        seg_list = self.embeddings_dict[entry]
        seg_tensors = [torch.tensor(seg, dtype=torch.float) for seg in seg_list]
        label = torch.tensor(self.labels[idx], dtype=torch.float)
        return seg_tensors, label

def collate_fn(batch):
    # this collate function pads sequences of segment embeddings for each protein
    all_segments = []
    labels = []
    lengths = []

    for seg_tensors, label in batch:
        lengths.append(len(seg_tensors))
        all_segments.append(torch.stack(seg_tensors))  # num_segments, embedding_dim
        labels.append(label)

    padded_segments = pad_sequence(all_segments, batch_first=True)  # batch_size, max_seq_len, embedding_dim
    labels = torch.stack(labels)
    batch_size, max_seq_len, _ = padded_segments.size()
    pad_mask = torch.zeros((batch_size, max_seq_len), dtype=torch.bool)

    for i, l in enumerate(lengths):
        if l < max_seq_len:
            pad_mask[i, l:] = True
    return padded_segments, labels, pad_mask

In [5]:
class ProteinTransformerClassifier(nn.Module):
    def __init__(self, input_dim, num_classes, num_layers=2, nhead=8, dim_feedforward=2048, dropout=0.1):
        super(ProteinTransformerClassifier, self).__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=nhead, 
                                                   dim_feedforward=dim_feedforward, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.classifier = nn.Linear(input_dim, num_classes)
        
    def forward(self, x, src_key_padding_mask):
        # x: batch_size, seq_len, input_dim
        x = x.transpose(0, 1)  # seq_len, batch_size, input_dim
        x = self.transformer_encoder(x, src_key_padding_mask=src_key_padding_mask)
        x = x.transpose(0, 1)  # batch_size, seq_len, input_dim
        mask = (~src_key_padding_mask).unsqueeze(-1).float()  # batch_size, seq_len, 1
        x = x * mask
        pooled = x.sum(dim=1) / mask.sum(dim=1)  # batch_size, input_dim
        logits = self.classifier(pooled)
        return logits

In [6]:
def train_model(model, dataloader, criterion, optimizer, device, num_epochs=10):
    model.train()

    for epoch in range(num_epochs):
        epoch_loss = 0.0

        for batch in dataloader:
            inputs, labels, pad_mask = batch
            inputs = inputs.to(device)
            labels = labels.to(device)
            pad_mask = pad_mask.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs, src_key_padding_mask=pad_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * inputs.size(0)
            
        avg_loss = epoch_loss / len(dataloader.dataset)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

In [7]:
def main():
    # set file paths
    embeddings_path = "data/esm2_segmented_embeddings.npy"
    csv_path = "data/filtered_parsed_data.csv"
    
    # Create dataset and dataloader
    dataset = ProteinDataset(embeddings_path, csv_path)
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
    
    print(f"Dataset size: {len(dataset)} proteins")
    print(f"Number of classes: {dataset.num_classes}")
    
    # get input dimension from one sample (assume at least one segment exists)
    sample_segments, _ = dataset[0]
    input_dim = sample_segments[0].shape[0]
    
    # instantiate transformer classifier model
    model = ProteinTransformerClassifier(input_dim=input_dim, num_classes=dataset.num_classes)
    
    # Use Apple's MPS backend
    if torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    model.to(device)
    
    # set up loss function and optimizer for multi-label classification
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    
    # train the model
    train_model(model, dataloader, criterion, optimizer, device, num_epochs=10)
    
    # save the trained model state
    torch.save(model.state_dict(), "protein_transformer_model.pt")
    print("Model saved as protein_transformer_model.pt")

if __name__ == "__main__":
    main()

FileNotFoundError: [Errno 2] No such file or directory: 'data/esm2_segmented_embeddings.npy'