In [11]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import re

from sklearn.model_selection import train_test_split, GroupKFold
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import spearmanr, pearsonr
from transformers import AutoModel, AutoTokenizer
from peft import LoraConfig, get_peft_model


In [5]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
# ESM + MLP regression model

mutations = pd.read_csv('your_path/mutations.csv') # should contain mutation string, sequence,scaled activity, and group

In [None]:
train_val_df, test_df = train_test_split(mutations, test_size=0.2, random_state=42)

In [None]:

class EnzymeDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=1024):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        sequence = row["sequence"]
        label = row["scaled_activity"]

        inputs = self.tokenizer(
            sequence,
            truncation=True,
            max_length=self.max_length,
            padding=False,
            return_tensors="pt"
        )
        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.float)
        }

def collate_fn(batch, tokenizer):
    ids = [item["input_ids"] for item in batch]
    masks = [item["attention_mask"] for item in batch]
    labels = torch.stack([item["label"] for item in batch])

    ids_padded = pad_sequence(ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    masks_padded = pad_sequence(masks, batch_first=True, padding_value=0)

    return {
        "input_ids": ids_padded,
        "attention_mask": masks_padded,
        "label": labels
    }

class AttentionPooling(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.attention = nn.Linear(hidden_dim, 1)

    def forward(self, token_embeddings, attention_mask):
        att_scores = self.attention(token_embeddings).squeeze(-1)
        att_scores = att_scores.masked_fill(attention_mask == 0, float("-inf"))
        att_weights = torch.softmax(att_scores, dim=-1).unsqueeze(-1)
        return torch.sum(token_embeddings * att_weights, dim=1)

class MediumHead(nn.Module):
    def __init__(self, input_dim, dropout=0.1):
        super().__init__()
        self.regressor = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 1)
        )
    def forward(self, x):
        return self.regressor(x).squeeze(-1)

class EnzymeRegressionModel(nn.Module):
    def __init__(self, esm_model, hidden_dim, dropout=0.1):
        super().__init__()
        self.esm_model = esm_model
        self.pooler = AttentionPooling(hidden_dim)
        self.head = MediumHead(hidden_dim, dropout)

    def forward(self, input_ids, attention_mask):
        with torch.no_grad():
            outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask)
        
        hidden_state = outputs.last_hidden_state
        pooled_repr = self.pooler(hidden_state, attention_mask)
        return self.head(pooled_repr)




def train_one_fold(model, train_loader, val_loader, num_epochs=10, lr=1e-4):
    criterion = nn.MSELoss()
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0
        for batch in train_loader:
            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            optimizer.zero_grad()
            preds = model(ids, mask)
            loss = criterion(preds, labels)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

        model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for batch in val_loader:
                ids = batch["input_ids"].to(device)
                mask = batch["attention_mask"].to(device)
                labels = batch["label"].to(device)
                preds = model(ids, mask)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        val_mse = mean_squared_error(all_labels, all_preds)
        spearman_corr, _ = spearmanr(all_preds, all_labels)
        
        scheduler.step(val_mse)
        
    return model

def print_trainable_parameters(model):
    trainable_params=0; all_param=0
    for _,param in model.named_parameters():
        all_param+=param.numel()
        if param.requires_grad:trainable_params+=param.numel()
    print(f"trainable params: {trainable_params} | all params: {all_param} | trainable%: {100*trainable_params/all_param:.2f}")

In [None]:
chosen_model_name = "facebook/esm2_t33_650M_UR50D"
hidden_dim = 1280

In [None]:
tokenizer = AutoTokenizer.from_pretrained(chosen_model_name)

In [None]:
gkf = GroupKFold(n_splits=5)
fold_results = []
fold_idx = 1

train_val_data = train_val_df.reset_index(drop=True)
train_val_groups = train_val_data['group']

for train_index, val_index in gkf.split(train_val_data, groups=train_val_groups):
    
    
    fold_train_df = train_val_data.iloc[train_index]
    fold_val_df = train_val_data.iloc[val_index]

    train_dataset = EnzymeDataset(fold_train_df, tokenizer)
    val_dataset = EnzymeDataset(fold_val_df, tokenizer)
    collate_wrapper = lambda batch: collate_fn(batch, tokenizer)
    

    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_wrapper)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=collate_wrapper)

    
    base_esm_model = AutoModel.from_pretrained(chosen_model_name)

    
    for param in base_esm_model.parameters():
        param.requires_grad = False

    
    model = EnzymeRegressionModel(
        esm_model=base_esm_model,
        hidden_dim=hidden_dim
    ).to(device)

    print("Model parameter status:")
    print_trainable_parameters(model)
    
    
    trained_model = train_one_fold(model, train_loader, val_loader, num_epochs=15, lr=1e-3)

    
    trained_model.eval()
    val_preds, val_labels = [], []
    with torch.no_grad():
        for batch in val_loader:
            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            preds = trained_model(ids, mask)
            val_preds.extend(preds.cpu().numpy())
            val_labels.extend(batch["label"].numpy())

    val_labels = np.array(val_labels)
    val_preds = np.array(val_preds)
    
    spearman_corr, _ = spearmanr(val_labels, val_preds)
    pearson_corr, _ = pearsonr(val_labels, val_preds)
    rmse = np.sqrt(mean_squared_error(val_labels, val_preds))
    r2 = r2_score(val_labels, val_preds)
    
    results_dict = {'spearman': spearman_corr, 'pearson': pearson_corr, 'rmse': rmse, 'r2': r2}
    fold_results.append(results_dict)
    
    
    fold_idx += 1
    

In [None]:

