In [None]:
# GNN only model

import numpy as np
import pandas as pd
from pathlib import Path
import re
import random
import os
import warnings
import matplotlib.pyplot as plt
import seaborn as sns

from scipy.spatial import KDTree
from scipy.sparse import coo_matrix
from sklearn.model_selection import GroupKFold
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr, spearmanr

import torch
from torch.utils.data import DataLoader, Dataset
from torch import nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

from Bio.PDB import PDBParser, Polypeptide

def get_device():
    if torch.backends.mps.is_available():
        return "mps"
    if torch.cuda.is_available():
        return "cuda"
    return "cpu"

OPTIMAL_RADIUS = 8.0
OPTIMAL_HIDDEN_DIM = 96
OPTIMAL_DROPOUT = 0.4

SEED = 42
N_SPLITS = 5          
BATCH_SIZE = 16
MAX_EPOCHS = 50       
LEARNING_RATE = 1e-3

EARLY_STOPPING_PATIENCE = 15

np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)



mutation_data_path = 'mutations_full_info.pkl' # data stored in pickle format, with columns: mutation, scaled_activity, group, sequence
mutations = pd.read_pickle(mutation_data_path)

y = mutations["scaled_activity"].to_numpy(dtype=np.float32)
groups = mutations["mutation"].str.extract(r"[A-Z](\d+)[A-Z]").astype(int)[0].to_numpy()

def custom_three_to_one(residue_name: str) -> str:
    _3to1_map = {'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'}
    return _3to1_map.get(residue_name.upper(), 'X')

WT_PDB_PATH = "wt_af.pdb" # wild-type protein structure in pdb format predicted by AlphaFold2
parser = PDBParser(QUIET=True)
structure = parser.get_structure("wt", WT_PDB_PATH)
chain = next(structure.get_chains())
residues = [res for res in chain.get_residues() if Polypeptide.is_aa(res, standard=True)]
coords = np.array([res["CA"].get_coord() for res in residues], dtype=np.float32)
wt_seq = "".join(custom_three_to_one(res.get_resname()) for res in residues)

AA_CODES = "ACDEFGHIKLMNPQRSTVWY"
def aa_onehot(code):
    vec = np.zeros(len(AA_CODES), dtype=np.float32)
    vec[AA_CODES.index(code)] = 1.0
    return vec

BASE_NODE_FEATURES = np.stack([aa_onehot(a) for a in wt_seq])

all_node_features = []
mut_pat = re.compile(r"([A-Z])(\d+)([A-Z])")
for _, row in mutations.iterrows():
    mut = row["mutation"]
    match = mut_pat.fullmatch(mut)
    if not match: continue
    
    wt, pos, mut_aa = match.groups()
    idx = int(pos) - 1
    
    x = BASE_NODE_FEATURES.copy()
    mutation_indicator = np.zeros((len(wt_seq), 1), dtype=np.float32)
    mutation_indicator[idx] = 1.0
    mut_aa_onehot = np.zeros((len(wt_seq), len(AA_CODES)), dtype=np.float32)
    mut_aa_onehot[idx] = aa_onehot(mut_aa)
    
    combined_features = np.hstack([x, mutation_indicator, mut_aa_onehot])
    all_node_features.append(combined_features)



def create_radius_adjacency_matrix(coords, radius=10.0):
    tree = KDTree(coords)
    pairs = tree.query_pairs(r=radius)
    row_indices, col_indices = zip(*pairs) if pairs else ([], [])
    
    all_rows = np.concatenate([row_indices, col_indices, np.arange(len(coords))])
    all_cols = np.concatenate([col_indices, row_indices, np.arange(len(coords))])
    
    adj = coo_matrix((np.ones(len(all_rows)), (all_rows, all_cols)), shape=(len(coords), len(coords)))
    return adj.tocsr()



class GraphAttentionLayer(nn.Module):
    def __init__(self, in_features, out_features, dropout=0.2, alpha=0.2):
        super().__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha

        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.zeros(size=(2 * out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        self.leakyrelu = nn.LeakyReLU(alpha)

    def forward(self, h, adj):
        Wh = torch.mm(h, self.W)
        a_input = self._prepare_attentional_mechanism_input(Wh)
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))

        zero_vec = -9e15 * torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.matmul(attention, Wh)
        return F.elu(h_prime)

    def _prepare_attentional_mechanism_input(self, Wh):
        N = Wh.size()[0]
        Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=0)
        Wh_repeated_alternating = Wh.repeat(N, 1)
        all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1)
        return all_combinations_matrix.view(N, N, 2 * self.out_features)

class MutGraphNet(nn.Module):
    def __init__(self, in_dim, hidden_dim, dropout):
        super().__init__()
        self.gat1 = GraphAttentionLayer(in_dim, hidden_dim, dropout=dropout)
        self.gat2 = GraphAttentionLayer(hidden_dim, hidden_dim, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)
        self.fc = nn.Linear(hidden_dim, 1)

    def forward(self, x, adj):
        x = self.gat1(x, adj)
        x = self.dropout(x)
        x = self.gat2(x, adj)
        x = x.mean(dim=0)
        return self.fc(x)

class ProteinGraphDataset(Dataset):
    def __init__(self, features_list, labels):
        self.features_list = features_list
        self.labels = labels

    def __len__(self):
        return len(self.features_list)

    def __getitem__(self, idx):
        return torch.tensor(self.features_list[idx], dtype=torch.float32), torch.tensor(self.labels[idx], dtype=torch.float32)



