In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence


from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import StandardScaler
from scipy.stats import spearmanr


from transformers import AutoModel, AutoTokenizer

In [None]:
df1 = pd.read_excel('pnas.1616408114.sd01.xlsx')
df2 = pd.read_excel('pnas.1616408114.sd02.xlsx') 
df3 = pd.read_excel('pnas.1616408114.sd03.xlsx')


merged_df = df1.merge(df2, on='Name', how='outer').merge(df3, on='Name', how='outer')
             

df = merged_df[['VH', 'VL', 'HEK Titer (mg/L)']].copy()



In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")


In [4]:
min_label_val = df["HEK Titer (mg/L)"].min()
shift_val = 1.0 - min_label_val if min_label_val <= 0 else 0
df["log_label"] = np.log(df["HEK Titer (mg/L)"] + shift_val + 1.0)  


scaler = StandardScaler()
df["scaled_label"] = scaler.fit_transform(df[["log_label"]])  

In [None]:

train_val_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
print(f"Train/Validation set size: {len(train_val_df)}")
print(f"Test set size: {len(test_df)}")

In [6]:
class AntibodyDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=512):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        heavy = row["VH"]
        light = row["VL"]
        label = row["scaled_label"]

        heavy_spaced = " ".join(list(heavy))
        light_spaced = " ".join(list(light))

        heavy_inputs = self.tokenizer(
            heavy_spaced,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors="pt"
        )
        light_inputs = self.tokenizer(
            light_spaced,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors="pt"
        )

        return {
            "heavy_input_ids": heavy_inputs["input_ids"].squeeze(0),
            "heavy_attention_mask": heavy_inputs["attention_mask"].squeeze(0),
            "light_input_ids": light_inputs["input_ids"].squeeze(0),
            "light_attention_mask": light_inputs["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.float)
        }

def collate_fn(batch):
    heavy_ids = torch.stack([item["heavy_input_ids"] for item in batch])
    heavy_masks = torch.stack([item["heavy_attention_mask"] for item in batch])
    light_ids = torch.stack([item["light_input_ids"] for item in batch])
    light_masks = torch.stack([item["light_attention_mask"] for item in batch])
    labels = torch.stack([item["label"] for item in batch])

    return {
        "heavy_input_ids": heavy_ids,
        "heavy_attention_mask": heavy_masks,
        "light_input_ids": light_ids,
        "light_attention_mask": light_masks,
        "label": labels
    }


class AttentionPooling(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.attention = nn.Linear(hidden_dim, 1)

    def forward(self, token_embeddings, attention_mask):
        att_scores = self.attention(token_embeddings).squeeze(-1)
        att_scores = att_scores.masked_fill(~(attention_mask.bool()), float("-inf"))
        att_weights = torch.softmax(att_scores, dim=-1).unsqueeze(-1)
        pooled = torch.sum(token_embeddings * att_weights, dim=1)
        return pooled

class SimpleHead(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.fc = nn.Linear(input_dim, 1)

    def forward(self, x):
        return self.fc(x).squeeze(-1)

class MediumHead(nn.Module):
    def __init__(self, input_dim, dropout=0.1):
        super().__init__()
        self.regressor = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        return self.regressor(x).squeeze(-1)

class DeepHead(nn.Module):
    def __init__(self, input_dim, dropout=0.1):
        super().__init__()
        self.regressor = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        return self.regressor(x).squeeze(-1)

class DeeperHead(nn.Module):
    def __init__(self, input_dim, dropout=0.1):
        super().__init__()
        self.regressor = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 1)
        )
    def forward(self, x):
        return self.regressor(x).squeeze(-1)

class RegressionModel(nn.Module):
    def __init__(
        self,
        plm_model,
        hidden_dim,
        head_type="medium",
        dropout=0.1,
        use_attention_pool=True,
        freeze_plm=True
    ):
        super().__init__()
        self.plm_model = plm_model
        self.hidden_dim = hidden_dim

        if freeze_plm:
            print("Freezing base PLM parameters.")
            for param in self.plm_model.parameters():
                param.requires_grad = False
        else:
             print("NOT freezing base PLM parameters (fine-tuning enabled).")

        if use_attention_pool:
            self.pooler = AttentionPooling(self.hidden_dim)
        else:
            self.pooler = None

        input_dim = 2 * self.hidden_dim
        if head_type == "simple":
            self.head = SimpleHead(input_dim)
        elif head_type == "medium":
            self.head = MediumHead(input_dim, dropout)
        elif head_type == "deep":
            self.head = DeepHead(input_dim, dropout)
        elif head_type == "deeper":
            self.head = DeeperHead(input_dim, dropout)
        else:
            raise ValueError(f"Invalid head type: {head_type}")

        print(f"Using head type: {head_type}")
        print(f"Using attention pooling: {use_attention_pool}")

    def forward(self, heavy_ids, heavy_mask, light_ids, light_mask):
        heavy_out = self.plm_model(input_ids=heavy_ids, attention_mask=heavy_mask)
        light_out = self.plm_model(input_ids=light_ids, attention_mask=light_mask)

        heavy_hidden = heavy_out.last_hidden_state
        light_hidden = light_out.last_hidden_state

        if self.pooler is not None:
            heavy_repr = self.pooler(heavy_hidden, heavy_mask)
            light_repr = self.pooler(light_hidden, light_mask)
        else:
            heavy_mask_f = heavy_mask.unsqueeze(-1).float()
            heavy_sum = (heavy_hidden * heavy_mask_f).sum(dim=1)
            heavy_len = heavy_mask_f.sum(dim=1).clamp(min=1e-9)
            heavy_repr = heavy_sum / heavy_len

            light_mask_f = light_mask.unsqueeze(-1).float()
            light_sum = (light_hidden * light_mask_f).sum(dim=1)
            light_len = light_mask_f.sum(dim=1).clamp(min=1e-9)
            light_repr = light_sum / light_len

        combined = torch.cat([heavy_repr, light_repr], dim=1)
        preds = self.head(combined)
        return preds

class EarlyStopping:
    def __init__(self, patience=3, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.best_val_loss = float("inf")
        self.counter = 0
        self.early_stop = False

    def step(self, val_loss):
        if val_loss < self.best_val_loss - self.min_delta:
            self.best_val_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

def train_one_fold(
    model,
    train_loader,
    val_loader,
    num_epochs=10,
    lr=1e-3,
    weight_decay=0.01,
    use_scheduler=True,
    patience=3
):
    model.to(device)
    criterion = nn.MSELoss()

    params_to_optimize = []
    if hasattr(model, 'head'):
        params_to_optimize.extend(list(model.head.parameters()))
    if isinstance(model.pooler, AttentionPooling):
        params_to_optimize.extend(list(model.pooler.parameters()))

    if not all(p.requires_grad == False for p in model.plm_model.parameters()):
        print("Including base PLM parameters in optimizer.")
        optimizer = optim.AdamW([
                {'params': params_to_optimize, 'lr': lr},
                {'params': model.plm_model.parameters(), 'lr': lr / 10.0}
            ], weight_decay=weight_decay)
    else:
        print("Optimizing only head and/or pooler parameters.")
        optimizer = optim.AdamW(params_to_optimize, lr=lr, weight_decay=weight_decay)

    scheduler = None
    if use_scheduler:
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)

    early_stopper = EarlyStopping(patience=patience)
    best_fold_model_state = None
    best_val_loss = float('inf')


    for epoch in range(num_epochs):
        model.train()
        train_losses = []
        for batch in train_loader:
            heavy_ids = batch["heavy_input_ids"].to(device)
            heavy_mask = batch["heavy_attention_mask"].to(device)
            light_ids = batch["light_input_ids"].to(device)
            light_mask = batch["light_attention_mask"].to(device)
            labels = batch["label"].to(device)

            optimizer.zero_grad()
            preds = model(heavy_ids, heavy_mask, light_ids, light_mask)
            loss = criterion(preds, labels)
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

        avg_train_loss = np.mean(train_losses)

        model.eval()
        val_losses = []
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for batch in val_loader:
                heavy_ids = batch["heavy_input_ids"].to(device)
                heavy_mask = batch["heavy_attention_mask"].to(device)
                light_ids = batch["light_input_ids"].to(device)
                light_mask = batch["light_attention_mask"].to(device)
                labels = batch["label"].to(device)

                preds = model(heavy_ids, heavy_mask, light_ids, light_mask)
                loss = criterion(preds, labels)
                val_losses.append(loss.item())

                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        avg_val_loss = np.mean(val_losses)
        if np.std(all_preds) > 1e-6 and np.std(all_labels) > 1e-6:
             spearman_corr, _ = spearmanr(all_preds, all_labels)
        else:
             spearman_corr = 0.0

        if isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
             scheduler.step(avg_val_loss)
        elif scheduler is not None:
             scheduler.step()

        print(f"[Epoch {epoch+1:02d}] Train Loss: {avg_train_loss:.4f} | "
              f"Val Loss: {avg_val_loss:.4f} | Spearman: {spearman_corr:.4f}")

        if avg_val_loss < best_val_loss:
             best_val_loss = avg_val_loss
             best_fold_model_state = model.state_dict()
             print(f"  -> New best validation loss: {best_val_loss:.4f}")

        early_stopper.step(avg_val_loss)
        if early_stopper.early_stop:
            print("Early stopping triggered!")
            break

    if best_fold_model_state:
        print("Loading best model state from training.")
        model.load_state_dict(best_fold_model_state)

    return model

In [14]:

igbert_model_name = "Exscientia/IgBERT"
igbert_hidden_dim = 1024

chosen_head_type = "deep"

In [None]:


tokenizer = AutoTokenizer.from_pretrained(igbert_model_name)

base_igbert_model = AutoModel.from_pretrained(igbert_model_name)

In [None]:


test_dataset = AntibodyDataset(test_df, tokenizer)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn) 

