In [None]:
import pandas as pd
import numpy as np
import pickle
import re
import xgboost as xgb
import optuna
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import GroupKFold, KFold, cross_val_score, cross_val_predict
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score, make_scorer
from sklearn.linear_model import ElasticNet, LinearRegression
from sklearn.svm import SVR
from scipy.stats import spearmanr, pearsonr
from xgboost.callback import EarlyStopping
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import shap


from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.svm import SVR
from xgboost import XGBRegressor
from lightgbm import LGBMRegressor


if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")


In [None]:
mutations = pd.read_csv('your_path/mutations.csv') # should contain mutation string, sequence,scaled activity, and group

In [2]:
def get_embeddings(model_name, sequences, device):
    """Generates mean-pooled embeddings for a list of sequences using a given ESM model."""
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name).to(device)
    model.eval()

    embeddings = []
    with torch.no_grad():
        for seq in tqdm(sequences, desc=f"Generating embeddings for {model_name}"):
            inputs = tokenizer(seq, return_tensors="pt", truncation=True).to(device)
            outputs = model(**inputs)
            # Mean pooling: average the embeddings of all tokens, ignoring padding
            attention_mask = inputs['attention_mask']
            token_embeddings = outputs.last_hidden_state
            masked_sum = (token_embeddings * attention_mask.unsqueeze(-1)).sum(dim=1)
            sequence_length = attention_mask.sum(dim=1)
            mean_pooled_embedding = masked_sum / sequence_length
            embeddings.append(mean_pooled_embedding.cpu().numpy())

    return np.vstack(embeddings)

In [None]:
embeddings = get_embeddings("facebook/esm2_t33_650M_UR50D", mutations['sequence'].tolist(), device)

In [5]:
X = embeddings
y = mutations['scaled_activity'].values
groups = mutations['group'].values

In [22]:
with open('esm_embeddings.pkl', 'wb') as f:
    pickle.dump(embeddings, f)

In [7]:
# XGBoost
def objective_xgb_final(trial):
    params = {
        'objective': 'reg:squarederror', 'eval_metric': 'rmse',
        'n_estimators': trial.suggest_int('n_estimators', 200, 1000),
        'max_depth': trial.suggest_int('max_depth', 3, 10),
        'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.3, log=True),
        'subsample': trial.suggest_float('subsample', 0.6, 1.0),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.6, 1.0),
        'tree_method': 'hist', 'random_state': 42
    }
    gkf = GroupKFold(n_splits=5)
    spearman_scores = []

    for train_idx, val_idx in gkf.split(X, y, groups):
        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]
        if len(X_train) == 0 or len(X_val) == 0: continue

        early_stopping_callback = EarlyStopping(rounds=50, save_best=True)
        model = xgb.XGBRegressor(**params, callbacks=[early_stopping_callback])
        model.fit(X_train, y_train, eval_set=[(X_val, y_val)], verbose=False)
        preds = model.predict(X_val)
        spearman_scores.append(spearmanr(y_val, preds)[0])

    return np.mean(spearman_scores) if spearman_scores else -1.0

In [None]:

study = optuna.create_study(direction='maximize')
study.optimize(objective_xgb_final, n_trials=50, show_progress_bar=True)


In [None]:


best_params = study.best_params
best_params.update({'objective': 'reg:squarederror', 'random_state': 42, 'tree_method': 'hist'})

gkf = GroupKFold(n_splits=5)
fold_metrics = []
oof_preds = np.zeros(len(y))
oof_true = np.zeros(len(y))

for fold, (train_idx, val_idx) in enumerate(gkf.split(X, y, groups)):
    X_train, X_val = X[train_idx], X[val_idx]
    y_train, y_val = y[train_idx], y[val_idx]


    if len(X_train) == 0 or len(X_val) == 0:
        print(f"Skipping Fold {fold+1} due to empty split.")
        continue

    final_early_stopping_callback = EarlyStopping(rounds=50, save_best=True)
    final_model = xgb.XGBRegressor(**best_params, callbacks=[final_early_stopping_callback])

    final_model.fit(X_train, y_train, eval_set=[(X_val, y_val)], verbose=False)

    preds = final_model.predict(X_val)
    oof_preds[val_idx], oof_true[val_idx] = preds, y_val
    
    metrics = {
        'spearman': spearmanr(y_val, preds)[0],
        'pearson': pearsonr(y_val, preds)[0],
        'rmse': np.sqrt(mean_squared_error(y_val, preds)),
        'r2': r2_score(y_val, preds)
    }
    fold_metrics.append(metrics)