def train_and_evaluate_gnn(features, labels, groups, adj, params, patience):
    kf = GroupKFold(n_splits=N_SPLITS)
    oof_preds = np.zeros(len(labels))
    device = get_device()


    adj_torch = torch.from_numpy(adj.todense().astype(np.float32)).to(device)

    fold_metrics_history = {
        'rmse': [], 'r2': [], 'pearson': [], 'spearman': []
    }

    for fold, (train_idx, val_idx) in enumerate(kf.split(features, labels, groups)):

        model = MutGraphNet(
            in_dim=features[0].shape[1],
            hidden_dim=params['hidden_dim'],
            dropout=params['dropout']
        ).to(device)
        
        optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'], weight_decay=1e-5)
        scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)

        train_dataset = ProteinGraphDataset([features[i] for i in train_idx], labels[train_idx])
        val_dataset = ProteinGraphDataset([features[i] for i in val_idx], labels[val_idx])
        loader_tr = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        loader_va = DataLoader(val_dataset, batch_size=BATCH_SIZE)

        best_val_loss = float('inf')
        patience_counter = 0
        best_model_path = f"best_model_fold_{fold+1}.pth"

        for epoch in range(MAX_EPOCHS):
            model.train()
            total_train_loss = 0
            for x_batch, y_batch in loader_tr:
                for i in range(x_batch.size(0)):
                    x, y = x_batch[i].to(device), y_batch[i].to(device)
                    optimizer.zero_grad()
                    output = model(x, adj_torch).squeeze()
                    loss = F.mse_loss(output, y)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    optimizer.step()
                    total_train_loss += loss.item()
            avg_train_loss = total_train_loss / len(train_dataset)

            model.eval()
            total_val_loss = 0
            with torch.no_grad():
                for x_batch, y_batch in loader_va:
                    for i in range(x_batch.size(0)):
                        x, y = x_batch[i].to(device), y_batch[i].to(device)
                        output = model(x, adj_torch).squeeze()
                        total_val_loss += F.mse_loss(output, y).item()
            avg_val_loss = total_val_loss / len(val_dataset)
            scheduler.step(avg_val_loss)

            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                patience_counter = 0
                torch.save(model.state_dict(), best_model_path)
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}.")
                break
        
        model.load_state_dict(torch.load(best_model_path))
        model.eval()
        fold_preds_list = []
        fold_true_labels = labels[val_idx]
        with torch.no_grad():
            for x_batch, _ in loader_va:
                 for i in range(x_batch.size(0)):
                    x = x_batch[i].to(device)
                    fold_preds_list.append(model(x, adj_torch).squeeze().cpu().numpy())
        
        fold_preds = np.array(fold_preds_list).flatten()
        oof_preds[val_idx] = fold_preds

        fold_metrics_history['rmse'].append(np.sqrt(mean_squared_error(fold_true_labels, fold_preds)))
        fold_metrics_history['r2'].append(r2_score(fold_true_labels, fold_preds))
        fold_metrics_history['pearson'].append(pearsonr(fold_true_labels, fold_preds)[0])
        fold_metrics_history['spearman'].append(spearmanr(fold_true_labels, fold_preds)[0])
        
        print(f"Fold {fold+1} spearman: {fold_metrics_history['spearman'][-1]:.4f}")
        os.remove(best_model_path)

    overall_metrics = {
        'rmse': np.sqrt(mean_squared_error(labels, oof_preds)),
        'r2': r2_score(labels, oof_preds),
        'pearson': pearsonr(labels, oof_preds)[0],
        'spearman': spearmanr(labels, oof_preds)[0]
    }
    
    return {
        'predictions': oof_preds,
        'overall_metrics': overall_metrics,
        'fold_metrics': fold_metrics_history
    }


adj_matrix = create_radius_adjacency_matrix(coords, radius=OPTIMAL_RADIUS)
model_params = {'lr': LEARNING_RATE, 'hidden_dim': OPTIMAL_HIDDEN_DIM, 'dropout': OPTIMAL_DROPOUT}
final_results = train_and_evaluate_gnn(
    all_node_features, y, groups, adj=adj_matrix, params=model_params, patience=EARLY_STOPPING_PATIENCE
)


for metric_name, metric_values in final_results['fold_metrics'].items():
    mean_val = np.mean(metric_values)
    std_val = np.std(metric_values)
    

overall_metrics = final_results['overall_metrics']


In [None]:
# generate embeddings (ESM-2-650M) for all mutants
import os, re, json, numpy as np, pandas as pd, torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
from Bio.PDB import PDBParser, Polypeptide

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")


DF_PATH      = "mutations_full_info.pkl" # data stored in pickle format, with columns: mutation, scaled_activity, group, sequence
OUT_PATH     = "mutant_embeddings.npz"   # save embeddings in npz format
MODEL_NAME   = "facebook/esm2_t33_650M_UR50D"
BATCH_SIZE   = 16                                 
dtype = torch.float16                            


parser  = PDBParser(QUIET=True)
WT_PDB  = "wt_af.pdb" # wild-type protein structure in pdb format predicted by AlphaFold2
chain   = next(parser.get_structure("wt", WT_PDB).get_chains())
aa3_to1 = {'ALA':'A','CYS':'C','ASP':'D','GLU':'E','PHE':'F','GLY':'G','HIS':'H',
           'ILE':'I','LYS':'K','LEU':'L','MET':'M','ASN':'N','PRO':'P','GLN':'Q',
           'ARG':'R','SER':'S','THR':'T','VAL':'V','TRP':'W','TYR':'Y'}
wt_seq  = "".join(aa3_to1.get(res.get_resname(), 'X')
                  for res in chain.get_residues() if Polypeptide.is_aa(res, True))


df = pd.read_pickle(DF_PATH)
muts = df["mutation"].unique()              


tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model     = AutoModel.from_pretrained(MODEL_NAME, torch_dtype=dtype).to(device)
model.eval()


mut_pat = re.compile(r"([A-Z])(\d+)([A-Z])")
def apply_mut(seq, mut):
    m = mut_pat.fullmatch(mut)
    if m is None: raise ValueError(f"Bad mutation string: {mut}")
    pos = int(m.group(2)) - 1          
    aa_new = m.group(3)
    return seq[:pos] + aa_new + seq[pos+1:]


emb_bank = {}
with torch.inference_mode():
    for i in tqdm(range(0, len(muts), BATCH_SIZE), ncols=88,
                  desc="Embedding mutants"):
        batch_muts = muts[i:i+BATCH_SIZE]
        seqs = [apply_mut(wt_seq, m) for m in batch_muts]
        toks = tokenizer(seqs, return_tensors="pt", padding=True)
        toks = {k: v.to(device) for k, v in toks.items()}
        out  = model(**toks, output_hidden_states=False).last_hidden_state

        # Remove BOS / EOS
        for mut, emb, length in zip(batch_muts, out, toks["attention_mask"].sum(1)-2):
            emb = emb[1:length+1].cpu().to(torch.float16).numpy() 
            emb_bank[mut] = emb

np.savez_compressed(OUT_PATH, **emb_bank)

In [None]:
# GNN + ESM embedding model
import numpy as np
import pandas as pd
from pathlib import Path
import re
import random
import os
import warnings
import matplotlib.pyplot as plt
import seaborn as sns
import pickle

from scipy.spatial import KDTree
from scipy.sparse import coo_matrix
from sklearn.model_selection import GroupKFold
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr, spearmanr

import torch
from torch.utils.data import DataLoader, Dataset
from torch import nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

from Bio.PDB import PDBParser, Polypeptide

OPTIMAL_RADIUS = 8.0
OPTIMAL_HIDDEN_DIM = 96
OPTIMAL_DROPOUT = 0.4

