In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import random
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import StandardScaler
from scipy.stats import spearmanr, pearsonr
from sklearn.metrics import mean_squared_error


from transformers import AutoModel, AutoTokenizer


from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from xgboost import XGBRegressor
from sklearn.svm import SVR

from tqdm import tqdm

In [2]:
IGBERT_MODEL_NAME = "Exscientia/IgBERT"
IGBERT_EMBEDDING_DIM = 1024
MAX_LENGTH = 512

BATCH_SIZE = 8
N_SPLITS_CV = 5

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)


In [None]:

if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")


In [None]:
df1 = pd.read_excel('pnas.1616408114.sd01.xlsx')
df2 = pd.read_excel('pnas.1616408114.sd02.xlsx') 
df3 = pd.read_excel('pnas.1616408114.sd03.xlsx')

merged_df = df1.merge(df2, on='Name', how='outer').merge(df3, on='Name', how='outer')
             
df = merged_df[['VH', 'VL', 'HEK Titer (mg/L)']].copy()

In [None]:
min_label_val = df["HEK Titer (mg/L)"].min()
if min_label_val <= 0:
    print(f"Warning: Minimum label value is {min_label_val}. Adding a small shift before log transform.")
    df["HEK Titer (mg/L)"] = df["HEK Titer (mg/L)"].apply(lambda x: x if x > 0 else 1e-6)

df["log_label"] = np.log(df["HEK Titer (mg/L)"])

scaler = StandardScaler()
df["scaled_label"] = scaler.fit_transform(df[["log_label"]])


In [None]:

train_val_df, test_df = train_test_split(df, test_size=0.2, random_state=SEED) 
print(f"\nTrain/Validation set size: {len(train_val_df)}")
print(f"Test set size: {len(test_df)}")