In [None]:
N_SPLITS = 5
kfold = KFold(n_splits=N_SPLITS, shuffle=True, random_state=42)

fold_results = []
trained_models = []

train_val_data = train_val_df.reset_index(drop=True)

print(f"\nStarting {N_SPLITS}-Fold Cross-Validation...")
for fold_idx, (train_index, val_index) in enumerate(kfold.split(train_val_data)):
    print(f"\n==== Fold {fold_idx+1}/{N_SPLITS} ====")

    fold_train_df = train_val_data.iloc[train_index]
    fold_val_df = train_val_data.iloc[val_index]

    train_dataset = AntibodyDataset(fold_train_df, tokenizer)
    val_dataset = AntibodyDataset(fold_val_df, tokenizer)

    batch_size = 8
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, pin_memory=True)

    print(f"DEBUG: Initializing RegressionModel with hidden_dim = {igbert_hidden_dim}")
    fold_model = RegressionModel(
        plm_model=base_igbert_model,
        hidden_dim=igbert_hidden_dim,
        head_type=chosen_head_type,
        dropout=0.1,
        use_attention_pool=True,
        freeze_plm=True
    )
    fold_model.to(device)

    trained_fold_model = train_one_fold(
        model=fold_model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=15,
        lr=5e-4,
        weight_decay=0.01,
        use_scheduler=True,
        patience=3
    )

    trained_fold_model.eval()
    val_preds = []
    val_labels = []
    criterion = nn.MSELoss()
    fold_val_losses = []

    with torch.no_grad():
        for batch in val_loader:
            heavy_ids = batch["heavy_input_ids"].to(device)
            heavy_mask = batch["heavy_attention_mask"].to(device)
            light_ids = batch["light_input_ids"].to(device)
            light_mask = batch["light_attention_mask"].to(device)
            labels = batch["label"].to(device)

            preds = trained_fold_model(heavy_ids, heavy_mask, light_ids, light_mask)
            loss = criterion(preds, labels)
            fold_val_losses.append(loss.item())
            val_preds.extend(preds.cpu().numpy())
            val_labels.extend(labels.cpu().numpy())

    avg_fold_val_loss = np.mean(fold_val_losses)

    if np.std(val_preds) > 1e-6 and np.std(val_labels) > 1e-6:
         fold_val_spearman, _ = spearmanr(val_preds, val_labels)
    else:
         fold_val_spearman = 0.0


    fold_results.append((avg_fold_val_loss, fold_val_spearman))
    trained_models.append(trained_fold_model.state_dict())
    print(f"Fold {fold_idx+1} Validation Loss: {avg_fold_val_loss:.4f}, Spearman: {fold_val_spearman:.4f}")

    del fold_model, trained_fold_model, train_loader, val_loader
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

In [None]:

cv_val_losses = [r[0] for r in fold_results]
cv_val_spearmans = [r[1] for r in fold_results]

avg_cv_loss = np.mean(cv_val_losses)
std_cv_loss = np.std(cv_val_losses)
avg_cv_spearman = np.mean(cv_val_spearmans)
std_cv_spearman = np.std(cv_val_spearmans)

print("\n--- Cross-Validation Summary ---")
print(f"Average Validation Loss: {avg_cv_loss:.4f} +/- {std_cv_loss:.4f}")
print(f"Average Validation Spearman: {avg_cv_spearman:.4f} +/- {std_cv_spearman:.4f}")
print("-" * 30)