SEED = 42
N_SPLITS = 5
BATCH_SIZE = 16
MAX_EPOCHS = 50
LEARNING_RATE = 1e-3
EARLY_STOPPING_PATIENCE = 15

np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)



mutation_data_path = 'mutations_full_info.pkl' # data stored in pickle format, with columns: mutation, scaled_activity, group, sequence
mutations = pd.read_pickle(mutation_data_path)

y = mutations["scaled_activity"].to_numpy(dtype=np.float32)
groups = mutations["mutation"].str.extract(r"[A-Z](\d+)[A-Z]").astype(int)[0].to_numpy()

esm_embeddings_path = 'mutant_embeddings.npz' # ESM-2 embeddings of mutants, generated above
esm_bank = np.load(esm_embeddings_path)


def custom_three_to_one(residue_name: str) -> str:
    _3to1_map = {'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'}
    return _3to1_map.get(residue_name.upper(), 'X')

WT_PDB_PATH = "wt_af.pdb" # wild-type protein structure in pdb format predicted by AlphaFold2
parser = PDBParser(QUIET=True)
structure = parser.get_structure("wt", WT_PDB_PATH)
chain = next(structure.get_chains())
residues = [res for res in chain.get_residues() if Polypeptide.is_aa(res, standard=True)]
coords = np.array([res["CA"].get_coord() for res in residues], dtype=np.float32)
wt_seq = "".join(custom_three_to_one(res.get_resname()) for res in residues)

AA_CODES = "ACDEFGHIKLMNPQRSTVWY"
def aa_onehot(code):
    vec = np.zeros(len(AA_CODES), dtype=np.float32)
    vec[AA_CODES.index(code)] = 1.0
    return vec

all_node_features = []
mut_pat = re.compile(r"([A-Z])(\d+)([A-Z])")
mut_indices = []

for _, row in mutations.iterrows():
    mut = row["mutation"]
    match = mut_pat.fullmatch(mut)
    if not match: continue
    
    wt, pos, mut_aa = match.groups()
    idx = int(pos) - 1
    mut_indices.append(idx)

    esm_embedding = esm_bank[mut]
    
    mutation_indicator = np.zeros((len(wt_seq), 1), dtype=np.float32)
    mutation_indicator[idx] = 1.0
    mut_aa_onehot = np.zeros((len(wt_seq), len(AA_CODES)), dtype=np.float32)
    mut_aa_onehot[idx] = aa_onehot(mut_aa)

    combined_features = np.hstack([esm_embedding, mutation_indicator, mut_aa_onehot])
    all_node_features.append(combined_features)


def create_radius_adjacency_matrix(coords, radius=10.0):
    tree = KDTree(coords)
    pairs = tree.query_pairs(r=radius)
    row_indices, col_indices = zip(*pairs) if pairs else ([], [])
    all_rows = np.concatenate([row_indices, col_indices, np.arange(len(coords))])
    all_cols = np.concatenate([col_indices, row_indices, np.arange(len(coords))])
    adj = coo_matrix((np.ones(len(all_rows)), (all_rows, all_cols)), shape=(len(coords), len(coords)))
    return adj.tocsr()

def get_device():
    if torch.backends.mps.is_available(): return "mps"
    if torch.cuda.is_available(): return "cuda"
    return "cpu"

class GraphAttentionLayer(nn.Module):
    def __init__(self, in_features, out_features, dropout=0.2, alpha=0.2):
        super().__init__()
        self.dropout, self.in_features, self.out_features, self.alpha = dropout, in_features, out_features, alpha
        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.zeros(size=(2 * out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        self.leakyrelu = nn.LeakyReLU(alpha)

    def forward(self, h, adj):
        Wh = torch.mm(h, self.W)
        a_input = self._prepare_attentional_mechanism_input(Wh)
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))
        zero_vec = -9e15 * torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.matmul(attention, Wh)
        return F.elu(h_prime)

    def _prepare_attentional_mechanism_input(self, Wh):
        N = Wh.size()[0]
        Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=0)
        Wh_repeated_alternating = Wh.repeat(N, 1)
        all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1)
        return all_combinations_matrix.view(N, N, 2 * self.out_features)

class MutGraphNet(nn.Module):
    def __init__(self, in_dim, hidden_dim, dropout):
        super().__init__()
        self.gat1 = GraphAttentionLayer(in_dim, hidden_dim, dropout=dropout)
        self.gat2 = GraphAttentionLayer(hidden_dim, hidden_dim, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)
        self.fc = nn.Linear(hidden_dim, 1)

    def forward(self, x, adj, mut_idx):
        x = self.gat1(x, adj)
        x = self.dropout(x)
        x = self.gat2(x, adj)
        x = x[mut_idx]
        return self.fc(x)

class ProteinGraphDataset(Dataset):
    def __init__(self, features_list, labels, mut_indices):
        self.features_list = features_list
        self.labels = labels
        self.mut_indices = mut_indices

    def __len__(self):
        return len(self.features_list)

    def __getitem__(self, idx):
        return (
            torch.tensor(self.features_list[idx], dtype=torch.float32),
            torch.tensor(self.labels[idx], dtype=torch.float32),
            self.mut_indices[idx]
        )