for metric_name in fold_results[0].keys():
    metric_values = [r[metric_name] for r in fold_results]
    mean_val = np.mean(metric_values)
    std_val = np.std(metric_values)
        

In [None]:
# ESM_LoRA

mutations = pd.read_csv('your_path/mutations.csv') # should contain mutation string, sequence,scaled activity, and group

In [None]:

train_val_df, test_df = train_test_split(mutations, test_size=0.2, random_state=42)

In [9]:
class EnzymeDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=1024):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        sequence = row["sequence"]
        label = row["scaled_activity"]

        inputs = self.tokenizer(
            sequence,
            truncation=True,
            max_length=self.max_length,
            padding=False,
            return_tensors="pt"
        )
        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.float)
        }

def collate_fn(batch, tokenizer):
    ids = [item["input_ids"] for item in batch]
    masks = [item["attention_mask"] for item in batch]
    labels = torch.stack([item["label"] for item in batch])

    ids_padded = pad_sequence(ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    masks_padded = pad_sequence(masks, batch_first=True, padding_value=0)

    return {
        "input_ids": ids_padded,
        "attention_mask": masks_padded,
        "label": labels
    }


class AttentionPooling(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.attention = nn.Linear(hidden_dim, 1)

    def forward(self, token_embeddings, attention_mask):
        att_scores = self.attention(token_embeddings).squeeze(-1)
        att_scores = att_scores.masked_fill(attention_mask == 0, float("-inf"))
        att_weights = torch.softmax(att_scores, dim=-1).unsqueeze(-1)
        return torch.sum(token_embeddings * att_weights, dim=1)

class MediumHead(nn.Module):
    def __init__(self, input_dim, dropout=0.1):
        super().__init__()
        self.regressor = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 1)
        )
    def forward(self, x):
        return self.regressor(x).squeeze(-1)

class EnzymeRegressionModel(nn.Module):
    def __init__(self, esm_model, hidden_dim, head_type="medium", dropout=0.1):
        super().__init__()
        self.esm_model = esm_model
        self.pooler = AttentionPooling(hidden_dim)
        self.head = MediumHead(hidden_dim, dropout)

    def forward(self, input_ids, attention_mask):
        outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = outputs.last_hidden_state
        pooled_repr = self.pooler(hidden_state, attention_mask)
        return self.head(pooled_repr)

def train_one_fold(model, train_loader, val_loader, num_epochs=10, lr=1e-4):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

    for epoch in range(num_epochs):
        model.train()
        for batch in train_loader:
            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            optimizer.zero_grad()
            preds = model(ids, mask)
            loss = criterion(preds, labels)
            loss.backward()
            optimizer.step()

        model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for batch in val_loader:
                ids = batch["input_ids"].to(device)
                mask = batch["attention_mask"].to(device)
                labels = batch["label"].to(device)
                preds = model(ids, mask)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        val_loss = np.mean((np.array(all_labels) - np.array(all_preds))**2)
        spearman_corr, _ = spearmanr(all_preds, all_labels)
        
        scheduler.step(val_loss)
    return model

def print_trainable_parameters(model):
    trainable_params=0; all_param=0
    for _,param in model.named_parameters():
        all_param+=param.numel()
        if param.requires_grad:trainable_params+=param.numel()
    print(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100*trainable_params/all_param:.2f}")

In [None]:
chosen_model_name = "facebook/esm2_t33_650M_UR50D"
hidden_dim = 1280

tokenizer = AutoTokenizer.from_pretrained(chosen_model_name)

lora_config = LoraConfig(r=8, lora_alpha=16, target_modules=["query", "value"], lora_dropout=0.1, bias="none")

In [None]:
gkf = GroupKFold(n_splits=5)
fold_results = []
fold_idx = 1

train_val_data = train_val_df.reset_index(drop=True)
train_val_groups = train_val_data['group']

for train_index, val_index in gkf.split(train_val_data, groups=train_val_groups):
    
    fold_train_df = train_val_data.iloc[train_index]
    fold_val_df = train_val_data.iloc[val_index]

    train_dataset = EnzymeDataset(fold_train_df, tokenizer)
    val_dataset = EnzymeDataset(fold_val_df, tokenizer)
    collate_wrapper = lambda batch: collate_fn(batch, tokenizer)
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_wrapper)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=collate_wrapper)

    base_esm_model_fold = AutoModel.from_pretrained(chosen_model_name)
    peft_model_fold = get_peft_model(base_esm_model_fold, lora_config)

    model = EnzymeRegressionModel(
        esm_model=peft_model_fold,
        hidden_dim=hidden_dim
    ).to(device)

    print_trainable_parameters(model)
    trained_model = train_one_fold(model, train_loader, val_loader, num_epochs=15, lr=5e-5)

    trained_model.eval()
    val_preds, val_labels = [], []
    with torch.no_grad():
        for batch in val_loader:
            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            preds = trained_model(ids, mask)
            val_preds.extend(preds.cpu().numpy())
            val_labels.extend(batch["label"].numpy())

    val_labels = np.array(val_labels)
    val_preds = np.array(val_preds)
    
    spearman_corr, _ = spearmanr(val_labels, val_preds)
    pearson_corr, _ = pearsonr(val_labels, val_preds)
    rmse = np.sqrt(mean_squared_error(val_labels, val_preds))
    r2 = r2_score(val_labels, val_preds)
    
    results_dict = {
        'spearman': spearman_corr,
        'pearson': pearson_corr,
        'rmse': rmse,
        'r2': r2
    }
    fold_results.append(results_dict)
    
    fold_idx += 1
 

In [None]:

for metric_name in fold_results[0].keys():
    metric_values = [r[metric_name] for r in fold_results]
    mean_val = np.mean(metric_values)
    std_val = np.std(metric_values)