In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence 


from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import StandardScaler
from scipy.stats import spearmanr, pearsonr


from transformers import AutoModel, AutoTokenizer, AutoConfig
from peft import LoraConfig, get_peft_model, TaskType 

In [2]:
IGBERT_MODEL_NAME = "Exscientia/IgBERT"
IGBERT_BASE_HIDDEN_DIM = 1024
MAX_LENGTH = 512

LORA_R = 8
LORA_ALPHA = 16
LORA_TARGET_MODULES = ["query", "value"]
LORA_DROPOUT = 0.1
LORA_BIAS = "none"

CHOSEN_HEAD_TYPE = "medium"
HEAD_DROPOUT = 0.1

NUM_EPOCHS = 15
BATCH_SIZE = 8
LEARNING_RATE = 5e-5
WEIGHT_DECAY = 0.01
USE_SCHEDULER = True
SCHEDULER_PATIENCE = 2
EARLY_STOPPING_PATIENCE = 5
MIN_DELTA_EARLY_STOP = 0.001

N_SPLITS_CV = 5

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)


In [None]:

if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")


In [None]:
df1 = pd.read_excel('pnas.1616408114.sd01.xlsx')
df2 = pd.read_excel('pnas.1616408114.sd02.xlsx') 
df3 = pd.read_excel('pnas.1616408114.sd03.xlsx')


merged_df = df1.merge(df2, on='Name', how='outer').merge(df3, on='Name', how='outer')
             

df = merged_df[['VH', 'VL', 'HEK Titer (mg/L)']].copy()


In [None]:
min_label_val = df["HEK Titer (mg/L)"].min()
if min_label_val <= 0:
    print(f"Warning: Minimum label value is {min_label_val}. Adding a small shift before log transform.")
    df["HEK Titer (mg/L)"] = df["HEK Titer (mg/L)"].apply(lambda x: x if x > 0 else 1e-6)

df["log_label"] = np.log(df["HEK Titer (mg/L)"])

scaler = StandardScaler()
df["scaled_label"] = scaler.fit_transform(df[["log_label"]])


In [None]:

train_val_df, test_df = train_test_split(df, test_size=0.2, random_state=SEED) 
print(f"\nTrain/Validation set size: {len(train_val_df)}")
print(f"Test set size: {len(test_df)}")