def train_and_evaluate_gnn(features, labels, groups, mut_indices, adj, params, patience):
    kf = GroupKFold(n_splits=N_SPLITS)
    oof_preds = np.zeros(len(labels))
    device = get_device()

    
    adj_torch = torch.from_numpy(adj.todense().astype(np.float32)).to(device)

    fold_metrics_history = {
        'rmse': [], 'r2': [], 'pearson': [], 'spearman': []
    }

    for fold, (train_idx, val_idx) in enumerate(kf.split(features, labels, groups)):

        
        model = MutGraphNet(
            in_dim=features[0].shape[1], 
            hidden_dim=params['hidden_dim'], 
            dropout=params['dropout']
        ).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'], weight_decay=1e-5)
        scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
        
        train_dataset = ProteinGraphDataset(
            [features[i] for i in train_idx], labels[train_idx], [mut_indices[i] for i in train_idx]
        )
        val_dataset = ProteinGraphDataset(
            [features[i] for i in val_idx], labels[val_idx], [mut_indices[i] for i in val_idx]
        )
        loader_tr = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        loader_va = DataLoader(val_dataset, batch_size=BATCH_SIZE)

        best_val_loss = float('inf')
        patience_counter = 0
        best_model_path = f"best_model_fold_{fold+1}.pth"

        for epoch in range(MAX_EPOCHS):
            model.train()
            for x_batch, y_batch, mut_idx_batch in loader_tr:
                for i in range(x_batch.size(0)):
                    x, y, mut_idx = x_batch[i].to(device), y_batch[i].to(device), mut_idx_batch[i]
                    optimizer.zero_grad()
                    output = model(x, adj_torch, mut_idx).squeeze()
                    loss = F.mse_loss(output, y)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    optimizer.step()
            
            model.eval()
            total_val_loss = 0
            with torch.no_grad():
                for x_batch, y_batch, mut_idx_batch in loader_va:
                    for i in range(x_batch.size(0)):
                        x, y, mut_idx = x_batch[i].to(device), y_batch[i].to(device), mut_idx_batch[i]
                        output = model(x, adj_torch, mut_idx).squeeze()
                        total_val_loss += F.mse_loss(output, y).item()
            avg_val_loss = total_val_loss / len(val_dataset)
            
            if avg_val_loss < best_val_loss:
                best_val_loss, patience_counter = avg_val_loss, 0
                torch.save(model.state_dict(), best_model_path)
            else:
                patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}.")
                break

        model.load_state_dict(torch.load(best_model_path))
        model.eval()
        
        fold_preds_list = []
        fold_true_labels = labels[val_idx]
        with torch.no_grad():
            for x_batch, _, mut_idx_batch in loader_va:
                for i in range(x_batch.size(0)):
                    x, mut_idx = x_batch[i].to(device), mut_idx_batch[i]
                    fold_preds_list.append(model(x, adj_torch, mut_idx).squeeze().cpu().numpy())
        
        fold_preds = np.array(fold_preds_list).flatten()
        oof_preds[val_idx] = fold_preds
        
        fold_metrics_history['rmse'].append(np.sqrt(mean_squared_error(fold_true_labels, fold_preds)))
        fold_metrics_history['r2'].append(r2_score(fold_true_labels, fold_preds))
        fold_metrics_history['pearson'].append(pearsonr(fold_true_labels, fold_preds)[0])
        fold_metrics_history['spearman'].append(spearmanr(fold_true_labels, fold_preds)[0])
        
        print(f"Fold {fold+1} Spearman: {fold_metrics_history['spearman'][-1]:.4f}")
        os.remove(best_model_path)
    
    overall_metrics = {
        'rmse': np.sqrt(mean_squared_error(labels, oof_preds)),
        'r2': r2_score(labels, oof_preds),
        'pearson': pearsonr(labels, oof_preds)[0],
        'spearman': spearmanr(labels, oof_preds)[0]
    }
    
    return {
        'predictions': oof_preds,
        'overall_metrics': overall_metrics,
        'fold_metrics': fold_metrics_history
    }


adj_matrix = create_radius_adjacency_matrix(coords, radius=OPTIMAL_RADIUS)
model_params = {'lr': LEARNING_RATE, 'hidden_dim': OPTIMAL_HIDDEN_DIM, 'dropout': OPTIMAL_DROPOUT}

final_results = train_and_evaluate_gnn(
    all_node_features, y, groups, mut_indices,
    adj=adj_matrix, params=model_params, patience=EARLY_STOPPING_PATIENCE
)

for metric_name, metric_values in final_results['fold_metrics'].items():
    mean_val = np.mean(metric_values)
    std_val = np.std(metric_values)
  

overall_metrics = final_results['overall_metrics']


In [2]:
# GNN + ESM embedding model (optimized feature engineering)

import numpy as np
import pandas as pd
from pathlib import Path
import re
import random
import os
import warnings
import matplotlib.pyplot as plt
import seaborn as sns
from functools import partial

from scipy.spatial import KDTree
from scipy.sparse import coo_matrix
from sklearn.model_selection import GroupKFold
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr, spearmanr

import torch
from torch.utils.data import DataLoader, Dataset
from torch import nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau

from Bio.PDB import PDBParser, Polypeptide
from transformers import AutoTokenizer, AutoModel

def get_device():
    if torch.backends.mps.is_available(): return torch.device("mps")
    if torch.cuda.is_available(): return torch.device("cuda")
    return torch.device("cpu")
device = get_device()

OPTIMAL_RADIUS = 8.0
SEED = 42
N_SPLITS = 5
BATCH_SIZE = 8 
MAX_EPOCHS = 50
EARLY_STOPPING_PATIENCE = 15

hyperparams = {
    'lr': 1e-4,         
    'hidden_dim': 128,  
    'dropout': 0.4,
    'lambda_rho': 0.5   
}

np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)


mutation_data_path = 'mutations_full_info.pkl' # data stored in pickle format, with columns: mutation, scaled_activity, group, sequence
mutations = pd.read_pickle(mutation_data_path)
y_values = mutations["scaled_activity"].to_numpy(dtype=np.float32)
groups = mutations["mutation"].str.extract(r"[A-Z](\d+)[A-Z]").astype(int)[0].to_numpy()

esm_embeddings_path = 'mutant_embeddings.npz' # ESM-2 embeddings of mutants, generated above
esm_bank = np.load(esm_embeddings_path)

def custom_three_to_one(residue_name: str) -> str:
    _3to1_map = {'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R', 'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'}
    return _3to1_map.get(residue_name.upper(), 'X')

WT_PDB_PATH = "wt_af.pdb" # wild-type protein structure in pdb format predicted by AlphaFold2
parser = PDBParser(QUIET=True)
structure = parser.get_structure("wt", WT_PDB_PATH)
chain = next(structure.get_chains())
residues = [res for res in chain.get_residues() if Polypeptide.is_aa(res, standard=True)]
coords = np.array([res["CA"].get_coord() for res in residues], dtype=np.float32)
wt_seq_from_pdb = "".join(custom_three_to_one(res.get_resname()) for res in residues)

WT_EMB_PATH = "wt_esm_embedding.npz" # to generate wild-type ESM-2 embedding
if not os.path.exists(WT_EMB_PATH):
    
    tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
    model = AutoModel.from_pretrained("facebook/esm2_t33_650M_UR50D", torch_dtype=torch.float16).to(device)
    model.eval()
    with torch.inference_mode():
        tok = tokenizer(wt_seq_from_pdb, return_tensors="pt").to(device)
        wt_emb_tensor = model(**tok).last_hidden_state[0, 1:-1].cpu().to(torch.float16).numpy()
    np.savez_compressed(WT_EMB_PATH, emb=wt_emb_tensor)
    
wt_emb = np.load(WT_EMB_PATH)["emb"]

AA_CODES = "ACDEFGHIKLMNPQRSTVWY"
def aa_onehot(code):
    vec = np.zeros(len(AA_CODES), dtype=np.float32)
    vec[AA_CODES.index(code)] = 1.0
    return vec

