In [None]:
import os
import math
import random
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import esm
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from scipy.stats import spearmanr

In [None]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class SeqRegressionDataset(Dataset):
    def __init__(self, seqs, ys):
        self.seqs = list(seqs)
        self.ys = np.asarray(ys, dtype=np.float32)

    def __len__(self):
        return len(self.seqs)

    def __getitem__(self, idx):
        return self.seqs[idx], self.ys[idx]

def make_collate_fn(alphabet):
    batch_converter = alphabet.get_batch_converter()
    pad_idx = alphabet.padding_idx

    def collate(batch):
        seqs, ys = zip(*batch)
        data = [(f"protein{i}", s) for i, s in enumerate(seqs)]
        _, _, tokens = batch_converter(data)  

        not_pad = tokens.ne(pad_idx)         
        valid_counts = not_pad.sum(dim=1)    
        lengths = (valid_counts - 2).clamp(min=0) 

        B, T = tokens.shape
        residue_mask = torch.zeros((B, T), dtype=torch.bool)
        for b in range(B):
            L = int(lengths[b].item())
            if L > 0:
                residue_mask[b, 1:1+L] = True

        ys_t = torch.tensor(ys, dtype=torch.float32)
        return tokens, residue_mask, lengths, ys_t

    return collate


class GlobalAttnPlusCTermAttnRegressor(nn.Module):
    def __init__(
        self,
        esm_model,
        repr_layer=33,
        emb_dim=1280,
        cterm_k=6,
        attn_hidden=256,
        cterm_attn_hidden=128,
        head_hidden=(512, 256),
        dropout=0.2,
    ):
        super().__init__()
        self.esm = esm_model
        self.repr_layer = repr_layer
        self.cterm_k = cterm_k

        self.global_attn = nn.Sequential(
            nn.Linear(emb_dim, attn_hidden),
            nn.Tanh(),
            nn.Linear(attn_hidden, 1),
        )

        self.cterm_attn = nn.Sequential(
            nn.Linear(emb_dim, cterm_attn_hidden),
            nn.Tanh(),
            nn.Linear(cterm_attn_hidden, 1),
        )

        layers = []
        in_dim = emb_dim * 2
        for h in head_hidden:
            layers += [nn.Linear(in_dim, h), nn.ReLU(), nn.Dropout(dropout)]
            in_dim = h
        layers += [nn.Linear(in_dim, 1)]
        self.head = nn.Sequential(*layers)

        for p in self.esm.parameters():
            p.requires_grad = False
        self.esm.eval()

    @torch.no_grad()
    def _esm_forward(self, tokens):
        out = self.esm(tokens, repr_layers=[self.repr_layer], return_contacts=False)
        return out["representations"][self.repr_layer] 

    def forward(self, tokens, residue_mask, lengths):
        H = self._esm_forward(tokens)  

        g_logits = self.global_attn(H).squeeze(-1)                 
        g_logits = g_logits.masked_fill(~residue_mask, -1e9)
        g_w = F.softmax(g_logits, dim=1)                          
        global_vec = torch.sum(H * g_w.unsqueeze(-1), dim=1)      

        B, T, D = H.shape
        cterm_mask = torch.zeros((B, T), dtype=torch.bool, device=H.device)
        for b in range(B):
            L = int(lengths[b].item())
            if L <= 0:
                continue
            k = min(self.cterm_k, L)
            start = 1 + (L - k)    
            end = 1 + L
            cterm_mask[b, start:end] = True

        c_logits = self.cterm_attn(H).squeeze(-1)                  
        c_logits = c_logits.masked_fill(~cterm_mask, -1e9)
        c_w = F.softmax(c_logits, dim=1)
        cterm_vec = torch.sum(H * c_w.unsqueeze(-1), dim=1)       

        feat = torch.cat([global_vec, cterm_vec], dim=1)          
        y_hat = self.head(feat).squeeze(-1)                        
        return y_hat