In [None]:

for metric_name in fold_metrics[0].keys():
    metric_values = [m[metric_name] for m in fold_metrics]
    mean_val, std_val = np.mean(metric_values), np.std(metric_values)
   


overall_metrics = {
    'Spearman ρ': spearmanr(oof_true, oof_preds)[0],
    'Pearson r ': pearsonr(oof_true, oof_preds)[0],
    'RMSE      ': np.sqrt(mean_squared_error(oof_true, oof_preds)),
    'R-squared ': r2_score(oof_true, oof_preds)
}


In [32]:
# Random Forest
def objective_rf(trial):
    """Optuna objective function for Random Forest Regressor."""
    params = {
        'n_estimators': trial.suggest_int('n_estimators', 100, 1000),
        'max_depth': trial.suggest_int('max_depth', 5, 50),
        'min_samples_split': trial.suggest_int('min_samples_split', 2, 20),
        'min_samples_leaf': trial.suggest_int('min_samples_leaf', 1, 20),
        'max_features': trial.suggest_float('max_features', 0.1, 1.0),
        'random_state': 42,
        'n_jobs': -1 
    }
    
    gkf = GroupKFold(n_splits=5)
    spearman_scores = []

    for train_idx, val_idx in gkf.split(X, y, groups):
        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]
        
        if len(X_train) == 0 or len(X_val) == 0:
            continue

        model = RandomForestRegressor(**params)
        model.fit(X_train, y_train)
        preds = model.predict(X_val)
        spearman_scores.append(spearmanr(y_val, preds)[0])

    return np.mean(spearman_scores) if spearman_scores else -1.0

In [None]:

study_rf = optuna.create_study(direction='maximize')
study_rf.optimize(objective_rf, n_trials=50, show_progress_bar=True)


In [None]:

best_params_rf = study_rf.best_params
best_params_rf.update({'random_state': 42, 'n_jobs': -1})

gkf = GroupKFold(n_splits=5)
fold_metrics_rf = []
oof_preds_rf = np.zeros(len(y))
oof_true_rf = np.zeros(len(y))

for fold, (train_idx, val_idx) in enumerate(tqdm(gkf.split(X, y, groups), total=5, desc="CV Folds")):
    X_train, X_val = X[train_idx], X[val_idx]
    y_train, y_val = y[train_idx], y[val_idx]

    if len(X_train) == 0 or len(X_val) == 0:
        print(f"Skipping Fold {fold+1} due to empty split.")
        continue

    final_model_rf = RandomForestRegressor(**best_params_rf)
    final_model_rf.fit(X_train, y_train)

    preds = final_model_rf.predict(X_val)
    oof_preds_rf[val_idx] = preds
    oof_true_rf[val_idx] = y_val
    
    metrics = {
        'spearman': spearmanr(y_val, preds)[0],
        'pearson': pearsonr(y_val, preds)[0],
        'rmse': np.sqrt(mean_squared_error(y_val, preds)),
        'r2': r2_score(y_val, preds)
    }
    fold_metrics_rf.append(metrics)
    

In [None]:

for metric_name in fold_metrics_rf[0].keys():
    metric_values = [m[metric_name] for m in fold_metrics_rf]
    mean_val, std_val = np.mean(metric_values), np.std(metric_values)
    


overall_metrics_rf = {
    'Spearman ρ': spearmanr(oof_true_rf, oof_preds_rf)[0],
    'Pearson r ': pearsonr(oof_true_rf, oof_preds_rf)[0],
    'RMSE      ': np.sqrt(mean_squared_error(oof_true_rf, oof_preds_rf)),
    'R-squared ': r2_score(oof_true_rf, oof_preds_rf)
}


In [35]:
# ElasticNet

X = embeddings
y = mutations['scaled_activity'].values
groups = mutations['group'].values