all_node_features = []
mut_pat = re.compile(r"([A-Z])(\d+)([A-Z])")
mut_indices = []


for _, row in mutations.iterrows():
    mut = row["mutation"]
    match = mut_pat.fullmatch(mut)
    if not match: continue

    wt, pos, mut_aa = match.groups()
    idx = int(pos) - 1
    mut_indices.append(idx)

    mut_emb = esm_bank[mut]
    delta_emb = mut_emb - wt_emb[:mut_emb.shape[0]]

    mutation_indicator = np.zeros((len(wt_seq_from_pdb), 1), dtype=np.float32)
    mutation_indicator[idx] = 1.0
    mut_aa_onehot = np.zeros((len(wt_seq_from_pdb), len(AA_CODES)), dtype=np.float32)
    mut_aa_onehot[idx] = aa_onehot(mut_aa)

    combined_features = np.hstack([mut_emb, delta_emb, mutation_indicator, mut_aa_onehot])
    all_node_features.append(combined_features)
    
def create_radius_adjacency_matrix(coords, radius=10.0):
    tree = KDTree(coords)
    pairs = tree.query_pairs(r=radius)
    row_indices, col_indices = zip(*pairs) if pairs else ([], [])
    all_rows = np.concatenate([row_indices, col_indices, np.arange(len(coords))])
    all_cols = np.concatenate([col_indices, row_indices, np.arange(len(coords))])
    adj = coo_matrix((np.ones(len(all_rows)), (all_rows, all_cols)), shape=(len(coords), len(coords)))
    return adj.tocsr()

class GraphAttentionLayer(nn.Module):
    def __init__(self, in_features, out_features, dropout=0.2, alpha=0.2):
        super().__init__()
        self.dropout, self.in_features, self.out_features, self.alpha = dropout, in_features, out_features, alpha
        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.zeros(size=(2 * out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        self.leakyrelu = nn.LeakyReLU(alpha)

    def forward(self, h, adj):
        Wh = torch.mm(h, self.W)
        a_input = self._prepare_attentional_mechanism_input(Wh)
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))
        zero_vec = -9e15 * torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.matmul(attention, Wh)
        return F.elu(h_prime)

    def _prepare_attentional_mechanism_input(self, Wh):
        N = Wh.size()[0]
        Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=0)
        Wh_repeated_alternating = Wh.repeat(N, 1)
        all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1)
        return all_combinations_matrix.view(N, N, 2 * self.out_features)

class MutGraphNet(nn.Module):
    def __init__(self, in_dim, hidden_dim, dropout):
        super().__init__()
        self.gat1 = GraphAttentionLayer(in_dim, hidden_dim, dropout=dropout)
        self.gat2 = GraphAttentionLayer(hidden_dim, hidden_dim, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)
        self.fc = nn.Linear(hidden_dim, 1)

    def forward(self, x, adj, mut_idx):
        x = self.gat1(x, adj)
        x = self.dropout(x)
        x = self.gat2(x, adj)
        x = x[mut_idx]
        return self.fc(x)

class ProteinGraphDataset(Dataset):
    def __init__(self, features_list, labels, mut_indices):
        self.features_list = features_list
        self.labels = labels
        self.mut_indices = mut_indices

    def __len__(self):
        return len(self.features_list)

    def __getitem__(self, idx):
        return (torch.tensor(self.features_list[idx], dtype=torch.float32),
                torch.tensor(self.labels[idx], dtype=torch.float32),
                self.mut_indices[idx])

def soft_rank_pytorch(x, steepness=1.0):
    diff = x.unsqueeze(1) - x.unsqueeze(0)
    return torch.sum(torch.sigmoid(steepness * diff), dim=1)

def spearman_soft_pytorch(pred, target, steepness=1.0):
    pred_ranks = soft_rank_pytorch(pred, steepness) - 1
    target_ranks = soft_rank_pytorch(target, steepness) - 1
    pred_ranks_c = pred_ranks - pred_ranks.mean()
    target_ranks_c = target_ranks - target_ranks.mean()
    pred_norm = torch.linalg.norm(pred_ranks_c)
    target_norm = torch.linalg.norm(target_ranks_c)
    if pred_norm < 1e-6 or target_norm < 1e-6:
        return torch.tensor(0.0, device=pred.device, requires_grad=True)
    return (pred_ranks_c * target_ranks_c).sum() / (pred_norm * target_norm)

huber_loss_fn = nn.SmoothL1Loss(beta=1.0)
def huber_rho_loss(pred, target, lambda_rho, steepness=1.0):
    huber_part = huber_loss_fn(pred, target)
    rho_part = 1.0 - spearman_soft_pytorch(pred, target, steepness)
    return huber_part + lambda_rho * rho_part

