In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

df1 = pd.read_excel('pnas.1616408114.sd01.xlsx')
df2 = pd.read_excel('pnas.1616408114.sd02.xlsx') 
df3 = pd.read_excel('pnas.1616408114.sd03.xlsx')

merged_df = df1.merge(df2, on='Name', how='outer').merge(df3, on='Name', how='outer')
             
df = merged_df[['VH', 'VL', 'HEK Titer (mg/L)']].copy()

In [20]:
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import StandardScaler
from scipy.stats import spearmanr

from transformers import AutoModel, AutoTokenizer

from tqdm import tqdm

from sklearn.metrics import mean_squared_error
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from xgboost import XGBRegressor
from sklearn.svm import SVR


In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print("Using device:", device)

In [4]:
min_label_val = df["HEK Titer (mg/L)"].min()
shift_val = 1.0 - min_label_val if min_label_val <= 0 else 0
df["log_label"] = np.log(df["HEK Titer (mg/L)"] + shift_val + 1.0)  # log( label + offset )


scaler = StandardScaler()
df["scaled_label"] = scaler.fit_transform(df[["log_label"]])  # shape => (n,1)

In [None]:
train_val_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
print(f"Train+Val: {len(train_val_df)}, Test: {len(test_df)}")

In [7]:
class AntibodyDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=1024):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        heavy = row["VH"]
        light = row["VL"]
        label = row["scaled_label"]

        heavy_inputs = self.tokenizer(
            heavy,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        light_inputs = self.tokenizer(
            light,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        return {
            "heavy_input_ids": heavy_inputs["input_ids"].squeeze(0),
            "heavy_attention_mask": heavy_inputs["attention_mask"].squeeze(0),
            "light_input_ids": light_inputs["input_ids"].squeeze(0),
            "light_attention_mask": light_inputs["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.float)
        }
        
def collate_fn(batch):
    heavy_ids = [item["heavy_input_ids"] for item in batch]
    heavy_masks = [item["heavy_attention_mask"] for item in batch]
    light_ids = [item["light_input_ids"] for item in batch]
    light_masks = [item["light_attention_mask"] for item in batch]
    labels = torch.stack([item["label"] for item in batch])

    heavy_ids_padded = pad_sequence(
        heavy_ids, batch_first=True, padding_value=tokenizer.pad_token_id
    )
    heavy_masks_padded = pad_sequence(
        heavy_masks, batch_first=True, padding_value=0
    )
    light_ids_padded = pad_sequence(
        light_ids, batch_first=True, padding_value=tokenizer.pad_token_id
    )
    light_masks_padded = pad_sequence(
        light_masks, batch_first=True, padding_value=0
    )

    return {
        "heavy_input_ids": heavy_ids_padded,
        "heavy_attention_mask": heavy_masks_padded,
        "light_input_ids": light_ids_padded,
        "light_attention_mask": light_masks_padded,
        "label": labels
    }
    
class AttentionPooling(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.attention = nn.Linear(hidden_dim, 1)

    def forward(self, token_embeddings, attention_mask):
        att_scores = self.attention(token_embeddings).squeeze(-1)
        att_scores = att_scores.masked_fill(~(attention_mask.bool()), float("-inf"))
        att_weights = torch.softmax(att_scores, dim=-1).unsqueeze(-1)
        pooled = torch.sum(token_embeddings * att_weights, dim=1)
        return pooled

In [8]:
class ESMEmbedder(torch.nn.Module):
    def __init__(self, esm_model, hidden_dim, use_attention_pool=True):
        super().__init__()
        self.esm_model = esm_model
        for param in self.esm_model.parameters():
            param.requires_grad = False

        self.use_attention_pool = use_attention_pool
        if use_attention_pool:
            self.pooler = AttentionPooling(hidden_dim)
        else:
            self.pooler = None

    def forward(self, heavy_ids, heavy_mask, light_ids, light_mask):
        heavy_out = self.esm_model(input_ids=heavy_ids, attention_mask=heavy_mask)
        light_out = self.esm_model(input_ids=light_ids, attention_mask=light_mask)

        heavy_hidden = heavy_out.last_hidden_state
        light_hidden = light_out.last_hidden_state

        if self.pooler is not None:
            heavy_repr = self.pooler(heavy_hidden, heavy_mask)
            light_repr = self.pooler(light_hidden, light_mask)
        else:
            heavy_mask_f = heavy_mask.unsqueeze(-1).float()
            heavy_sum = (heavy_hidden * heavy_mask_f).sum(dim=1)
            heavy_len = heavy_mask_f.sum(dim=1).clamp(min=1e-9)
            heavy_repr = heavy_sum / heavy_len

            light_mask_f = light_mask.unsqueeze(-1).float()
            light_sum = (light_hidden * light_mask_f).sum(dim=1)
            light_len = light_mask_f.sum(dim=1).clamp(min=1e-9)
            light_repr = light_sum / light_len

        combined = torch.cat([heavy_repr, light_repr], dim=1)
        return combined

In [9]:
def extract_embeddings(model, data_loader, device):
    model.eval()
    all_feats = []
    all_labels = []
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Extracting embeddings"):
            heavy_ids = batch["heavy_input_ids"].to(device)
            heavy_mask = batch["heavy_attention_mask"].to(device)
            light_ids = batch["light_input_ids"].to(device)
            light_mask = batch["light_attention_mask"].to(device)

            labels = batch["label"].numpy()
            embeddings = model(heavy_ids, heavy_mask, light_ids, light_mask)
            embeddings = embeddings.cpu().numpy()

            all_feats.append(embeddings)
            all_labels.append(labels)

    X = np.concatenate(all_feats, axis=0)
    y = np.concatenate(all_labels, axis=0)
    return X, y


In [10]:
model_configurations = {
     "facebook/esm2_t6_8M_UR50D": (320, "simple"),
     "facebook/esm2_t12_35M_UR50D": (480, "medium"),
     "facebook/esm2_t30_150M_UR50D": (640, "medium"),
     "facebook/esm2_t33_650M_UR50D": (1280, "deep"),
     "facebook/esm2_t36_3B_UR50D": (2560, "deeper")
}

In [None]:

chosen_model_name = "facebook/esm2_t6_8M_UR50D"
hidden_dim, head_type = model_configurations[chosen_model_name]


tokenizer = AutoTokenizer.from_pretrained(chosen_model_name)
base_esm_model = AutoModel.from_pretrained(chosen_model_name)

In [29]:

frozen_esm_embedder = ESMEmbedder(
    esm_model=base_esm_model,
    hidden_dim=hidden_dim,         
    use_attention_pool=True        # or False for mean pooling
).to(device)

In [30]:

full_dataset = AntibodyDataset(train_val_df, tokenizer)
full_loader = DataLoader(full_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)

In [None]:

X_full, y_full = extract_embeddings(frozen_esm_embedder, full_loader, device=device)

In [32]:

kf = KFold(n_splits=5, shuffle=True, random_state=42)
folds = list(kf.split(X_full))

In [33]:
models = {
    "RandomForest": RandomForestRegressor(
        n_estimators=100,
        max_depth=None,
        random_state=42
    ),
    "GradientBoosting": GradientBoostingRegressor(
        n_estimators=100,
        learning_rate=0.1,
        max_depth=3,
        random_state=42
    ),
    "XGBoost": XGBRegressor(
        n_estimators=100,
        learning_rate=0.1,
        max_depth=3,
        random_state=42,
        verbosity=0
    ),
    "SVR": SVR(
        kernel='rbf',
        C=1.0,
        epsilon=0.1
    )
}

In [None]:
cv_results = {model_name: [] for model_name in models.keys()}

for i, (train_idx, val_idx) in enumerate(folds, start=1):
    print(f"\n=== Fold {i} ===")
    
    X_train, X_val = X_full[train_idx], X_full[val_idx]
    y_train, y_val = y_full[train_idx], y_full[val_idx]
    
    for model_name, model in models.items():
        model.fit(X_train, y_train)
        
        train_preds = model.predict(X_train)
        val_preds = model.predict(X_val)
        
        train_mse = mean_squared_error(y_train, train_preds)
        val_mse   = mean_squared_error(y_val, val_preds)
        val_spear, _ = spearmanr(val_preds, y_val)
        
        cv_results[model_name].append((val_mse, val_spear))
        
        print(f"  [{model_name}]  Train MSE: {train_mse:.4f} | Val MSE: {val_mse:.4f} | Spearman: {val_spear:.4f}")
        
        if val_mse > 2.0 * train_mse:
            print(f"    -> Possible Overfitting Detected (Val MSE >> Train MSE)")

In [None]:
for model_name in models.keys():
    all_mses = [res[0] for res in cv_results[model_name]]     
    all_spears = [res[1] for res in cv_results[model_name]]
    avg_mse = np.mean(all_mses)
    avg_spear = np.mean(all_spears)
    sd_mse = np.std(all_mses, ddof=1)
    sd_spear = np.std(all_spears, ddof=1)
    print(f"{model_name:>20s} => Avg Val MSE: {avg_mse:.4f} +/- {sd_mse:.4f}, Avg Spearman: {avg_spear:.4f} +/- {sd_spear:.4f}")

In [None]:
chosen_model_name = "facebook/esm2_t12_35M_UR50D"
hidden_dim, head_type = model_configurations[chosen_model_name]

tokenizer = AutoTokenizer.from_pretrained(chosen_model_name)
base_esm_model = AutoModel.from_pretrained(chosen_model_name)

frozen_esm_embedder = ESMEmbedder(
    esm_model=base_esm_model,
    hidden_dim=hidden_dim,
    use_attention_pool=True
).to(device)

full_dataset = AntibodyDataset(train_val_df, tokenizer)
full_loader = DataLoader(full_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)

X_full, y_full = extract_embeddings(frozen_esm_embedder, full_loader, device=device)

kf = KFold(n_splits=5, shuffle=True, random_state=42)
folds = list(kf.split(X_full))

models = {
    "RandomForest": RandomForestRegressor(
        n_estimators=100,
        max_depth=None,
        random_state=42
    ),
    "GradientBoosting": GradientBoostingRegressor(
        n_estimators=100,
        learning_rate=0.1,
        max_depth=3,
        random_state=42
    ),
    "XGBoost": XGBRegressor(
        n_estimators=100,
        learning_rate=0.1,
        max_depth=3,
        random_state=42,
        verbosity=0
    ),
    "SVR": SVR(
        kernel='rbf',
        C=1.0,
        epsilon=0.1
    )
}

cv_results = {model_name: [] for model_name in models.keys()}

for i, (train_idx, val_idx) in enumerate(folds, start=1):
    print(f"\n=== Fold {i} ===")
    
    X_train, X_val = X_full[train_idx], X_full[val_idx]
    y_train, y_val = y_full[train_idx], y_full[val_idx]
    
    for model_name, model in models.items():
        model.fit(X_train, y_train)
        
        train_preds = model.predict(X_train)
        val_preds = model.predict(X_val)
        
        train_mse = mean_squared_error(y_train, train_preds)
        val_mse   = mean_squared_error(y_val, val_preds)
        val_spear, _ = spearmanr(val_preds, y_val)
        
        cv_results[model_name].append((val_mse, val_spear))
        
        print(f"  [{model_name}]  Train MSE: {train_mse:.4f} | Val MSE: {val_mse:.4f} | Spearman: {val_spear:.4f}")
        
        if val_mse > 2.0 * train_mse:
            print(f"    -> Possible Overfitting Detected (Val MSE >> Train MSE)")

for model_name in models.keys():
    all_mses = [res[0] for res in cv_results[model_name]]
    all_spears = [res[1] for res in cv_results[model_name]]
    avg_mse = np.mean(all_mses)
    avg_spear = np.mean(all_spears)
    sd_mse = np.std(all_mses, ddof=1)
    sd_spear = np.std(all_spears, ddof=1)
    print(f"{model_name:>20s} => Avg Val MSE: {avg_mse:.4f} +/- {sd_mse:.4f}, Avg Spearman: {avg_spear:.4f} +/- {sd_spear:.4f}")