In [1]:
import os 
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import esm
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from tqdm import tqdm

from datasets import PeptidePredictorDataset
from models import PeptidePredictor

max_seq_length = 12
esm_model_pretrained, alphabet = esm.pretrained.esm2_t12_35M_UR50D()
batch_converter = alphabet.get_batch_converter(truncation_seq_length=max_seq_length - 2)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device, torch.cuda.device_count())

cuda 1


In [4]:
def get_stratification_key(length):
    return str(length)

In [5]:
regression_df = pd.read_csv("../../clean_data/merged_all.csv", keep_default_na=False, na_values=[''])[['peptide', 'length', 'ap']]
classification_df = pd.read_csv("../../clean_data/merged_all.csv", keep_default_na=False, na_values=[''])[['peptide', 'length', 'is_assembled']]
regression_df = regression_df.dropna(subset=['ap'])
classification_df = classification_df.dropna(subset=['is_assembled'])

classification_df['stratify_key'] = classification_df['length'].apply(get_stratification_key)
regression_df['stratify_key'] = regression_df['length'].apply(get_stratification_key)

def split_dataset(df, stratify_col, test_size=0.1, val_size=0.1, random_state=42):
    train_val_df, test_df = train_test_split(
        df, test_size=test_size, stratify=df[stratify_col], random_state=random_state
    )
    train_df, val_df = train_test_split(
        train_val_df, test_size=val_size, stratify=train_val_df[stratify_col], random_state=random_state
    )
    return train_df, val_df, test_df

train_cls, val_cls, test_cls = split_dataset(classification_df, 'stratify_key')
train_reg, val_reg, test_reg = split_dataset(regression_df, 'stratify_key')

print("Classification splits:", len(train_cls), len(val_cls), len(test_cls))
print("Regression splits:", len(train_reg), len(val_reg), len(test_reg))

Classification splits: 75870 8431 9367
Regression splits: 98537 10949 12166


In [6]:
train_cls_dataset = PeptidePredictorDataset(train_cls, task='classification')
val_cls_dataset   = PeptidePredictorDataset(val_cls, task='classification')
test_cls_dataset  = PeptidePredictorDataset(test_cls, task='classification')

train_reg_dataset = PeptidePredictorDataset(train_reg, task='regression')
val_reg_dataset   = PeptidePredictorDataset(val_reg, task='regression')
test_reg_dataset  = PeptidePredictorDataset(test_reg, task='regression')

In [7]:
def convert_and_pad(data, seq_length):
    """
    Converts a list of (name, sequence) tuples into tokenized sequences using the batch converter,
    and pads them to ensure a fixed length of max_seq_length tokens.
    """
    _, _, tokens = batch_converter(data)

    current_len = tokens.size(1)
    if current_len < seq_length:
        pad_length = seq_length - current_len
        padding = torch.full((tokens.size(0), pad_length), alphabet.padding_idx, dtype=tokens.dtype)
        tokens = torch.cat([tokens, padding], dim=1)
    
    return tokens


In [9]:
def esm_collate_fn(batch):
    """
    batch: list of tuples (sequence, target)
    Returns:
      - tokens: a tensor of shape (batch_size, seq_len) ready for the ESM model
      - targets: a tensor of the corresponding targets
    """
    data = [(f"peptide_{i}", seq) for i, (seq, _) in enumerate(batch)]

    tokens = convert_and_pad(data, seq_length=max_seq_length)

    targets = torch.tensor([target for _, target in batch])
    return tokens, targets

In [10]:
batch_size = 1024

train_cls_loader = DataLoader(train_cls_dataset, batch_size=batch_size, shuffle=True, collate_fn=esm_collate_fn)
val_cls_loader   = DataLoader(val_cls_dataset, batch_size=batch_size, shuffle=False, collate_fn=esm_collate_fn)
test_cls_loader  = DataLoader(test_cls_dataset, batch_size=batch_size, shuffle=False, collate_fn=esm_collate_fn)

train_reg_loader = DataLoader(train_reg_dataset, batch_size=batch_size, shuffle=True, collate_fn=esm_collate_fn)
val_reg_loader   = DataLoader(val_reg_dataset, batch_size=batch_size, shuffle=False, collate_fn=esm_collate_fn)
test_reg_loader  = DataLoader(test_reg_dataset, batch_size=batch_size, shuffle=False, collate_fn=esm_collate_fn)

In [11]:
model = PeptidePredictor(esm_model_pretrained, alphabet, hidden_dim=128, dropout=0.10)

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    model = nn.DataParallel(model)
model.to(device)