def train_and_evaluate_gnn(features, labels, groups, mut_indices_list, adj, params, patience):
    kf = GroupKFold(n_splits=N_SPLITS)
    oof_preds = np.zeros(len(labels))
    adj_torch = torch.from_numpy(adj.todense().astype(np.float32)).to(device)
    fold_metrics_history = {'rmse': [], 'r2': [], 'pearson': [], 'spearman': []}
    loss_fn = partial(huber_rho_loss, lambda_rho=params['lambda_rho'])

    for fold, (train_idx, val_idx) in enumerate(kf.split(features, labels, groups)):
        
        model = MutGraphNet(in_dim=features[0].shape[1], hidden_dim=params['hidden_dim'], dropout=params['dropout']).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'], weight_decay=1e-5)
        scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
        
        train_ds = ProteinGraphDataset([features[i] for i in train_idx], labels[train_idx], [mut_indices_list[i] for i in train_idx])
        val_ds = ProteinGraphDataset([features[i] for i in val_idx], labels[val_idx], [mut_indices_list[i] for i in val_idx])
        loader_tr = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
        loader_va = DataLoader(val_ds, batch_size=BATCH_SIZE)

        best_val_loss, patience_counter = float('inf'), 0
        best_model_path = f"best_model_fold_{fold+1}.pth"

        for epoch in range(MAX_EPOCHS):
            model.train()
            for x_batch, y_batch, mut_idx_batch in loader_tr:
                optimizer.zero_grad()
                batch_outputs = [model(x.to(device), adj_torch, mut_idx).squeeze() for x, mut_idx in zip(x_batch, mut_idx_batch)]
                outputs_tensor = torch.stack(batch_outputs)
                loss = loss_fn(outputs_tensor, y_batch.to(device))
                if torch.isnan(loss): continue
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()

            model.eval()
            val_outputs_list = []
            with torch.no_grad():
                for x_batch, _, mut_idx_batch in loader_va:
                    for i in range(x_batch.size(0)):
                        x, mut_idx = x_batch[i].to(device), mut_idx_batch[i]
                        output = model(x, adj_torch, mut_idx).squeeze()
                        val_outputs_list.append(output.cpu())

            val_outputs_tensor = torch.stack(val_outputs_list)
            val_labels_tensor = torch.tensor(val_ds.labels, dtype=torch.float32)
            
            avg_val_loss = loss_fn(val_outputs_tensor.to(device), val_labels_tensor.to(device)).item()
            val_outputs_np = val_outputs_tensor.numpy()
            val_labels_np = val_labels_tensor.numpy()
            val_spearman = spearmanr(val_labels_np, val_outputs_np)[0]
            print(f"epoch {epoch+1:02d}/{MAX_EPOCHS} | val loss: {avg_val_loss:.4f} | val spearman: {val_spearman:.4f}")

            scheduler.step(avg_val_loss)
            if avg_val_loss < best_val_loss:
                best_val_loss, patience_counter = avg_val_loss, 0
                torch.save(model.state_dict(), best_model_path)
            else:
                patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}.")
                break
        
        model.load_state_dict(torch.load(best_model_path))
        model.eval()
        fold_preds_list = []
        with torch.no_grad():
            for x_batch, _, mut_idx_batch in loader_va:
                for i in range(x_batch.size(0)):
                    x, mut_idx = x_batch[i].to(device), mut_idx_batch[i]
                    output = model(x, adj_torch, mut_idx).squeeze()
                    fold_preds_list.append(output.cpu())
        
        fold_preds = torch.stack(fold_preds_list).numpy()
        oof_preds[val_idx] = fold_preds
        fold_true = labels[val_idx]

        fold_metrics_history['rmse'].append(np.sqrt(mean_squared_error(fold_true, fold_preds)))
        fold_metrics_history['r2'].append(r2_score(fold_true, fold_preds))
        fold_metrics_history['pearson'].append(pearsonr(fold_true, fold_preds)[0])
        fold_metrics_history['spearman'].append(spearmanr(fold_true, fold_preds)[0])
        print(f"Fold {fold+1} Spearman: {fold_metrics_history['spearman'][-1]:.4f}")
        os.remove(best_model_path)

    overall_metrics = {'rmse': np.sqrt(mean_squared_error(labels, oof_preds)), 'r2': r2_score(labels, oof_preds), 'pearson': pearsonr(labels, oof_preds)[0], 'spearman': spearmanr(labels, oof_preds)[0]}
    return {'predictions': oof_preds, 'overall_metrics': overall_metrics, 'fold_metrics': fold_metrics_history}


adj_matrix = create_radius_adjacency_matrix(coords, radius=OPTIMAL_RADIUS)
final_results = train_and_evaluate_gnn(all_node_features, y_values, groups, mut_indices, adj=adj_matrix, params=hyperparams, patience=EARLY_STOPPING_PATIENCE)


for metric_name, metric_values in final_results['fold_metrics'].items():
    mean_val, std_val = np.mean(metric_values), np.std(metric_values)

overall_metrics = final_results['overall_metrics']


In [None]:
# EGNN + ESM embedding model (optimized feature engineering)

import numpy as np
import pandas as pd
import re
import random
import os
import warnings
import torch
from functools import partial

import matplotlib.pyplot as plt
import seaborn as sns

from torch import nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import GroupKFold
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr, spearmanr

from Bio.PDB import PDBParser, Polypeptide
from transformers import AutoTokenizer, AutoModel

mutations = pd.read_pickle('mutations_full_info.pkl') # data stored in pickle format, with columns: mutation, scaled_activity, group, sequence

y_values = mutations["scaled_activity"].to_numpy(dtype=np.float32).reshape(-1, 1)
groups = mutations["mutation"].str.extract(r"[A-Z](\d+)[A-Z]").astype(int)[0].to_numpy()

def get_device():
    if torch.backends.mps.is_available(): return torch.device("mps")
    if torch.cuda.is_available(): return torch.device("cuda")
    return torch.device("cpu")

device = get_device()


OUT_PATH = "mutant_embeddings.npz" # ESM-2 embeddings of mutants
MODEL_NAME = "facebook/esm2_t33_650M_UR50D"

wt_seq = 'MDSLVVLVLCL...'

WT_EMB_PATH = "wt_esm650m_embedding.npz" # to generate wild-type ESM-2 embedding
if not os.path.exists(WT_EMB_PATH):

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModel.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(device)
    model.eval()
    
    with torch.inference_mode():
        tok = tokenizer(wt_seq, return_tensors="pt").to(device)
        wt_emb = model(**tok).last_hidden_state[0, 1:-1].cpu().to(torch.float16).numpy()
    
    np.savez_compressed(WT_EMB_PATH, emb=wt_emb)


SEED = 42
N_SPLITS = 5
EPOCHS = 50

np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)

def scatter_mean_workaround(src, index, dim, dim_size):
    out = torch.zeros(dim_size, src.size(1), dtype=src.dtype, device=src.device)
    count = torch.zeros(dim_size, 1, dtype=src.dtype, device=src.device)
    out.index_add_(dim, index, src)
    ones_for_counting = torch.ones(src.size(0), 1, dtype=src.dtype, device=src.device)
    count.index_add_(dim, index, ones_for_counting)
    return out / count.clamp(min=1)

WT_PDB = "wt_af.pdb"
parser = PDBParser(QUIET=True)
structure = parser.get_structure("wt", WT_PDB)
chain = next(structure.get_chains())
residues = [res for res in chain.get_residues() if Polypeptide.is_aa(res, standard=True)]
coords = np.array([res["CA"].get_coord() for res in residues], dtype=np.float32)
wt_seq_len = len(coords)

embedding_bank = np.load(OUT_PATH)
wt_emb = np.load(WT_EMB_PATH)["emb"]

MUTANT_ESM_DIM = embedding_bank[list(embedding_bank.keys())[0]].shape[1]
DELTA_ESM_DIM = wt_emb.shape[1]
MUT_INDICATOR_COL_IDX = MUTANT_ESM_DIM + DELTA_ESM_DIM

AA_CODES = "ACDEFGHIKLMNPQRSTVWY"
def aa_onehot(code):
    vec = np.zeros(len(AA_CODES), dtype=np.float32)
    vec[AA_CODES.index(code)] = 1.0
    return vec

all_node_features = []
mut_indices = []
mut_pat = re.compile(r"([A-Z])(\d+)([A-Z])")