def collect_preds(model, loader, device, clip_01=True):
    model.eval()
    ys_true, ys_pred = [], []
    with torch.no_grad():
        for tokens, residue_mask, lengths, ys in loader:
            tokens = tokens.to(device)
            residue_mask = residue_mask.to(device)
            lengths = lengths.to(device)

            pred = model(tokens, residue_mask, lengths)  
            pred = pred.detach().cpu().numpy()

            if clip_01:
                pred = np.clip(pred, 0.0, 1.0)

            ys_true.append(ys.numpy())
            ys_pred.append(pred)

    y_true = np.concatenate(ys_true)
    y_pred = np.concatenate(ys_pred)
    return y_true, y_pred


def eval_metrics(y_true, y_pred):
    mae = mean_absolute_error(y_true, y_pred)
    rmse = mean_squared_error(y_true, y_pred, squared=False)
    r2 = r2_score(y_true, y_pred)
    rho, p = spearmanr(y_true, y_pred)
    return mae, rmse, r2, rho, p


def pairwise_ranking_loss(pred, y, margin=0.05, n_pairs=2048, min_label_diff=0.02):
    B = y.size(0)
    if B < 2:
        return pred.new_tensor(0.0)

    i = torch.randint(0, B, (n_pairs,), device=pred.device)
    j = torch.randint(0, B, (n_pairs,), device=pred.device)

    yi, yj = y[i], y[j]
    pi, pj = pred[i], pred[j]

    diff = yi - yj
    keep = diff.abs() > min_label_diff
    if keep.sum() < 1:
        return pred.new_tensor(0.0)

    diff = diff[keep]
    pi, pj = pi[keep], pj[keep]
    s = torch.sign(diff)  # +1 if yi>yj else -1

    # hinge ranking loss
    loss = torch.relu(margin - s * (pi - pj)).mean()
    return loss


def train_one_fold(
    model,
    train_loader,
    val_loader,
    device,
    lr=1e-3,
    weight_decay=1e-4,
    max_epochs=40,
    patience=6,
    use_ranking_loss=True,
    rank_lambda=0.8,       
    rank_margin=0.05,
    rank_pairs=2048,
    min_label_diff=0.02,
):
    model.to(device)

    mse_loss = nn.MSELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    best_val_rho = -1e9
    best_state = None
    no_improve = 0

    for epoch in range(1, max_epochs + 1):
        model.train()
        total_loss = 0.0
        n = 0

        for tokens, residue_mask, lengths, ys in train_loader:
            tokens = tokens.to(device)
            residue_mask = residue_mask.to(device)
            lengths = lengths.to(device)
            ys = ys.to(device)

            pred_raw = model(tokens, residue_mask, lengths)  
            if use_ranking_loss:
                loss_rank = pairwise_ranking_loss(
                    pred_raw, ys,
                    margin=rank_margin,
                    n_pairs=rank_pairs,
                    min_label_diff=min_label_diff
                )
                loss_mse = mse_loss(pred_raw, ys)
                loss = rank_lambda * loss_rank + (1.0 - rank_lambda) * loss_mse
            else:
                loss = mse_loss(pred_raw, ys)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item() * ys.size(0)
            n += ys.size(0)

        train_loss = total_loss / max(n, 1)

        y_val_true, y_val_pred = collect_preds(model, val_loader, device, clip_01=True)
        val_mae, val_rmse, val_r2, val_rho, val_p = eval_metrics(y_val_true, y_val_pred)

        print(
            f"  Epoch {epoch:02d} | train_loss={train_loss:.4f} | "
            f"val_Spearmanρ={val_rho:.4f} (p={val_p:.2e}) | "
            f"val_R2={val_r2:.4f} val_MAE={val_mae:.4f} val_RMSE={val_rmse:.4f}"
        )

        if val_rho > best_val_rho + 1e-4:
            best_val_rho = val_rho
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    y_val_true, y_val_pred = collect_preds(model, val_loader, device, clip_01=True)
    val_mae, val_rmse, val_r2, val_rho, val_p = eval_metrics(y_val_true, y_val_pred)
    return model, (val_mae, val_rmse, val_r2, val_rho, val_p)