model.freeze_encoder()
optimizer = torch.optim.AdamW([
    {"params": list(model.shared.parameters()) + 
               list(model.ap_head.parameters()) + 
               list(model.cls_head.parameters()), "lr": 1e-3}
])
criterion_cls = nn.BCEWithLogitsLoss()
criterion_reg = nn.SmoothL1Loss(beta=0.5)

num_epochs = 5
alpha, beta = 1, 1

reg_iterator = iter(train_reg_loader)


In [12]:
model.freeze_encoder()

for epoch in range(num_epochs):
    epoch_loss = 0.0

    pbar = tqdm(train_cls_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
    
    for cls_batch in pbar:
        tokens_cls, cls_labels = cls_batch
        #print(tokens_cls, cls_labels)
        tokens_cls = tokens_cls.to(device)
        cls_labels = cls_labels.to(device)
        #print(tokens_cls, cls_labels)
        ap_pred_cls, cls_pred = model(tokens_cls)
        loss_cls = criterion_cls(cls_pred.squeeze(), cls_labels.float())
        
        try:
            reg_batch = next(reg_iterator)
        except StopIteration:
            reg_iterator = iter(train_reg_loader)
            reg_batch = next(reg_iterator)
            
        tokens_reg, reg_labels = reg_batch
        tokens_reg = tokens_reg.to(device)
        reg_labels = reg_labels.to(device)
        
        ap_pred_reg, _ = model(tokens_reg)
        loss_reg = criterion_reg(ap_pred_reg.squeeze(), reg_labels.float())
        
        total_loss = alpha * loss_cls + beta * loss_reg
        
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        epoch_loss += total_loss.item()
        pbar.set_postfix(loss=total_loss.item())
    
    avg_loss = epoch_loss / len(train_cls_loader)
    print(f"Epoch {epoch+1}/{num_epochs} - Average Loss: {avg_loss:.4f}")

                                                                      

Epoch 1/5 - Average Loss: 0.4069


                                                                      

Epoch 2/5 - Average Loss: 0.1737


                                                                      

Epoch 3/5 - Average Loss: 0.1520


                                                                      

Epoch 4/5 - Average Loss: 0.1403


                                                                       

Epoch 5/5 - Average Loss: 0.1328




In [13]:
#underlying_model = model.module if hasattr(model, "module") else model

encoder_lr = 1e-5
head_lr = 1e-3

model.unfreeze_encoder()
opt = torch.optim.Adam([
    {"params": model.esm.parameters(), "lr": encoder_lr},
    {"params": list(model.shared.parameters()) + 
               list(model.ap_head.parameters()) + 
               list(model.cls_head.parameters()), "lr": head_lr}
])

scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.75)