for _, row in mutations.iterrows():
    mut_string = row["mutation"]
    match = mut_pat.fullmatch(mut_string)
    if not match: continue

    _, pos, mut_aa = match.groups()
    idx = int(pos) - 1

    mut_emb = embedding_bank[mut_string]
    delta_emb = mut_emb - wt_emb[:mut_emb.shape[0]]
    mut_indices.append(idx)
    mutation_indicator = np.zeros((wt_seq_len, 1), dtype=np.float16)
    mutation_indicator[idx] = 1.0
    mut_aa_onehot = np.zeros((wt_seq_len, len(AA_CODES)), dtype=np.float16)
    mut_aa_onehot[idx] = aa_onehot(mut_aa)

    combined = np.hstack([mut_emb, delta_emb, mutation_indicator, mut_aa_onehot])
    all_node_features.append(combined)



def create_radius_graph(coords, radius=8.0):
    coords_t = torch.from_numpy(coords)
    dist_matrix = torch.cdist(coords_t, coords_t)
    edge_index = (dist_matrix < radius).nonzero(as_tuple=False).t().contiguous()
    return edge_index[:, edge_index[0] != edge_index[1]]

GRAPH_RADIUS = 8.0
edge_index = create_radius_graph(coords, radius=GRAPH_RADIUS)


class ProteinDataset(Dataset):
    def __init__(self, node_features_list, coords, labels, mut_idx_list):
        self.node_features_list = [torch.tensor(f, dtype=torch.float32) for f in node_features_list]
        self.coords = torch.tensor(coords, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)
        self.mut_idx = mut_idx_list
    def __len__(self): return len(self.node_features_list)
    def __getitem__(self, idx): 
        pos = self.coords.clone()
        m_i = self.mut_idx[idx]
        pos[m_i] = torch.full((3,), float('nan'))
        return {"x": self.node_features_list[idx], "pos": pos, "y": self.labels[idx]}

def collate_fn(batch):
    xs, poss, ys, batch_vecs, edge_indices = [], [], [], [], []
    n_nodes_so_far = 0
    for i, item in enumerate(batch):
        N = item["x"].shape[0]
        xs.append(item["x"])
        poss.append(item["pos"])
        ys.append(item["y"])
        batch_vecs.append(torch.full((N,), i, dtype=torch.long))
        edge_indices.append(edge_index + n_nodes_so_far)
        n_nodes_so_far += N
    return {
        "x": torch.cat(xs, dim=0), "pos": torch.cat(poss, dim=0),
        "edge_index": torch.cat(edge_indices, dim=1), "batch": torch.cat(batch_vecs, dim=0),
        "y": torch.stack(ys).squeeze(-1)}

dataset = ProteinDataset(all_node_features, coords, y_values, mut_indices)


class E_GCL(nn.Module): # equivariant graph convolution layer
    def __init__(self, input_nf, output_nf, hidden_nf, act_fn=nn.SiLU(), residual=True, num_rbf=16, rbf_cutoff=12.0):
        super().__init__()
        self.residual = residual
        self.edge_mlp = nn.Sequential(nn.Linear(input_nf * 2 + num_rbf, hidden_nf), act_fn, nn.Linear(hidden_nf, hidden_nf), act_fn)
        self.node_mlp = nn.Sequential(nn.Linear(hidden_nf + input_nf, hidden_nf), act_fn, nn.Linear(hidden_nf, output_nf))
        self.coord_mlp = nn.Sequential(nn.Linear(hidden_nf, hidden_nf), act_fn, nn.Linear(hidden_nf, 1, bias=False))
        centres = torch.linspace(0.0, rbf_cutoff, num_rbf)
        gamma = (rbf_cutoff / num_rbf) ** -2
        self.register_buffer("rbf_centres", centres)
        self.register_buffer("rbf_gamma", torch.full((num_rbf,), gamma))
    def edge_model(self, source, target, rbf): return self.edge_mlp(torch.cat([source, target, rbf], dim=1))
    def node_model(self, x, edge_index, edge_attr):
        row, _ = edge_index
        agg = scatter_mean_workaround(edge_attr, row, dim=0, dim_size=x.size(0))
        return self.node_mlp(torch.cat([x, agg], dim=1))
    def coord_model(self, coord, edge_index, coord_diff, edge_feat):
        row, _ = edge_index
        trans = coord_diff * self.coord_mlp(edge_feat)
        return scatter_mean_workaround(trans, row, dim=0, dim_size=coord.size(0))
    def forward(self, h, edge_index, coord):
        coord_filled = torch.nan_to_num(coord, nan=0.0)
        row, col = edge_index
        coord_diff = coord_filled[row] - coord_filled[col]
        dist = torch.norm(coord_diff, dim=1, keepdim=True)
        rbf = torch.exp(-self.rbf_gamma * (dist - self.rbf_centres) ** 2)
        edge_feat = self.edge_model(h[row], h[col], rbf)
        delta = self.coord_model(coord_filled, edge_index, coord_diff, edge_feat)
        coord = torch.where(torch.isnan(coord), delta, coord + delta)
        h_new = self.node_model(h, edge_index, edge_feat)
        h = h + h_new if self.residual else h_new
        return h, coord

class MutEGNN(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_layers, dropout, mut_col: int):
        super().__init__()
        self.mut_col = mut_col
        self.embed = nn.Linear(in_dim, hidden_dim)
        self.layers = nn.ModuleList([E_GCL(hidden_dim, hidden_dim, hidden_dim) for _ in range(n_layers)])
        self.head = nn.Sequential(nn.Linear(hidden_dim, hidden_dim * 2), nn.SiLU(), nn.Dropout(dropout), nn.Linear(hidden_dim * 2, 1))
    def forward(self, data_batch, edge_index):
        x, pos, batch_vec = data_batch['x'], data_batch['pos'], data_batch['batch']
        mut_mask = x[:, self.mut_col:self.mut_col + 1]
        h = self.embed(x)
        for layer in self.layers:
            h, pos = layer(h, edge_index, pos)
        mean_weighted_h = scatter_mean_workaround(h * mut_mask, batch_vec, dim=0, dim_size=data_batch['y'].size(0))
        mean_mask = scatter_mean_workaround(mut_mask, batch_vec, dim=0, dim_size=data_batch['y'].size(0))
        graph_features = mean_weighted_h / (mean_mask + 1e-6)
        return self.head(graph_features).squeeze(-1)

def soft_rank_pytorch(x, steepness=1.0):
    diff = x.unsqueeze(1) - x.unsqueeze(0)
    ranks = torch.sum(torch.sigmoid(steepness * diff), dim=1)
    return ranks