In [36]:
def objective_en(trial):
    """Optuna objective function for ElasticNet."""

    params = {
        'alpha': trial.suggest_float('alpha', 1e-4, 1e1, log=True),
        'l1_ratio': trial.suggest_float('l1_ratio', 0.0, 1.0),
        'random_state': 42,
        'max_iter': 2000 
    }
    
    gkf = GroupKFold(n_splits=5)
    spearman_scores = []

    for train_idx, val_idx in gkf.split(X, y, groups):
        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]
        
        if len(X_train) == 0 or len(X_val) == 0:
            continue

        
        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train)
        X_val_scaled = scaler.transform(X_val)

        model = ElasticNet(**params)
        model.fit(X_train_scaled, y_train)
        preds = model.predict(X_val_scaled)
        spearman_scores.append(spearmanr(y_val, preds)[0])

    return np.mean(spearman_scores) if spearman_scores else -1.0

In [None]:

study_en = optuna.create_study(direction='maximize')

study_en.optimize(objective_en, n_trials=50, show_progress_bar=True)


In [None]:

best_params_en = study_en.best_params
best_params_en.update({'random_state': 42, 'max_iter': 2000})

gkf = GroupKFold(n_splits=5)
fold_metrics_en = []
oof_preds_en = np.zeros(len(y))
oof_true_en = np.zeros(len(y))

for fold, (train_idx, val_idx) in enumerate(tqdm(gkf.split(X, y, groups), total=5, desc="CV Folds")):
    X_train, X_val = X[train_idx], X[val_idx]
    y_train, y_val = y[train_idx], y[val_idx]

    if len(X_train) == 0 or len(X_val) == 0:
        continue

    
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_val_scaled = scaler.transform(X_val)

    final_model_en = ElasticNet(**best_params_en)
    final_model_en.fit(X_train_scaled, y_train)

    preds = final_model_en.predict(X_val_scaled)
    oof_preds_en[val_idx] = preds
    oof_true_en[val_idx] = y_val
    
    metrics = {
        'spearman': spearmanr(y_val, preds)[0],
        'pearson': pearsonr(y_val, preds)[0],
        'rmse': np.sqrt(mean_squared_error(y_val, preds)),
        'r2': r2_score(y_val, preds)
    }
    fold_metrics_en.append(metrics)

In [None]:

for metric_name in fold_metrics_en[0].keys():
    metric_values = [m[metric_name] for m in fold_metrics_en]
    mean_val, std_val = np.mean(metric_values), np.std(metric_values)
    


overall_metrics_en = {
    'Spearman ρ': spearmanr(oof_true_en, oof_preds_en)[0],
    'Pearson r ': pearsonr(oof_true_en, oof_preds_en)[0],
    'RMSE      ': np.sqrt(mean_squared_error(oof_true_en, oof_preds_en)),
    'R-squared ': r2_score(oof_true_en, oof_preds_en)
}


In [None]:
# linear regression

gkf = GroupKFold(n_splits=5)
fold_metrics_lr = []
oof_preds_lr = np.zeros(len(y))
oof_true_lr = np.zeros(len(y))

for fold, (train_idx, val_idx) in enumerate(tqdm(gkf.split(X, y, groups), total=5, desc="CV Folds")):
    X_train, X_val = X[train_idx], X[val_idx]
    y_train, y_val = y[train_idx], y[val_idx]

    if len(X_train) == 0 or len(X_val) == 0:
        continue

    
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_val_scaled = scaler.transform(X_val)

   
    final_model_lr = LinearRegression()
    final_model_lr.fit(X_train_scaled, y_train)

    preds = final_model_lr.predict(X_val_scaled)
    oof_preds_lr[val_idx] = preds
    oof_true_lr[val_idx] = y_val
    
    metrics = {
        'spearman': spearmanr(y_val, preds)[0],
        'pearson': pearsonr(y_val, preds)[0],
        'rmse': np.sqrt(mean_squared_error(y_val, preds)),
        'r2': r2_score(y_val, preds)
    }
    fold_metrics_lr.append(metrics)



for metric_name in fold_metrics_lr[0].keys():
    metric_values = [m[metric_name] for m in fold_metrics_lr]
    mean_val, std_val = np.mean(metric_values), np.std(metric_values)
    


overall_metrics_lr = {
    'Spearman ρ': spearmanr(oof_true_lr, oof_preds_lr)[0],
    'Pearson r ': pearsonr(oof_true_lr, oof_preds_lr)[0],
    'RMSE      ': np.sqrt(mean_squared_error(oof_true_lr, oof_preds_lr)),
    'R-squared ': r2_score(oof_true_lr, oof_preds_lr)
}