In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

df1 = pd.read_excel('pnas.1616408114.sd01.xlsx')
df2 = pd.read_excel('pnas.1616408114.sd02.xlsx') 
df3 = pd.read_excel('pnas.1616408114.sd03.xlsx')

merged_df = df1.merge(df2, on='Name', how='outer').merge(df3, on='Name', how='outer')
             
df = merged_df[['VH', 'VL', 'HEK Titer (mg/L)']].copy()

print("\nFirst few rows of df_yield:")
print(df.head())

In [4]:
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence


from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import StandardScaler
from scipy.stats import spearmanr


from transformers import AutoModel, AutoTokenizer


In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print("Using device:", device)

In [6]:
min_label_val = df["HEK Titer (mg/L)"].min()
shift_val = 1.0 - min_label_val if min_label_val <= 0 else 0
df["log_label"] = np.log(df["HEK Titer (mg/L)"] + shift_val + 1.0)  


scaler = StandardScaler()
df["scaled_label"] = scaler.fit_transform(df[["log_label"]])  

In [None]:
train_val_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
print(f"Train+Val: {len(train_val_df)}, Test: {len(test_df)}")


In [9]:
class AntibodyDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=1024):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        heavy = row["VH"]
        light = row["VL"]
        label = row["scaled_label"]  

        # Tokenize
        heavy_inputs = self.tokenizer(
            heavy,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        light_inputs = self.tokenizer(
            light,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        return {
            "heavy_input_ids": heavy_inputs["input_ids"].squeeze(0),
            "heavy_attention_mask": heavy_inputs["attention_mask"].squeeze(0),
            "light_input_ids": light_inputs["input_ids"].squeeze(0),
            "light_attention_mask": light_inputs["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.float)
        }

In [10]:
def collate_fn(batch):
    heavy_ids = [item["heavy_input_ids"] for item in batch]
    heavy_masks = [item["heavy_attention_mask"] for item in batch]
    light_ids = [item["light_input_ids"] for item in batch]
    light_masks = [item["light_attention_mask"] for item in batch]
    labels = torch.stack([item["label"] for item in batch])

    # Pad sequences
    heavy_ids_padded = pad_sequence(
        heavy_ids, batch_first=True, padding_value=tokenizer.pad_token_id
    )
    heavy_masks_padded = pad_sequence(
        heavy_masks, batch_first=True, padding_value=0
    )
    light_ids_padded = pad_sequence(
        light_ids, batch_first=True, padding_value=tokenizer.pad_token_id
    )
    light_masks_padded = pad_sequence(
        light_masks, batch_first=True, padding_value=0
    )

    return {
        "heavy_input_ids": heavy_ids_padded,
        "heavy_attention_mask": heavy_masks_padded,
        "light_input_ids": light_ids_padded,
        "light_attention_mask": light_masks_padded,
        "label": labels
    }

In [11]:
class AttentionPooling(nn.Module):
    """
    Weighted pooling across the sequence dimension using learned attention.
    """
    def __init__(self, hidden_dim):
        super().__init__()
        self.attention = nn.Linear(hidden_dim, 1)

    def forward(self, token_embeddings, attention_mask):
        """
        token_embeddings: (batch, seq_len, hidden_dim)
        attention_mask:   (batch, seq_len)
        Returns pooled vector of shape (batch, hidden_dim).
        """
        
        att_scores = self.attention(token_embeddings).squeeze(-1)  
        
        att_scores = att_scores.masked_fill(~(attention_mask.bool()), float("-inf"))
        att_weights = torch.softmax(att_scores, dim=-1).unsqueeze(-1)  
        pooled = torch.sum(token_embeddings * att_weights, dim=1)     
        return pooled

In [12]:
class SimpleHead(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.fc = nn.Linear(input_dim, 1)

    def forward(self, x):
        return self.fc(x).squeeze(-1)

class MediumHead(nn.Module):
    def __init__(self, input_dim, dropout=0.1):
        super().__init__()
        self.regressor = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        return self.regressor(x).squeeze(-1)

class DeepHead(nn.Module):
    def __init__(self, input_dim, dropout=0.1):
        super().__init__()
        self.regressor = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        return self.regressor(x).squeeze(-1)
    
class DeeperHead(nn.Module):
    def __init__(self, input_dim, dropout=0.1):
        super().__init__()
        self.regressor = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.regressor(x).squeeze(-1)


In [13]:
class RegressionModel(nn.Module):
    def __init__(
        self,
        esm_model,
        hidden_dim,
        head_type="simple",
        dropout=0.1,
        use_attention_pool=True
    ):
        super().__init__()
        self.esm_model = esm_model
        for param in self.esm_model.parameters():
            param.requires_grad = False  # freeze ESM

        # Choose an attention pooling or mean pooling
        if use_attention_pool:
            self.pooler = AttentionPooling(hidden_dim)
        else:
            self.pooler = None

        # Choose the head architecture
        input_dim = 2 * hidden_dim  # because heavy + light chain are concatenated
        if head_type == "simple":
            self.head = SimpleHead(input_dim)
        elif head_type == "medium":
            self.head = MediumHead(input_dim, dropout)
        elif head_type == "deep":
            self.head = DeepHead(input_dim, dropout)
        elif head_type == "deeper":
            self.head = DeeperHead(input_dim, dropout)
        else:
            raise ValueError(f"Invalid head type: {head_type}")

    def forward(self, heavy_ids, heavy_mask, light_ids, light_mask):
        # ESM forward
        heavy_out = self.esm_model(input_ids=heavy_ids, attention_mask=heavy_mask)
        light_out = self.esm_model(input_ids=light_ids, attention_mask=light_mask)

        # Last hidden state: get embeddings
        heavy_hidden = heavy_out.last_hidden_state  
        light_hidden = light_out.last_hidden_state  

        # Pooling
        if self.pooler is not None:
            heavy_repr = self.pooler(heavy_hidden, heavy_mask)
            light_repr = self.pooler(light_hidden, light_mask)
        else:
            # mean pooling
            heavy_mask_f = heavy_mask.unsqueeze(-1).float()
            heavy_sum = (heavy_hidden * heavy_mask_f).sum(dim=1)
            heavy_len = heavy_mask_f.sum(dim=1).clamp(min=1e-9)
            heavy_repr = heavy_sum / heavy_len

            light_mask_f = light_mask.unsqueeze(-1).float()
            light_sum = (light_hidden * light_mask_f).sum(dim=1)
            light_len = light_mask_f.sum(dim=1).clamp(min=1e-9)
            light_repr = light_sum / light_len

        combined = torch.cat([heavy_repr, light_repr], dim=1)  
        preds = self.head(combined)  
        return preds

In [14]:
class EarlyStopping:
    def __init__(self, patience=3):
        self.patience = patience
        self.best_val_loss = float("inf")
        self.counter = 0
        self.early_stop = False

    def step(self, val_loss):
        if val_loss < self.best_val_loss:
            self.best_val_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True


In [15]:
def train_one_fold(
    model,
    train_loader,
    val_loader,
    num_epochs=10,
    lr=1e-3,
    use_scheduler=True
):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.head.parameters(), lr=lr)
    if isinstance(model.pooler, AttentionPooling):
        optimizer = optim.Adam(
            list(model.head.parameters()) + list(model.pooler.parameters()),
            lr=lr
        )

    if use_scheduler:
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.5)
    else:
        scheduler = None

    early_stopper = EarlyStopping(patience=3)

    for epoch in range(num_epochs):
        model.train()
        train_losses = []
        for batch in train_loader:
            heavy_ids = batch["heavy_input_ids"].to(device)
            heavy_mask = batch["heavy_attention_mask"].to(device)
            light_ids = batch["light_input_ids"].to(device)
            light_mask = batch["light_attention_mask"].to(device)
            labels = batch["label"].to(device)

            optimizer.zero_grad()
            preds = model(heavy_ids, heavy_mask, light_ids, light_mask)
            loss = criterion(preds, labels)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

        avg_train_loss = np.mean(train_losses)

        model.eval()
        val_losses = []
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for batch in val_loader:
                heavy_ids = batch["heavy_input_ids"].to(device)
                heavy_mask = batch["heavy_attention_mask"].to(device)
                light_ids = batch["light_input_ids"].to(device)
                light_mask = batch["light_attention_mask"].to(device)
                labels = batch["label"].to(device)

                preds = model(heavy_ids, heavy_mask, light_ids, light_mask)
                loss = criterion(preds, labels)
                val_losses.append(loss.item())

                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        avg_val_loss = np.mean(val_losses)
        spearman_corr, _ = spearmanr(all_preds, all_labels)

        if scheduler is not None:
            scheduler.step()

        print(f"[Epoch {epoch+1:02d}] Train Loss: {avg_train_loss:.4f} | "
              f"Val Loss: {avg_val_loss:.4f} | Spearman: {spearman_corr:.4f}")

        early_stopper.step(avg_val_loss)
        if early_stopper.early_stop:
            print("Early stopping triggered!")
            break

    return model

In [16]:
model_configurations = {
     "facebook/esm2_t6_8M_UR50D": (320, "simple"),
     "facebook/esm2_t12_35M_UR50D": (480, "medium"),
     "facebook/esm2_t30_150M_UR50D": (640, "medium"),
     "facebook/esm2_t33_650M_UR50D": (1280, "deep"),
     "facebook/esm2_t36_3B_UR50D": (2560, "deeper")
}

In [None]:

chosen_model_name = "facebook/esm2_t6_8M_UR50D"
hidden_dim, head_type = model_configurations[chosen_model_name]


tokenizer = AutoTokenizer.from_pretrained(chosen_model_name)
base_esm_model = AutoModel.from_pretrained(chosen_model_name)

In [None]:
test_dataset = AntibodyDataset(test_df, tokenizer)
test_loader = DataLoader(test_dataset, batch_size=8, collate_fn=collate_fn)

kfold = KFold(n_splits=5, shuffle=True, random_state=42)

fold_results = []
fold_idx = 1

train_val_data = train_val_df.reset_index(drop=True)

for train_index, val_index in kfold.split(train_val_data):
    print(f"\n==== Fold {fold_idx} ====")
    fold_idx += 1
    fold_train_df = train_val_data.iloc[train_index]
    fold_val_df = train_val_data.iloc[val_index]

    train_dataset = AntibodyDataset(fold_train_df, tokenizer)
    val_dataset = AntibodyDataset(fold_val_df, tokenizer)

    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)

    model = RegressionModel(
        esm_model=base_esm_model,
        hidden_dim=hidden_dim,
        head_type=head_type,
        dropout=0.1,
        use_attention_pool=True
    ).to(device)

    trained_model = train_one_fold(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=10,
        lr=1e-3,
        use_scheduler=True
    )

    trained_model.eval()
    val_preds = []
    val_labels = []
    criterion = nn.MSELoss()
    val_losses = []

    with torch.no_grad():
        for batch in val_loader:
            heavy_ids = batch["heavy_input_ids"].to(device)
            heavy_mask = batch["heavy_attention_mask"].to(device)
            light_ids = batch["light_input_ids"].to(device)
            light_mask = batch["light_attention_mask"].to(device)
            labels = batch["label"].to(device)

            preds = trained_model(heavy_ids, heavy_mask, light_ids, light_mask)
            loss = criterion(preds, labels)
            val_losses.append(loss.item())
            val_preds.extend(preds.cpu().numpy())
            val_labels.extend(labels.cpu().numpy())

    avg_val_loss = np.mean(val_losses)
    val_spearman, _ = spearmanr(val_preds, val_labels)
    fold_results.append((avg_val_loss, val_spearman))
    print(f"Fold Validation Loss: {avg_val_loss:.4f}, Spearman: {val_spearman:.4f}")

val_losses = [r[0] for r in fold_results]
val_spearmans = [r[1] for r in fold_results]
cv_val_loss = np.mean(val_losses)
cv_val_spearman = np.mean(val_spearmans)
sd_val_loss = np.std(val_losses, ddof=1)
sd_val_spearman = np.std(val_spearmans, ddof=1)

print(f"\n=== Cross-Validation Results ===")
print(f"Average Validation Loss: {cv_val_loss:.4f} +/- {sd_val_loss:.4f}")
print(f"Average Validation Spearman: {cv_val_spearman:.4f} +/- {sd_val_spearman:.4f}")