def spearman_soft_pytorch(pred, target, steepness=1.0):
    pred_ranks = soft_rank_pytorch(pred, steepness)
    target_ranks = soft_rank_pytorch(target, steepness)
    
    pred_ranks = pred_ranks - pred_ranks.mean()
    target_ranks = target_ranks - target_ranks.mean()

    pred_norm = torch.linalg.norm(pred_ranks)
    target_norm = torch.linalg.norm(target_ranks)

    if pred_norm < 1e-6 or target_norm < 1e-6:
        return torch.tensor(0.0, device=pred.device, requires_grad=True)
    
    return (pred_ranks * target_ranks).sum() / (pred_norm * target_norm)

huber_loss_fn = nn.SmoothL1Loss(beta=1.0)

def huber_rho_loss(pred, target, lambda_rho, steepness=1.0):
    huber_part = huber_loss_fn(pred, target)
    rho_part = 1.0 - spearman_soft_pytorch(pred, target, steepness)
    return huber_part + lambda_rho * rho_part

def train_evaluate_egnn(dataset, groups, edge_index_template, params, mut_indicator_col):
    kf = GroupKFold(n_splits=N_SPLITS)
    oof_preds, true_labels = np.zeros(len(dataset)), np.zeros(len(dataset))
    device = get_device()

    
    fold_metrics_history = {'rmse': [], 'r2': [], 'pearson': [], 'spearman': []}
    edge_index = edge_index_template.to(device)

    loss_fn = partial(huber_rho_loss, lambda_rho=params['lambda_rho'])

    for fold, (train_idx, val_idx) in enumerate(kf.split(dataset, groups=groups)):

        train_subset = torch.utils.data.Subset(dataset, train_idx)
        val_subset = torch.utils.data.Subset(dataset, val_idx)
        loader_tr = DataLoader(train_subset, batch_size=params['batch_size'], shuffle=True, collate_fn=collate_fn)
        loader_va = DataLoader(val_subset, batch_size=params['batch_size'], collate_fn=collate_fn)
        
        model = MutEGNN(
            in_dim=dataset[0]['x'].shape[1],
            hidden_dim=params['hidden_dim'],
            n_layers=params['n_layers'],
            dropout=params['dropout'],
            mut_col=mut_indicator_col
        ).to(device)
        
        optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'], weight_decay=1e-5)
        scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
        
        best_val_loss, epochs_no_improve = float('inf'), 0
        patience = params.get('patience', 15)
        best_model_path = f"best_model_fold_{fold+1}.pth"

        for epoch in range(EPOCHS):
            model.train()
            for batch in loader_tr:
                batch = {k: v.to(device) for k, v in batch.items()}
                optimizer.zero_grad()
                output = model(batch, batch["edge_index"])
                loss = loss_fn(output, batch['y'])
                
                if torch.isnan(loss):
                    warnings.warn(f"NaN loss detected in fold {fold+1}, epoch {epoch+1}. skip current batch.")
                    continue
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
            
            model.eval()
            total_val_loss, fold_preds_val, fold_true_val = 0, [], []
            with torch.no_grad():
                for batch in loader_va:
                    batch_device = {k: v.to(device) for k, v in batch.items()}
                    preds = model(batch_device, batch_device['edge_index'])
                    if not torch.all(torch.isfinite(preds)):
                         continue
                    
                    val_loss = loss_fn(preds, batch_device['y'])
                    if not torch.isnan(val_loss):
                        total_val_loss += val_loss.item() * batch_device['y'].size(0)
                        fold_preds_val.extend(preds.cpu().numpy().flatten())
                        fold_true_val.extend(batch['y'].cpu().numpy().flatten())

            if len(fold_true_val) == 0:
                print(f"All validation batches were unstable in fold {fold+1}, epoch {epoch+1}. skip current epoch.")
                continue

            avg_val_loss = total_val_loss / len(fold_true_val)
            val_spearman = spearmanr(fold_true_val, fold_preds_val)[0]
            print(f"epoch {epoch+1:02d}/{EPOCHS} | val loss: {avg_val_loss:.4f} | val spearman: {val_spearman:.4f}")
            
            scheduler.step(avg_val_loss)
            
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss; epochs_no_improve = 0
                torch.save(model.state_dict(), best_model_path)
            else:
                epochs_no_improve += 1
            
            if epochs_no_improve >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}!"); break
        
        
        model.load_state_dict(torch.load(best_model_path))
        
        model.eval()
        fold_preds, fold_true = [], []
        with torch.no_grad():
            for batch in loader_va:
                batch_device = {k: v.to(device) for k, v in batch.items()}
                preds = model(batch_device, batch_device['edge_index'])
                if torch.all(torch.isfinite(preds)):
                    fold_preds.extend(preds.cpu().numpy().flatten()); fold_true.extend(batch['y'].cpu().numpy().flatten())
        
        if len(fold_true) > 1:
            oof_preds[val_idx], true_labels[val_idx] = fold_preds, fold_true
            fold_metrics_history['rmse'].append(np.sqrt(mean_squared_error(fold_true, fold_preds)))
            fold_metrics_history['r2'].append(r2_score(fold_true, fold_preds))
            fold_metrics_history['pearson'].append(pearsonr(fold_true, fold_preds)[0])
            fold_metrics_history['spearman'].append(spearmanr(fold_true, fold_preds)[0])
            print(f"Fold {fold+1} Spearman: {fold_metrics_history['spearman'][-1]:.4f}")
        
        if os.path.exists(best_model_path): os.remove(best_model_path)
        
    overall_metrics = {
        'rmse': np.sqrt(mean_squared_error(true_labels, oof_preds)),
        'r2': r2_score(true_labels, oof_preds),
        'pearson': pearsonr(true_labels, oof_preds)[0],
        'spearman': spearmanr(true_labels, oof_preds)[0]
    }
    
    return {'overall_metrics': overall_metrics, 'fold_metrics': fold_metrics_history, 'oof_predictions': {'true': true_labels, 'pred': oof_preds}}

hyperparams = {
    'lr': 1e-4,
    'batch_size': 16,
    'hidden_dim': 128,
    'n_layers': 3,
    'dropout': 0.2,
    'patience': 15,
    'lambda_rho': 0.5
}

final_results = train_evaluate_egnn(dataset, groups, edge_index, params=hyperparams, mut_indicator_col=MUT_INDICATOR_COL_IDX)

for metric_name, metric_values in final_results['fold_metrics'].items():
    mean_val, std_val = np.mean(metric_values), np.std(metric_values)
    
overall_metrics = final_results['overall_metrics']