num_epochs = 6
for epoch in range(num_epochs):
    epoch_loss = 0.0
    pbar = tqdm(train_cls_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
    for cls_batch in pbar:
        tokens_cls, cls_labels = cls_batch
        tokens_cls = tokens_cls.to(device)
        cls_labels = cls_labels.to(device)
        
        ap_pred_cls, cls_pred = model(tokens_cls)
        loss_cls = criterion_cls(cls_pred.squeeze(), cls_labels.float())
        
        try:
            reg_batch = next(reg_iterator)
        except StopIteration:
            reg_iterator = iter(train_reg_loader)
            reg_batch = next(reg_iterator)
            
        tokens_reg, reg_labels = reg_batch
        tokens_reg = tokens_reg.to(device)
        reg_labels = reg_labels.to(device)
        
        ap_pred_reg, _ = model(tokens_reg)
        loss_reg = criterion_reg(ap_pred_reg.squeeze(), reg_labels.float())
        
        total_loss = alpha * loss_cls + beta * loss_reg
        
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        epoch_loss += total_loss.item()
        pbar.set_postfix(loss=total_loss.item())
    
    avg_loss = epoch_loss / len(train_cls_loader)
    print(f"Epoch {epoch+1}/{num_epochs} - Average Loss: {avg_loss:.4f}")
    
    scheduler.step()

                                                                      

Epoch 1/6 - Average Loss: 0.1258


                                                                       

Epoch 2/6 - Average Loss: 0.1181


                                                                       

Epoch 3/6 - Average Loss: 0.1109


                                                                       

Epoch 4/6 - Average Loss: 0.1062


                                                                       

Epoch 5/6 - Average Loss: 0.1005


                                                                       

Epoch 6/6 - Average Loss: 0.0975




In [21]:
model_to_save = model.module if hasattr(model, "module") else model
torch.save(model_to_save.state_dict(), "peptide_predictor.pt")
print("Model saved to peptide_predictor.pt")

Model saved to peptide_predictor.pt


In [14]:
def validate_model(model, cls_loader, reg_loader, device):
    model.eval()
    all_cls_preds = []
    all_cls_labels = []
    all_reg_preds = []
    all_reg_labels = []
    accuracy, f1, auc, mse, mae, r2 = float('nan'), float('nan'), float('nan'), float('nan'), float('nan'), float('nan')
    
    with torch.no_grad():
        for tokens, targets in cls_loader:
            tokens = tokens.to(device)
            targets = targets.to(device)
            _, cls_pred = model(tokens)
            all_cls_preds.append(cls_pred.cpu())
            all_cls_labels.append(targets.cpu())
            
        for tokens, targets in reg_loader:
            tokens = tokens.to(device)
            targets = targets.to(device)
            ap_pred, _ = model(tokens)
            all_reg_preds.append(ap_pred.cpu())
            all_reg_labels.append(targets.cpu())
    
    if cls_loader:
        all_cls_preds = torch.cat(all_cls_preds, dim=0).numpy().squeeze()
        all_cls_labels = torch.cat(all_cls_labels, dim=0).numpy().squeeze()

        cls_pred_labels = (all_cls_preds > 0.5).astype(int)
        accuracy = accuracy_score(all_cls_labels, cls_pred_labels)
        f1 = f1_score(all_cls_labels, cls_pred_labels)
        try:
            auc = roc_auc_score(all_cls_labels, all_cls_preds)
        except Exception:
            auc = float('nan')
    if reg_loader:
        all_reg_preds = torch.cat(all_reg_preds, dim=0).numpy().squeeze()
        all_reg_labels = torch.cat(all_reg_labels, dim=0).numpy().squeeze()

        mse = np.mean((all_reg_preds - all_reg_labels) ** 2)
        mae = np.mean(np.abs(all_reg_preds - all_reg_labels))
        ss_res = np.sum((all_reg_preds - all_reg_labels) ** 2)
        ss_tot = np.sum((all_reg_labels - np.mean(all_reg_labels)) ** 2)
        r2 = 1 - ss_res / ss_tot if ss_tot != 0 else float('nan')    
    
    return accuracy, f1, auc, mse, mae, r2, all_cls_preds, all_cls_labels, all_reg_preds, all_reg_labels

accuracy, f1, auc, mse, mae, r2, all_cls_preds, all_cls_labels, all_reg_preds, all_reg_labels = validate_model(model, val_cls_loader, val_reg_loader, device)
print("Validation Metrics:")
print(f"Classification -> Accuracy: {accuracy:.4f}, F1: {f1:.4f}, AUC: {auc:.4f}")
print(f"Regression     -> MSE: {mse:.4f}, MAE: {mae:.4f}, R²: {r2:.4f}")


Validation Metrics:
Classification -> Accuracy: 0.9669, F1: 0.9754, AUC: 0.9954
Regression     -> MSE: 0.0026, MAE: 0.0396, R²: 0.8858


In [16]:
accuracy, f1, auc, mse, mae, r2, all_cls_preds, all_cls_labels, all_reg_preds, all_reg_labels = validate_model(model, test_cls_loader, test_reg_loader, device)
print("Test Metrics:")
print(f"Classification -> Accuracy: {accuracy:.4f}, F1: {f1:.4f}, AUC: {auc:.4f}")
print(f"Regression     -> MSE: {mse:.4f}, MAE: {mae:.4f}, R²: {r2:.4f}")

Test Metrics:
Classification -> Accuracy: 0.9672, F1: 0.9754, AUC: 0.9951
Regression     -> MSE: 0.0026, MAE: 0.0393, R²: 0.8867


In [17]:
# load experimental data
experimental_df = pd.read_csv("../../clean_data/experimental.csv", keep_default_na=False, na_values=[''], sep=';')
experimental_df

Unnamed: 0,peptide,label
0,AEAEAEAEAKAKAKAK,1
1,GGGGDD,1
2,GFFLGLDD,1
3,QEIARLEQEIARLEYEIARLE,1
4,NNQQNY,1
...,...,...
363,YTEYK,0
364,EPYYK,0
365,YDPKY,0
366,KDPYY,0


In [18]:
experimental_df['len'] = experimental_df['peptide'].apply(len)
experimental_df.rename(columns={'label': 'is_assembled'}, inplace=True)
experimental_df['stratify_key'] = experimental_df['len'].apply(get_stratification_key)

experimental_df = experimental_df[experimental_df['len'] <= max_seq_length - 2]
experimental_dataset = PeptidePredictorDataset(experimental_df, task='classification')
experimental_loader = DataLoader(experimental_dataset, batch_size=batch_size, shuffle=False, collate_fn=esm_collate_fn)
experimental_df

Unnamed: 0,peptide,is_assembled,len,stratify_key
1,GGGGDD,1,6,6
2,GFFLGLDD,1,8,8
4,NNQQNY,1,6,6
5,AILSS,1,5,5
6,YVIFL,1,5,5
...,...,...,...,...
363,YTEYK,0,5,5
364,EPYYK,0,5,5
365,YDPKY,0,5,5
366,KDPYY,0,5,5


In [19]:
accuracy, f1, auc, mse, mae, r2, all_cls_preds, all_cls_labels, all_reg_preds, all_reg_labels = validate_model(model, experimental_loader, [], device)
print("Experimental Data Metrics:")
print(f"Classification -> Accuracy: {accuracy:.4f}, F1: {f1:.4f}, AUC: {auc:.4f}")
print(f"Regression     -> MSE: {mse:.4f}, MAE: {mae:.4f}, R²: {r2:.4f}")

Experimental Data Metrics:
Classification -> Accuracy: 0.5521, F1: 0.6033, AUC: 0.5601
Regression     -> MSE: nan, MAE: nan, R²: nan


In [20]:
# show df with predictions
experimental_df['predicted_is_assembled'] = all_cls_preds

experimental_df[['peptide', 'is_assembled', 'predicted_is_assembled']]


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  experimental_df['predicted_is_assembled'] = all_cls_preds


Unnamed: 0,peptide,is_assembled,predicted_is_assembled
1,GGGGDD,1,-2.919494
2,GFFLGLDD,1,2.590684
4,NNQQNY,1,-6.989876
5,AILSS,1,3.247889
6,YVIFL,1,12.237429
...,...,...,...
363,YTEYK,0,4.245285
364,EPYYK,0,3.144766
365,YDPKY,0,4.263178
366,KDPYY,0,4.366505


In [33]:
# which peptides are in the experimental data and also classification_df?
peptides_in_experimental = set(experimental_df['peptide'])
peptides_in_regression = set(regression_df['peptide'])
common_peptides = peptides_in_experimental.intersection(peptides_in_regression)
print(f"Number of common peptides in experimental and regression data: {len(common_peptides)}")

Number of common peptides in experimental and regression data: 9


In [34]:
# for the common peptides, show the peptide, is_assembled from experiments, predicted is_assembled, and ap from classification_df
common_peptides_df = regression_df[regression_df['peptide'].isin(common_peptides)]
common_peptides_df = common_peptides_df[['peptide', 'ap']]
common_peptides_df = common_peptides_df.merge(experimental_df[['peptide', 'is_assembled', 'predicted_is_assembled']], on='peptide')
common_peptides_df = common_peptides_df[['peptide', 'is_assembled', 'predicted_is_assembled', 'ap']]
print("Common peptides in experimental and classification data:")
print(common_peptides_df)

Common peptides in experimental and classification data:
  peptide  is_assembled  predicted_is_assembled        ap
0     DFY             1                0.180612  0.314015
1    GFIL             1                0.999707  0.490410
2     IID             1                0.005936  0.260011
3     IVD             1                0.005513  0.172776
4     KFG             1                0.200675  0.336357
5     KYF             1                0.990598  0.609441
6     LFF             1                0.999992  0.839569
7  VQIVYK             1                0.983682  0.430783
8   VVVVV             1                0.999842  0.591718


In [None]:
cmp_reg = pd.DataFrame(
    {"peptide":['MWYWKF'], 'ap':[2.194], "len":[6], 'stratify_key':[6] }
)
cmp_dataset  = PeptidePredictorDataset(cmp_reg, task='regression')
cmp_loader  = DataLoader(cmp_dataset, batch_size=1, shuffle=False, collate_fn=esm_collate_fn)

In [None]:
with torch.no_grad():
    for tokens, targets in cmp_loader:
        tokens = tokens.to(device)
        targets = targets.to(device)
        ap_pred, clas_pred = model(tokens)
        print(tokens, ap_pred, clas_pred)

tensor([[ 0, 20, 22, 19, 22, 15, 18,  2,  1,  1,  1,  1]], device='cuda:0') tensor([[0.5457]], device='cuda:0') tensor([[1.0000]], device='cuda:0')