In [7]:
class AntibodyDatasetIgBERT(Dataset):
    def __init__(self, df, tokenizer, max_length=MAX_LENGTH):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.df['VH'] = self.df['VH'].astype(str)
        self.df['VL'] = self.df['VL'].astype(str)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        heavy = row["VH"]
        light = row["VL"]
        label = row["scaled_label"]

        heavy_spaced = " ".join(list(heavy))
        light_spaced = " ".join(list(light))

        heavy_inputs = self.tokenizer(
            heavy_spaced, truncation=True, max_length=self.max_length,
            padding='max_length', return_tensors="pt"
        )
        light_inputs = self.tokenizer(
            light_spaced, truncation=True, max_length=self.max_length,
            padding='max_length', return_tensors="pt"
        )
        return {
            "heavy_input_ids": heavy_inputs["input_ids"].squeeze(0),
            "heavy_attention_mask": heavy_inputs["attention_mask"].squeeze(0),
            "light_input_ids": light_inputs["input_ids"].squeeze(0),
            "light_attention_mask": light_inputs["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.float)
        }

def collate_fn_igbert(batch):
    heavy_ids = torch.stack([item["heavy_input_ids"] for item in batch])
    heavy_masks = torch.stack([item["heavy_attention_mask"] for item in batch])
    light_ids = torch.stack([item["light_input_ids"] for item in batch])
    light_masks = torch.stack([item["light_attention_mask"] for item in batch])
    labels = torch.stack([item["label"] for item in batch])
    return {
        "heavy_input_ids": heavy_ids, "heavy_attention_mask": heavy_masks,
        "light_input_ids": light_ids, "light_attention_mask": light_masks,
        "label": labels
    }

class AttentionPooling(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.attention_net = nn.Linear(hidden_dim, 1)

    def forward(self, token_embeddings, attention_mask):
        att_scores = self.attention_net(token_embeddings).squeeze(-1)
        att_scores = att_scores.masked_fill(attention_mask == 0, -1e9)
        att_weights = torch.softmax(att_scores, dim=-1).unsqueeze(-1)
        pooled = torch.sum(token_embeddings * att_weights, dim=1)
        return pooled

class IgBERTEmbedder(nn.Module):
    def __init__(self, igbert_model, embedding_dim, use_attention_pool=True):
        super().__init__()
        self.igbert_model = igbert_model
        for param in self.igbert_model.parameters():
            param.requires_grad = False

        self.use_attention_pool = use_attention_pool
        if use_attention_pool:
            self.pooler = AttentionPooling(embedding_dim)
        else:
            self.pooler = None
        print(f"IgBERTEmbedder initialized. Using attention pooling: {self.use_attention_pool}")
        print(f"Expected embedding dimension for pooling: {embedding_dim}")


    def forward(self, heavy_ids, heavy_mask, light_ids, light_mask):
        heavy_out = self.igbert_model(input_ids=heavy_ids, attention_mask=heavy_mask)
        light_out = self.igbert_model(input_ids=light_ids, attention_mask=light_mask)

        heavy_hidden = heavy_out.last_hidden_state
        light_hidden = light_out.last_hidden_state

        if self.pooler is not None:
            heavy_repr = self.pooler(heavy_hidden, heavy_mask)
            light_repr = self.pooler(light_hidden, light_mask)
        else:
            heavy_mask_f = heavy_mask.unsqueeze(-1).float()
            heavy_sum = (heavy_hidden * heavy_mask_f).sum(dim=1)
            heavy_len = heavy_mask_f.sum(dim=1).clamp(min=1e-9)
            heavy_repr = heavy_sum / heavy_len

            light_mask_f = light_mask.unsqueeze(-1).float()
            light_sum = (light_hidden * light_mask_f).sum(dim=1)
            light_len = light_mask_f.sum(dim=1).clamp(min=1e-9)
            light_repr = light_sum / light_len

        combined_embeddings = torch.cat([heavy_repr, light_repr], dim=1)
        return combined_embeddings

def extract_igbert_embeddings(embedder_model, data_loader, device):
    embedder_model.eval()
    all_embeddings_list = []
    all_labels_list = []
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Extracting IgBERT embeddings"):
            heavy_ids = batch["heavy_input_ids"].to(device)
            heavy_mask = batch["heavy_attention_mask"].to(device)
            light_ids = batch["light_input_ids"].to(device)
            light_mask = batch["light_attention_mask"].to(device)
            labels = batch["label"].cpu().numpy()

            embeddings = embedder_model(heavy_ids, heavy_mask, light_ids, light_mask)
            all_embeddings_list.append(embeddings.cpu().numpy())
            all_labels_list.append(labels)

    X = np.concatenate(all_embeddings_list, axis=0)
    y = np.concatenate(all_labels_list, axis=0)
    return X, y

In [None]:
igbert_tokenizer = AutoTokenizer.from_pretrained(IGBERT_MODEL_NAME)
base_igbert_model = AutoModel.from_pretrained(IGBERT_MODEL_NAME)

actual_igbert_hidden_size = base_igbert_model.config.hidden_size
if actual_igbert_hidden_size != IGBERT_EMBEDDING_DIM:
    print(f"WARNING: Configured IGBERT_EMBEDDING_DIM ({IGBERT_EMBEDDING_DIM}) "
          f"does not match loaded model's hidden_size ({actual_igbert_hidden_size}). "
          f"Using actual_model_hidden_size: {actual_igbert_hidden_size} for embedder.")
    current_igbert_embedding_dim = actual_igbert_hidden_size
else:
    current_igbert_embedding_dim = IGBERT_EMBEDDING_DIM

igbert_embedder = IgBERTEmbedder(
    igbert_model=base_igbert_model,
    embedding_dim=current_igbert_embedding_dim,
    use_attention_pool=True
).to(device)

train_val_dataset = AntibodyDatasetIgBERT(train_val_df, igbert_tokenizer, max_length=MAX_LENGTH)
train_val_loader = DataLoader(train_val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn_igbert)

if not test_df.empty:
    test_dataset = AntibodyDatasetIgBERT(test_df, igbert_tokenizer, max_length=MAX_LENGTH)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn_igbert)
else:
    test_loader = None
    print("Test set is empty. Skipping final test set evaluation with classical models.")

print("\nExtracting IgBERT embeddings for training/validation set...")
X_train_val_embeddings, y_train_val_labels = extract_igbert_embeddings(igbert_embedder, train_val_loader, device)
print(f"Shape of training/validation embeddings (X): {X_train_val_embeddings.shape}")
print(f"Shape of training/validation labels (y): {y_train_val_labels.shape}")

X_test_embeddings, y_test_labels = None, None
if test_loader:
    print("\nExtracting IgBERT embeddings for test set...")
    X_test_embeddings, y_test_labels = extract_igbert_embeddings(igbert_embedder, test_loader, device)
    print(f"Shape of test embeddings (X): {X_test_embeddings.shape}")
    print(f"Shape of test labels (y): {y_test_labels.shape}")

classical_models = {
    "RandomForest": RandomForestRegressor(n_estimators=100, random_state=SEED, n_jobs=-1),
    "GradientBoosting": GradientBoostingRegressor(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=SEED),
    "XGBoost": XGBRegressor(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=SEED, verbosity=0, n_jobs=-1),
    "SVR_RBF": SVR(kernel='rbf', C=1.0, epsilon=0.1)
}

print(f"\n--- Starting {N_SPLITS_CV}-Fold Cross-Validation with Classical Models (using IgBERT Embeddings) ---")
kf = KFold(n_splits=N_SPLITS_CV, shuffle=True, random_state=SEED)
cv_results_classical = {model_name: {'mse': [], 'spearman': [], 'pearson': []} for model_name in classical_models.keys()}

for fold_idx, (train_indices, val_indices) in enumerate(kf.split(X_train_val_embeddings)):
    print(f"\n==== Fold {fold_idx+1}/{N_SPLITS_CV} ====")
    X_train_fold, X_val_fold = X_train_val_embeddings[train_indices], X_train_val_embeddings[val_indices]
    y_train_fold, y_val_fold = y_train_val_labels[train_indices], y_train_val_labels[val_indices]

    for model_name, model_instance in classical_models.items():
        print(f"  Training {model_name}...")
        model_instance.fit(X_train_fold, y_train_fold)
        
        val_preds = model_instance.predict(X_val_fold)
        
        val_mse = mean_squared_error(y_val_fold, val_preds)
        val_spearman, _ = spearmanr(y_val_fold, val_preds) if len(np.unique(y_val_fold)) > 1 and len(np.unique(val_preds)) > 1 else (0.0, 0.0)
        val_pearson, _ = pearsonr(y_val_fold, val_preds) if len(np.unique(y_val_fold)) > 1 and len(np.unique(val_preds)) > 1 else (0.0, 0.0)
        
        cv_results_classical[model_name]['mse'].append(val_mse)
        cv_results_classical[model_name]['spearman'].append(val_spearman)
        cv_results_classical[model_name]['pearson'].append(val_pearson)
        
        print(f"    {model_name} - Val MSE: {val_mse:.4f}, Val Spearman: {val_spearman:.4f}, Val Pearson: {val_pearson:.4f}")

print("\n--- Cross-Validation Summary (IgBERT Embeddings + Classical Models) ---")
for model_name, metrics in cv_results_classical.items():
    avg_mse = np.mean(metrics['mse'])
    std_mse = np.std(metrics['mse'])
    avg_spearman = np.mean(metrics['spearman'])
    std_spearman = np.std(metrics['spearman'])
    avg_pearson = np.mean(metrics['pearson'])
    std_pearson = np.std(metrics['pearson'])
    print(f"  {model_name}:")
    print(f"    Avg Val MSE:       {avg_mse:.4f} +/- {std_mse:.4f}")
    print(f"    Avg Val Spearman:  {avg_spearman:.4f} +/- {std_spearman:.4f}")
    print(f"    Avg Val Pearson:   {avg_pearson:.4f} +/- {std_pearson:.4f}")

if X_test_embeddings is not None and y_test_labels is not None:
    print("\n--- Final Test Set Evaluation (IgBERT Embeddings + Classical Models) ---")
    print("  (Models retrained on the full train_val_embeddings set)")

    for model_name, model_prototype in classical_models.items():
        print(f"  Training {model_name} on full train_val_embeddings set...")
        if model_name == "RandomForest": model_instance = RandomForestRegressor(n_estimators=100, random_state=SEED, n_jobs=-1)
        elif model_name == "GradientBoosting": model_instance = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=SEED)
        elif model_name == "XGBoost": model_instance = XGBRegressor(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=SEED, verbosity=0, n_jobs=-1)
        elif model_name == "SVR_RBF": model_instance = SVR(kernel='rbf', C=1.0, epsilon=0.1)
        else: model_instance = model_prototype

        model_instance.fit(X_train_val_embeddings, y_train_val_labels)
        
        test_preds_scaled = model_instance.predict(X_test_embeddings)
        
        test_preds_log = scaler.inverse_transform(test_preds_scaled.reshape(-1, 1)).flatten()
        y_test_labels_log = scaler.inverse_transform(y_test_labels.reshape(-1, 1)).flatten()
        
        test_preds_original = np.exp(test_preds_log)
        y_test_labels_original = np.exp(y_test_labels_log)

        test_mse_scaled = mean_squared_error(y_test_labels, test_preds_scaled)
        test_spearman_scaled, _ = spearmanr(y_test_labels, test_preds_scaled) if len(np.unique(y_test_labels)) > 1 and len(np.unique(test_preds_scaled)) > 1 else (0.0, 0.0)
        test_pearson_scaled, _ = pearsonr(y_test_labels, test_preds_scaled) if len(np.unique(y_test_labels)) > 1 and len(np.unique(test_preds_scaled)) > 1 else (0.0, 0.0)

        test_mse_original = mean_squared_error(y_test_labels_original, test_preds_original)
        test_spearman_original, _ = spearmanr(y_test_labels_original, test_preds_original) if len(np.unique(y_test_labels_original)) > 1 and len(np.unique(test_preds_original)) > 1 else (0.0, 0.0)
        test_pearson_original, _ = pearsonr(y_test_labels_original, test_preds_original) if len(np.unique(y_test_labels_original)) > 1 and len(np.unique(test_preds_original)) > 1 else (0.0, 0.0)

        print(f"  Model: {model_name}")
        print(f"    Test MSE (scaled):       {test_mse_scaled:.4f}")
        print(f"    Test Spearman (scaled):  {test_spearman_scaled:.4f}")
        print(f"    Test Pearson (scaled):   {test_pearson_scaled:.4f}")
        print(f"    Test MSE (original):     {test_mse_original:.4f}")
        print(f"    Test Spearman (original):{test_spearman_original:.4f}")
        print(f"    Test Pearson (original): {test_pearson_original:.4f}")

        plt.figure(figsize=(6, 5))
        sns.scatterplot(x=y_test_labels_original, y=test_preds_original, alpha=0.6)
        min_val = min(y_test_labels_original.min(), test_preds_original.min())
        max_val = max(y_test_labels_original.max(), test_preds_original.max())
        plt.plot([min_val, max_val], [min_val, max_val], color='red', linestyle='--')
        plt.xlabel("True HEK Titer (original scale)")
        plt.ylabel("Predicted HEK Titer (original scale)")
        plt.title(f"{model_name} - Test Set (IgBERT Embeddings)\nSpearman: {test_spearman_original:.3f}")
        plt.grid(True)
        plt.tight_layout()
        plt.show()
else:
    print("\nTest set embeddings not available. Skipping final test set evaluation.")