def run_5fold_cv(
    sequences,
    y,
    device,
    n_splits=5,
    seed=42,
    batch_size=8,
    cterm_k=6,
    lr=1e-3,
    weight_decay=1e-4,
    max_epochs=40,
    patience=6,
    save_dir=None,
    y_bin_q=10,            
    use_ranking_loss=True,
    rank_lambda=0.8,
    rank_margin=0.05,
    rank_pairs=2048,
    min_label_diff=0.02,
):
    set_seed(seed)

    y_bins = pd.qcut(y, q=y_bin_q, labels=False, duplicates="drop")

    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)

    fold_results = []
    for fold, (tr_idx, va_idx) in enumerate(skf.split(np.zeros(len(y)), y_bins), start=1):
        print(f"\n==== Fold {fold}/{n_splits} ====")

        X_tr = [sequences[i] for i in tr_idx]
        y_tr = y[tr_idx]
        X_va = [sequences[i] for i in va_idx]
        y_va = y[va_idx]

        esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        collate_fn = make_collate_fn(alphabet)

        train_ds = SeqRegressionDataset(X_tr, y_tr)
        val_ds = SeqRegressionDataset(X_va, y_va)

        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
        val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

        model = GlobalAttnPlusCTermAttnRegressor(
            esm_model=esm_model,
            repr_layer=33,
            emb_dim=1280,
            cterm_k=cterm_k,
            attn_hidden=256,
            cterm_attn_hidden=128,
            head_hidden=(512, 256),
            dropout=0.2,
        )

        model, metrics = train_one_fold(
            model,
            train_loader,
            val_loader,
            device=device,
            lr=lr,
            weight_decay=weight_decay,
            max_epochs=max_epochs,
            patience=patience,
            use_ranking_loss=use_ranking_loss,
            rank_lambda=rank_lambda,
            rank_margin=rank_margin,
            rank_pairs=rank_pairs,
            min_label_diff=min_label_diff,
        )

        val_mae, val_rmse, val_r2, val_rho, val_p = metrics
        fold_results.append(metrics)
        print(
            f"Fold {fold} best | Spearmanρ={val_rho:.4f} (p={val_p:.2e}) | "
            f"val_R2={val_r2:.4f} val_MAE={val_mae:.4f} val_RMSE={val_rmse:.4f}"
        )

        if save_dir is not None:
            os.makedirs(save_dir, exist_ok=True)
            out_path = os.path.join(save_dir, f"fold{fold}_best_head.pth")
            torch.save(model.state_dict(), out_path)
            print("  Saved:", out_path)

    fold_results = np.array(fold_results, dtype=float) 

    mae_mean, rmse_mean, r2_mean, rho_mean = (
        fold_results[:, 0].mean(),
        fold_results[:, 1].mean(),
        fold_results[:, 2].mean(),
        fold_results[:, 3].mean(),
    )
    mae_std, rmse_std, r2_std, rho_std = (
        fold_results[:, 0].std(ddof=1),
        fold_results[:, 1].std(ddof=1),
        fold_results[:, 2].std(ddof=1),
        fold_results[:, 3].std(ddof=1),
    )

    print("\n==== CV Summary (mean ± std) ====")
    print(f"Spearmanρ  = {rho_mean:.4f} ± {rho_std:.4f}")
    print(f"val_R2     = {r2_mean:.4f} ± {r2_std:.4f}")
    print(f"val_MAE    = {mae_mean:.4f} ± {mae_std:.4f}")
    print(f"val_RMSE   = {rmse_mean:.4f} ± {rmse_std:.4f}")

    return fold_results


In [None]:
if __name__ == "__main__":
    set_seed(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)

    train_data = pd.read_csv("./Regression_train_data.csv")
    df = train_data.copy()
    sequences = df["protein_sequence"].astype(str).tolist()
    y = np.clip(df["mean_stability"].astype(float).to_numpy(), 0.0, 1.0)

    results = run_5fold_cv(
        sequences=sequences,
        y=y,
        device=device,
        n_splits=5,
        seed=42,
        batch_size=16,          
        cterm_k=6,
        lr=1e-3,
        weight_decay=1e-4,
        max_epochs=40,
        patience=6,
        save_dir="./cv_models",
        y_bin_q=6,              
        use_ranking_loss=True,
        rank_lambda=0.75,      
        rank_margin=0.05,
        rank_pairs=4096,     
        min_label_diff=0.02,
    )