In [7]:
class AntibodyDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=MAX_LENGTH):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.df['VH'] = self.df['VH'].astype(str)
        self.df['VL'] = self.df['VL'].astype(str)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        heavy = row["VH"]
        light = row["VL"]
        label = row["scaled_label"]

        heavy_spaced = " ".join(list(heavy))
        light_spaced = " ".join(list(light))

        heavy_inputs = self.tokenizer(
            heavy_spaced,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors="pt"
        )
        light_inputs = self.tokenizer(
            light_spaced,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors="pt"
        )

        return {
            "heavy_input_ids": heavy_inputs["input_ids"].squeeze(0),
            "heavy_attention_mask": heavy_inputs["attention_mask"].squeeze(0),
            "light_input_ids": light_inputs["input_ids"].squeeze(0),
            "light_attention_mask": light_inputs["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.float)
        }

In [8]:

def collate_fn(batch):
    heavy_ids = torch.stack([item["heavy_input_ids"] for item in batch])
    heavy_masks = torch.stack([item["heavy_attention_mask"] for item in batch])
    light_ids = torch.stack([item["light_input_ids"] for item in batch])
    light_masks = torch.stack([item["light_attention_mask"] for item in batch])
    labels = torch.stack([item["label"] for item in batch])

    return {
        "heavy_input_ids": heavy_ids,
        "heavy_attention_mask": heavy_masks,
        "light_input_ids": light_ids,
        "light_attention_mask": light_masks,
        "label": labels
    }

In [9]:
class AttentionPooling(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.attention_net = nn.Linear(hidden_dim, 1)

    def forward(self, token_embeddings, attention_mask):
        att_scores = self.attention_net(token_embeddings).squeeze(-1)
        att_scores = att_scores.masked_fill(attention_mask == 0, -1e9)
        att_weights = torch.softmax(att_scores, dim=-1).unsqueeze(-1)
        pooled = torch.sum(token_embeddings * att_weights, dim=1)
        return pooled

class SimpleHead(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.fc = nn.Linear(input_dim, 1)
    def forward(self, x): return self.fc(x).squeeze(-1)

class MediumHead(nn.Module):
    def __init__(self, input_dim, dropout=HEAD_DROPOUT):
        super().__init__()
        self.regressor = nn.Sequential(
            nn.Linear(input_dim, 128), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(128, 1)
        )
    def forward(self, x): return self.regressor(x).squeeze(-1)

class DeepHead(nn.Module):
    def __init__(self, input_dim, dropout=HEAD_DROPOUT):
        super().__init__()
        self.regressor = nn.Sequential(
            nn.Linear(input_dim, 256), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(256, 128), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(128, 1)
        )
    def forward(self, x): return self.regressor(x).squeeze(-1)

class DeeperHead(nn.Module):
    def __init__(self, input_dim, dropout=HEAD_DROPOUT):
        super().__init__()
        self.regressor = nn.Sequential(
            nn.Linear(input_dim, 512), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(512, 256), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(256, 128), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(128, 64), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(64, 1)
        )
    def forward(self, x): return self.regressor(x).squeeze(-1)


class RegressionModelWithLoRA(nn.Module):
    def __init__(
        self,
        peft_plm_model,
        plm_native_hidden_dim,
        head_type=CHOSEN_HEAD_TYPE,
        dropout=HEAD_DROPOUT,
        use_attention_pool=True
    ):
        super().__init__()
        self.plm_model = peft_plm_model

        self.use_attention_pool = use_attention_pool
        if self.use_attention_pool:
            self.pooler = AttentionPooling(plm_native_hidden_dim)
        else:
            self.pooler = None

        head_input_dim = 2 * plm_native_hidden_dim
        if head_type == "simple":
            self.head = SimpleHead(head_input_dim)
        elif head_type == "medium":
            self.head = MediumHead(head_input_dim, dropout)
        elif head_type == "deep":
            self.head = DeepHead(head_input_dim, dropout)
        elif head_type == "deeper":
            self.head = DeeperHead(head_input_dim, dropout)
        else:
            raise ValueError(f"Invalid head type: {head_type}")
        
        print(f"Initialized RegressionModelWithLoRA:")
        print(f"  PLM Native Hidden Dim: {plm_native_hidden_dim}")
        print(f"  Using Attention Pooling: {self.use_attention_pool}")
        print(f"  Head Input Dim: {head_input_dim}")
        print(f"  Head Type: {head_type}")


    def forward(self, heavy_ids, heavy_mask, light_ids, light_mask):
        heavy_out = self.plm_model(input_ids=heavy_ids, attention_mask=heavy_mask)
        light_out = self.plm_model(input_ids=light_ids, attention_mask=light_mask)

        heavy_hidden = heavy_out.last_hidden_state
        light_hidden = light_out.last_hidden_state

        if self.pooler is not None:
            heavy_repr = self.pooler(heavy_hidden, heavy_mask)
            light_repr = self.pooler(light_hidden, light_mask)
        else:
            heavy_mask_f = heavy_mask.unsqueeze(-1).float()
            heavy_sum = (heavy_hidden * heavy_mask_f).sum(dim=1)
            heavy_len = heavy_mask_f.sum(dim=1).clamp(min=1e-9)
            heavy_repr = heavy_sum / heavy_len

            light_mask_f = light_mask.unsqueeze(-1).float()
            light_sum = (light_hidden * light_mask_f).sum(dim=1)
            light_len = light_mask_f.sum(dim=1).clamp(min=1e-9)
            light_repr = light_sum / light_len

        combined_repr = torch.cat([heavy_repr, light_repr], dim=1)
        predictions = self.head(combined_repr)
        return predictions

In [10]:

class EarlyStopping:
    def __init__(self, patience=EARLY_STOPPING_PATIENCE, min_delta=MIN_DELTA_EARLY_STOP, verbose=False):
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.counter = 0
        self.best_val_loss = float("inf")
        self.early_stop = False
        self.best_model_state = None

    def __call__(self, val_loss, model):
        if val_loss < self.best_val_loss - self.min_delta:
            self.best_val_loss = val_loss
            self.counter = 0
            self.best_model_state = model.state_dict() # Save best model state
            if self.verbose:
                print(f"EarlyStopping: New best validation loss: {self.best_val_loss:.4f}. Saving model state.")
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter}/{self.patience}. Best loss: {self.best_val_loss:.4f}")
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print("EarlyStopping: Stopping training.")

def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"  Trainable params: {trainable_params} || All params: {all_param} || Trainable %: {100 * trainable_params / all_param:.2f}"
    )

In [11]:
def train_one_fold(
    model,
    train_loader,
    val_loader,
    optimizer,
    criterion,
    scheduler,
    early_stopper,
    num_epochs=NUM_EPOCHS,
    fold_num=0
):
    fold_train_losses = []
    fold_val_losses = []
    fold_val_spearman = []

    for epoch in range(num_epochs):
        model.train()
        epoch_train_losses = []
        for batch_idx, batch in enumerate(train_loader):
            heavy_ids = batch["heavy_input_ids"].to(device)
            heavy_mask = batch["heavy_attention_mask"].to(device)
            light_ids = batch["light_input_ids"].to(device)
            light_mask = batch["light_attention_mask"].to(device)
            labels = batch["label"].to(device)

            optimizer.zero_grad()
            preds = model(heavy_ids, heavy_mask, light_ids, light_mask)
            loss = criterion(preds, labels)
            loss.backward()
            optimizer.step()
            epoch_train_losses.append(loss.item())
        
        avg_epoch_train_loss = np.mean(epoch_train_losses)
        fold_train_losses.append(avg_epoch_train_loss)

        model.eval()
        epoch_val_losses = []
        all_val_preds = []
        all_val_labels = []
        with torch.no_grad():
            for batch in val_loader:
                heavy_ids = batch["heavy_input_ids"].to(device)
                heavy_mask = batch["heavy_attention_mask"].to(device)
                light_ids = batch["light_input_ids"].to(device)
                light_mask = batch["light_attention_mask"].to(device)
                labels = batch["label"].to(device)

                preds = model(heavy_ids, heavy_mask, light_ids, light_mask)
                loss = criterion(preds, labels)
                epoch_val_losses.append(loss.item())
                all_val_preds.extend(preds.cpu().numpy())
                all_val_labels.extend(labels.cpu().numpy())

        avg_epoch_val_loss = np.mean(epoch_val_losses)
        fold_val_losses.append(avg_epoch_val_loss)
        
        current_spearman = 0.0
        if len(all_val_preds) > 1 and np.std(all_val_preds) > 1e-6 and np.std(all_val_labels) > 1e-6:
            current_spearman, _ = spearmanr(all_val_preds, all_val_labels)
        fold_val_spearman.append(current_spearman)

        print(f"[Fold {fold_num+1}/Epoch {epoch+1:02d}] Train Loss: {avg_epoch_train_loss:.4f} | "
              f"Val Loss: {avg_epoch_val_loss:.4f} | Val Spearman: {current_spearman:.4f} | "
              f"LR: {optimizer.param_groups[0]['lr']:.2e}")

        if USE_SCHEDULER and isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(avg_epoch_val_loss)
        elif USE_SCHEDULER and scheduler is not None:
            scheduler.step()
            
        early_stopper(avg_epoch_val_loss, model)
        if early_stopper.early_stop:
            print(f"Early stopping triggered at epoch {epoch+1} for fold {fold_num+1}.")
            break
            
    history = {
        'train_loss': fold_train_losses,
        'val_loss': fold_val_losses,
        'val_spearman': fold_val_spearman,
        'best_val_loss': early_stopper.best_val_loss
    }
    return early_stopper.best_model_state, history

In [12]:
def evaluate_model(model, data_loader, criterion, device, scaler_obj, data_type="Test"):
    model.eval()
    all_preds_scaled_list = []
    all_labels_scaled_list = []
    total_loss = 0

    with torch.no_grad():
        for batch in data_loader:
            heavy_ids = batch["heavy_input_ids"].to(device)
            heavy_mask = batch["heavy_attention_mask"].to(device)
            light_ids = batch["light_input_ids"].to(device)
            light_mask = batch["light_attention_mask"].to(device)
            labels_scaled = batch["label"].to(device)

            preds_scaled = model(heavy_ids, heavy_mask, light_ids, light_mask)
            loss = criterion(preds_scaled, labels_scaled)
            total_loss += loss.item() * labels_scaled.size(0) 

            all_preds_scaled_list.append(preds_scaled.cpu().numpy())
            all_labels_scaled_list.append(labels_scaled.cpu().numpy())

    if not all_preds_scaled_list:
        print(f"Warning: No data in {data_type} loader for evaluation.")
        return {
            "loss_scaled": float('nan'), "spearman_scaled": float('nan'), "pearson_scaled": float('nan'),
            "spearman_original": float('nan'), "pearson_original": float('nan'),
            "preds_scaled": [], "labels_scaled": [], "preds_original": [], "labels_original": []
        }

    all_preds_scaled = np.concatenate(all_preds_scaled_list)
    all_labels_scaled = np.concatenate(all_labels_scaled_list)
    
    avg_loss = total_loss / len(all_labels_scaled) if len(all_labels_scaled) > 0 else float('nan')
        
    all_preds_scaled_2d = all_preds_scaled.reshape(-1, 1)
    all_labels_scaled_2d = all_labels_scaled.reshape(-1, 1)

    all_preds_log = scaler_obj.inverse_transform(all_preds_scaled_2d).flatten()
    all_labels_log = scaler_obj.inverse_transform(all_labels_scaled_2d).flatten()

    all_preds_original = np.exp(all_preds_log)
    all_labels_original = np.exp(all_labels_log)

    spearman_corr_scaled, _ = spearmanr(all_preds_scaled, all_labels_scaled) if len(all_preds_scaled) > 1 and np.std(all_preds_scaled) > 1e-6 and np.std(all_labels_scaled) > 1e-6 else (0.0, 0.0)
    pearson_corr_scaled, _ = pearsonr(all_preds_scaled, all_labels_scaled) if len(all_preds_scaled) > 1 and np.std(all_preds_scaled) > 1e-6 and np.std(all_labels_scaled) > 1e-6 else (0.0, 0.0)
    
    spearman_corr_original, _ = spearmanr(all_preds_original, all_labels_original) if len(all_preds_original) > 1 and np.std(all_preds_original) > 1e-6 and np.std(all_labels_original) > 1e-6 else (0.0, 0.0)
    pearson_corr_original, _ = pearsonr(all_preds_original, all_labels_original) if len(all_preds_original) > 1 and np.std(all_preds_original) > 1e-6 and np.std(all_labels_original) > 1e-6 else (0.0, 0.0)


    print(f"\n--- {data_type} Set Evaluation ---")
    print(f"  Avg Loss (scaled): {avg_loss:.4f}")
    print(f"  Spearman Correlation (scaled): {spearman_corr_scaled:.4f}")
    print(f"  Pearson Correlation (scaled): {pearson_corr_scaled:.4f}")
    print(f"  Spearman Correlation (original scale): {spearman_corr_original:.4f}")
    print(f"  Pearson Correlation (original scale): {pearson_corr_original:.4f}")
    
    results = {
        "loss_scaled": avg_loss,
        "spearman_scaled": spearman_corr_scaled,
        "pearson_scaled": pearson_corr_scaled,
        "spearman_original": spearman_corr_original,
        "pearson_original": pearson_corr_original,
        "preds_scaled": all_preds_scaled.tolist(),
        "labels_scaled": all_labels_scaled.tolist(),
        "preds_original": all_preds_original.tolist(),
        "labels_original": all_labels_original.tolist()
    }
    return results

In [None]:
tokenizer = AutoTokenizer.from_pretrained(IGBERT_MODEL_NAME)

if not test_df.empty:
    test_dataset = AntibodyDataset(test_df, tokenizer, max_length=MAX_LENGTH)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
else:
    test_loader = None

kfold = KFold(n_splits=N_SPLITS_CV, shuffle=True, random_state=SEED)
cv_fold_results_summary = []
cv_best_model_states = []
all_fold_histories = []

train_val_data_reset = train_val_df.reset_index(drop=True)

for fold_idx, (train_indices, val_indices) in enumerate(kfold.split(train_val_data_reset)):
    fold_train_df = train_val_data_reset.iloc[train_indices]
    fold_val_df = train_val_data_reset.iloc[val_indices]

    train_dataset_fold = AntibodyDataset(fold_train_df, tokenizer, max_length=MAX_LENGTH)
    val_dataset_fold = AntibodyDataset(fold_val_df, tokenizer, max_length=MAX_LENGTH)

    train_loader_fold = DataLoader(train_dataset_fold, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, pin_memory=True if device.type == 'cuda' else False, num_workers=0)
    val_loader_fold = DataLoader(val_dataset_fold, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, pin_memory=True if device.type == 'cuda' else False, num_workers=0)

    base_igbert_model_fold = AutoModel.from_pretrained(IGBERT_MODEL_NAME)

    lora_config = LoraConfig(
        r=LORA_R,
        lora_alpha=LORA_ALPHA,
        target_modules=LORA_TARGET_MODULES,
        lora_dropout=LORA_DROPOUT,
        bias=LORA_BIAS,
        task_type=None
    )

    peft_igbert_model_fold = get_peft_model(base_igbert_model_fold, lora_config)
    
    fold_model = RegressionModelWithLoRA(
        peft_plm_model=peft_igbert_model_fold,
        plm_native_hidden_dim=IGBERT_BASE_HIDDEN_DIM,
        head_type=CHOSEN_HEAD_TYPE,
        dropout=HEAD_DROPOUT,
        use_attention_pool=True
    ).to(device)

    print_trainable_parameters(fold_model)

    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, fold_model.parameters()), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    criterion = nn.MSELoss()
    scheduler = None
    if USE_SCHEDULER:
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=SCHEDULER_PATIENCE, verbose=False)
    
    early_stopper = EarlyStopping(patience=EARLY_STOPPING_PATIENCE, min_delta=MIN_DELTA_EARLY_STOP, verbose=True)

    best_fold_model_state, fold_history = train_one_fold(
        model=fold_model,
        train_loader=train_loader_fold,
        val_loader=val_loader_fold,
        optimizer=optimizer,
        criterion=criterion,
        scheduler=scheduler,
        early_stopper=early_stopper,
        num_epochs=NUM_EPOCHS,
        fold_num=fold_idx
    )
    
    all_fold_histories.append(fold_history)
    if best_fold_model_state:
        cv_best_model_states.append(best_fold_model_state)
        best_val_loss_this_fold = fold_history['best_val_loss']
        best_val_loss_epoch_idx = fold_history['val_loss'].index(best_val_loss_this_fold)
        best_val_spearman_this_fold = fold_history['val_spearman'][best_val_loss_epoch_idx]
        
        cv_fold_results_summary.append({
            "fold": fold_idx + 1,
            "best_val_loss": best_val_loss_this_fold,
            "best_val_spearman": best_val_spearman_this_fold
        })
    else:
        print(f"Warning: No best model state saved for fold {fold_idx+1}. Training might have been too short or an issue occurred.")

    del fold_model, peft_igbert_model_fold, base_igbert_model_fold, optimizer, train_loader_fold, val_loader_fold
    if device.type == 'cuda':
        torch.cuda.empty_cache()
    elif device.type == 'mps':
        torch.mps.empty_cache()

if cv_fold_results_summary:
    avg_cv_loss = np.mean([r['best_val_loss'] for r in cv_fold_results_summary])
    std_cv_loss = np.std([r['best_val_loss'] for r in cv_fold_results_summary])
    avg_cv_spearman = np.mean([r['best_val_spearman'] for r in cv_fold_results_summary])
    std_cv_spearman = np.std([r['best_val_spearman'] for r in cv_fold_results_summary])

if all_fold_histories:
    plt.figure(figsize=(12, 6))
    for i, history in enumerate(all_fold_histories):
        plt.plot(history['val_loss'], label=f'Fold {i+1} Val Loss')
    plt.title('Validation Loss per Fold')
    plt.xlabel('Epoch')
    plt.ylabel('Validation Loss (MSE)')
    plt.legend()
    plt.grid(True)
    plt.show()

if test_loader and cv_best_model_states:
    all_test_preds_scaled_ensemble = []

    for fold_idx, model_state in enumerate(cv_best_model_states):
        base_igbert_test = AutoModel.from_pretrained(IGBERT_MODEL_NAME)
        lora_config_test = LoraConfig(r=LORA_R, lora_alpha=LORA_ALPHA, target_modules=LORA_TARGET_MODULES, lora_dropout=LORA_DROPOUT, bias=LORA_BIAS)
        peft_igbert_test = get_peft_model(base_igbert_test, lora_config_test)
        
        test_eval_model = RegressionModelWithLoRA(
            peft_plm_model=peft_igbert_test,
            plm_native_hidden_dim=IGBERT_BASE_HIDDEN_DIM,
            head_type=CHOSEN_HEAD_TYPE
        ).to(device)
        
        test_eval_model.load_state_dict(model_state)
        test_eval_model.eval()
        
        fold_test_preds_scaled = []
        with torch.no_grad():
            for batch in test_loader:
                heavy_ids = batch["heavy_input_ids"].to(device)
                heavy_mask = batch["heavy_attention_mask"].to(device)
                light_ids = batch["light_input_ids"].to(device)
                light_mask = batch["light_attention_mask"].to(device)
                
                preds_scaled = test_eval_model(heavy_ids, heavy_mask, light_ids, light_mask)
                fold_test_preds_scaled.extend(preds_scaled.cpu().numpy())
        all_test_preds_scaled_ensemble.append(fold_test_preds_scaled)

        del test_eval_model, peft_igbert_test, base_igbert_test
        if device.type == 'cuda': torch.cuda.empty_cache()
        elif device.type == 'mps': torch.mps.empty_cache()

    if all_test_preds_scaled_ensemble:
        avg_test_preds_scaled = np.mean(np.array(all_test_preds_scaled_ensemble), axis=0)
        
        true_test_labels_scaled = []
        for batch in test_loader:
            true_test_labels_scaled.extend(batch["label"].cpu().numpy())
        
        test_mse_scaled = np.mean((avg_test_preds_scaled - np.array(true_test_labels_scaled))**2)
        
        test_spearman_scaled, _ = spearmanr(avg_test_preds_scaled, true_test_labels_scaled) if len(avg_test_preds_scaled) > 1 else (0,0)
        test_pearson_scaled, _ = pearsonr(avg_test_preds_scaled, true_test_labels_scaled) if len(avg_test_preds_scaled) > 1 else (0,0)

        avg_test_preds_scaled_2d = np.array(avg_test_preds_scaled).reshape(-1, 1)
        true_test_labels_scaled_2d = np.array(true_test_labels_scaled).reshape(-1, 1)

        avg_test_preds_log = scaler.inverse_transform(avg_test_preds_scaled_2d).flatten()
        true_test_labels_log = scaler.inverse_transform(true_test_labels_scaled_2d).flatten()
        
        avg_test_preds_original = np.exp(avg_test_preds_log)
        true_test_labels_original = np.exp(true_test_labels_log)

        test_spearman_original, _ = spearmanr(avg_test_preds_original, true_test_labels_original) if len(avg_test_preds_original) > 1 else (0,0)
        test_pearson_original, _ = pearsonr(avg_test_preds_original, true_test_labels_original) if len(avg_test_preds_original) > 1 else (0,0)
        test_mse_original = np.mean((avg_test_preds_original - true_test_labels_original)**2)

        plt.figure(figsize=(8, 8))
        sns.scatterplot(x=true_test_labels_original, y=avg_test_preds_original, alpha=0.7)
        plt.plot([min(true_test_labels_original.min(), avg_test_preds_original.min()), 
                  max(true_test_labels_original.max(), avg_test_preds_original.max())],
                 [min(true_test_labels_original.min(), avg_test_preds_original.min()),
                  max(true_test_labels_original.max(), avg_test_preds_original.max())], 
                 color='red', linestyle='--')
        plt.xlabel("True HEK Titer (mg/L) - Test Set")
        plt.ylabel("Predicted HEK Titer (mg/L) - Test Set (Ensemble)")
        plt.title(f"Test Set: True vs. Predicted (Original Scale)\nSpearman: {test_spearman_original:.3f}, Pearson: {test_pearson_original:.3f}")
        plt.grid(True)
        plt.show()