In [None]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold, RandomizedSearchCV
from sklearn.metrics import f1_score, precision_recall_fscore_support, cohen_kappa_score, make_scorer, roc_auc_score
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import TruncatedSVD
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
import xgboost as xgb
import lightgbm as lgb
import warnings
import random
import os
import gc
import time
from collections import defaultdict

ESM2_MODEL_NAMES = [
    "facebook/esm2_t6_8M_UR50D",
    "facebook/esm2_t12_35M_UR50D",
    "facebook/esm2_t30_150M_UR50D",
    "facebook/esm2_t33_650M_UR50D",
]
MAX_LENGTH = 256
EMBEDDING_BATCH_SIZE = 16
NUM_CLASSES = 3

SEED = 42
N_SPLITS_ML = 5
N_ITER_RANDOM_SEARCH = 30
CV_RANDOM_SEARCH = 3

np.random.seed(SEED)
random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
elif torch.backends.mps.is_available():
    pass
torch.manual_seed(SEED)
warnings.filterwarnings("ignore")

if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS (Apple Silicon GPU) for embedding generation.")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA for embedding generation.")
else:
    device = torch.device("cpu")
    print("Using CPU for embedding generation.")

def load_and_preprocess_data_for_embeddings(csv_path='data.csv'):
    try:
        df = pd.read_csv(csv_path)
    except FileNotFoundError:
        print(f"Error: The file {csv_path} was not found.")
        print("Creating a dummy DataFrame for demonstration purposes.")
        data = {
            'VH': ["EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMSWVRQAPGKGLEWVSAISGSGGSTYYADSVKGRFTISRDNSKNTLYLQMNSLRAEDTAVYYCAKDRLGRYFDYWGQGTLVTVSS",
                   "QVQLQESGPGLVKPSQTLSLTCTVSGGSISSYYWSWIRQPPGKGLEWIGYIYYSGSTYYNPSLKSRVTISVDTSKNQFSLKLSSVTAADTAVYYCARWDYLRDYWGQGTLVTVSS"] * 5,
            'VL': ["DIQMTQSPSSLSASVGDRVTITCRASQGISSALAWYQQKPGKAPKLLIYDASSLESGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQFNSYPLTFGGGTKVEIK",
                   "EIVLTQSPGTLSLSPGERATLSCRASQSVSSSYLAWYQQKPGQAPRLLIYGASSRATGIPDRFSGSGSGTDFTLTISRLEPEDFAVYYCQQYGSSPTFGGGTKVEIK"] * 5,
            'psr': [0.05, 0.25, 0.55, 0.08, 0.15, 0.02, 0.30, 0.60, 0.09, 0.20] 
        }
        df = pd.DataFrame(data)
        chain_max_len = MAX_LENGTH // 2 - 2 
        df['VH'] = df['VH'].apply(lambda x: x[:chain_max_len]) 
        df['VL'] = df['VL'].apply(lambda x: x[:chain_max_len])
        df = pd.concat([df]*2, ignore_index=True)

    if not all(col in df.columns for col in ['VH', 'VL', 'psr']):
        raise ValueError("DataFrame must contain 'VH', 'VL', and 'psr' columns.")

    
    df['combined_sequence'] = df['VH'] + 'X' + df['VL']
    
    
    print(f"Total sequences: {len(df)}")
    print(f"Label distribution:\n{df['label'].value_counts().sort_index()}")
    
    return df['combined_sequence'].tolist(), df['label'].to_numpy()

def generate_embeddings_batched(sequences, esm_model_name, tokenizer_name, device, max_len, batch_size_embed):
    print(f"\nLoading ESM model: {esm_model_name} for embedding generation...")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    model = AutoModel.from_pretrained(esm_model_name)
    model.eval()
    model.to(device)
    
    all_pooled_embeddings = []
    num_sequences = len(sequences)
    
    print(f"Generating Attention Pooled embeddings for {num_sequences} sequences with batch size {batch_size_embed}...")
    for i in range(0, num_sequences, batch_size_embed):
        batch_sequences = sequences[i:i+batch_size_embed]
        
        inputs = tokenizer.batch_encode_plus(
            batch_sequences, add_special_tokens=True, max_length=max_len,
            padding='max_length', truncation=True, return_tensors='pt',
            return_attention_mask=True
        )
        
        input_ids = inputs['input_ids'].to(device)
        attention_mask = inputs['attention_mask'].to(device)
        
        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_mask, return_dict=True)
            token_embeddings = outputs.last_hidden_state
            
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
            sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) 
            pooled_embeddings_batch = (sum_embeddings / sum_mask).cpu().numpy()
        
        all_pooled_embeddings.append(pooled_embeddings_batch)
        
        if (i // batch_size_embed) % 5 == 0 or i + batch_size_embed >= num_sequences : 
            print(f"  Processed batch {i // batch_size_embed + 1}/{(num_sequences + batch_size_embed -1) // batch_size_embed}")

        if device.type == 'mps': torch.mps.empty_cache()
        del input_ids, attention_mask, outputs, token_embeddings, input_mask_expanded, sum_embeddings, sum_mask, pooled_embeddings_batch
        gc.collect()
        
    print("Embedding generation complete.")
    del model, tokenizer 
    if device.type == 'cuda': torch.cuda.empty_cache()
    if device.type == 'mps': torch.mps.empty_cache()
    gc.collect()
    
    return np.concatenate(all_pooled_embeddings, axis=0)

def make_pipe(preprocessor, clf):
    return Pipeline([("prep", preprocessor), ("clf", clf)])

all_sequences, y_labels = load_and_preprocess_data_for_embeddings(csv_path='data.csv') 
overall_esm_results = []

for esm_name in ESM2_MODEL_NAMES:
    print(f"\n\n{'='*20} Processing ESM Model: {esm_name} {'='*20}")
    
    tokenizer_to_use = esm_name 
    X_embeddings = generate_embeddings_batched(all_sequences, esm_name, tokenizer_to_use, device, MAX_LENGTH, EMBEDDING_BATCH_SIZE)
    
    print(f"Generated embeddings shape for {esm_name}: {X_embeddings.shape}")
    print(f"Labels shape: {y_labels.shape}")

    current_embedding_dim = X_embeddings.shape[1]
    n_svd_components = min(150, current_embedding_dim - 1 if current_embedding_dim > 1 else 1)
    if current_embedding_dim <=1: 
        print(f"Warning: Embedding dimension {current_embedding_dim} is too low for SVD. Skipping SVD.")
        preproc_pipeline = Pipeline([("scaler", StandardScaler())])
    else:
        print(f"Using TruncatedSVD with n_components={n_svd_components} for {esm_name}")
        preproc_pipeline = Pipeline([
            ("scaler", StandardScaler()),
            ("svd", TruncatedSVD(n_components=n_svd_components, random_state=SEED)),
        ])

    grids = {
        "logreg": {
            "model": LogisticRegression(class_weight="balanced", multi_class="multinomial", solver="saga", max_iter=500, random_state=SEED, n_jobs=-1),
            "params": {"clf__C": np.logspace(-3, 2, 10), "clf__penalty": ["l1", "l2"]}
        },
        "svc_rbf": {
            "model": SVC(kernel="rbf", probability=True, class_weight="balanced", random_state=SEED),
            "params": {"clf__C": np.logspace(-2, 2, 10), "clf__gamma": np.logspace(-3, 0, 10)}
        },
        "knn": {
            "model": KNeighborsClassifier(n_jobs=-1),
            "params": {"clf__n_neighbors": range(3, 12, 2), "clf__weights": ["uniform", "distance"]}
        },
        "rf": {
            "model": RandomForestClassifier(class_weight="balanced_subsample", n_jobs=-1, random_state=SEED),
            "params": {"clf__n_estimators": [200, 300, 400], "clf__max_depth": [4, 6, 8, 10], "clf__max_features": ["sqrt", 0.3]}
        },
        "xgb": {
            "model": xgb.XGBClassifier(objective="multi:softprob", eval_metric="mlogloss", n_jobs=-1, random_state=SEED, use_label_encoder=False, num_class=NUM_CLASSES),
            "params": { "clf__n_estimators": [150, 250, 350], "clf__max_depth": [3, 4, 5], "clf__learning_rate": [0.05, 0.1],
                        "clf__subsample": [0.7, 0.9], "clf__colsample_bytree": [0.6, 0.8]}
        },
        "lgb": {
            "model": lgb.LGBMClassifier(objective="multiclass", random_state=SEED, n_jobs=-1, n_estimators=250, num_class=NUM_CLASSES),
            "params": {"clf__num_leaves": [31, 63], "clf__max_depth": [4, 6], "clf__learning_rate": [0.05, 0.1],
                       "clf__subsample": [0.7, 0.9], "clf__colsample_bytree": [0.6, 0.8]}
        },
    }

    skf_ml = StratifiedKFold(n_splits=N_SPLITS_ML, shuffle=True, random_state=SEED)
    cv_results_current_esm = defaultdict(list)
    scoring_for_search = "f1_macro" 

    for fold_idx, (train_idx, test_idx) in enumerate(skf_ml.split(X_embeddings, y_labels), 1):
        X_train_fold, X_test_fold = X_embeddings[train_idx], X_embeddings[test_idx]
        y_train_ml, y_test_ml = y_labels[train_idx], y_labels[test_idx]
        
        print(f"\n===== ML Fold {fold_idx}/{N_SPLITS_ML} for {esm_name} =====")
        fold_fitted_models = {} 

        for model_name, cfg in grids.items():
            print(f"\n  ⏳ Tuning {model_name} for {esm_name}, Fold {fold_idx}...")
            pipe = make_pipe(preproc_pipeline, cfg["model"])
            
            search = RandomizedSearchCV(
                estimator=pipe, param_distributions=cfg["params"], n_iter=N_ITER_RANDOM_SEARCH,
                scoring=scoring_for_search, cv=CV_RANDOM_SEARCH, random_state=SEED, n_jobs=-1, verbose=0
            )
            search.fit(X_train_fold, y_train_ml)
            best_estimator_for_model = search.best_estimator_
            fold_fitted_models[model_name] = best_estimator_for_model
            preds = best_estimator_for_model.predict(X_test_fold)

            macro_f1 = f1_score(y_test_ml, preds, average="macro", zero_division=0)
            kappa = cohen_kappa_score(y_test_ml, preds, weights="quadratic")
            
            roc_auc_val = np.nan
            if hasattr(best_estimator_for_model, "predict_proba"):
                probas = best_estimator_for_model.predict_proba(X_test_fold)
                unique_labels_in_fold = np.unique(y_test_ml)
                if len(unique_labels_in_fold) == NUM_CLASSES:
                    try:
                        roc_auc_val = roc_auc_score(y_test_ml, probas, multi_class='ovr', average='weighted', labels=list(range(NUM_CLASSES)))
                    except ValueError as e_roc:
                        print(f"    Warning: ROC AUC for {model_name} (Fold {fold_idx}) failed: {e_roc}")
                else:
                    print(f"    Info: ROC AUC for {model_name} (Fold {fold_idx}) N/A (only {len(unique_labels_in_fold)}/{NUM_CLASSES} classes in y_test_ml).")
            else:
                print(f"    Info: {model_name} does not have predict_proba. ROC AUC N/A.")

            cv_results_current_esm[f"{model_name}_f1"].append(macro_f1)
            cv_results_current_esm[f"{model_name}_kappa"].append(kappa)
            cv_results_current_esm[f"{model_name}_roc_auc"].append(roc_auc_val)
            print(f"    → Best {model_name} (Fold {fold_idx}) Macro-F1: {macro_f1:.4f} / Kappa: {kappa:.4f} / ROC AUC: {roc_auc_val:.4f}")

        fold_model_performances = []
        for model_name_key in fold_fitted_models.keys():
            model_fold_f1 = cv_results_current_esm[f"{model_name_key}_f1"][-1] 
            fold_model_performances.append((model_name_key, model_fold_f1, fold_fitted_models[model_name_key]))
        
        top3_for_fold_ensemble = sorted(fold_model_performances, key=lambda item: item[1], reverse=True)[:3]
        ensemble_estimators_this_fold = [(item[0], item[2]) for item in top3_for_fold_ensemble]
        
        if len(ensemble_estimators_this_fold) < 1:
            print(f"  🚫 Could not form ensemble for Fold {fold_idx}. Skipping.")
            cv_results_current_esm["ens_f1"].append(np.nan)
            cv_results_current_esm["ens_kappa"].append(np.nan)
            cv_results_current_esm["ens_roc_auc"].append(np.nan)
        else:
            print(f"  Ensemble for Fold {fold_idx} using: {[name for name, _ in ensemble_estimators_this_fold]}")
            ens = VotingClassifier(
                estimators=ensemble_estimators_this_fold, voting="soft", 
                weights=[3,2,1][:len(ensemble_estimators_this_fold)] 
            )
            ens.fit(X_train_fold, y_train_ml)
            ens_preds = ens.predict(X_test_fold)
            
            ens_macro_f1 = f1_score(y_test_ml, ens_preds, average="macro", zero_division=0)
            ens_kappa = cohen_kappa_score(y_test_ml, ens_preds, weights="quadratic")

            ens_roc_auc_val = np.nan
            if hasattr(ens, "predict_proba"):
                ens_probas = ens.predict_proba(X_test_fold)
                unique_labels_in_fold_ens = np.unique(y_test_ml)
                if len(unique_labels_in_fold_ens) == NUM_CLASSES:
                    try:
                        ens_roc_auc_val = roc_auc_score(y_test_ml, ens_probas, multi_class='ovr', average='weighted', labels=list(range(NUM_CLASSES)))
                    except ValueError as e_roc_ens:
                        print(f"    Warning: ROC AUC for Ensemble (Fold {fold_idx}) failed: {e_roc_ens}")
                else:
                    print(f"    Info: ROC AUC for Ensemble (Fold {fold_idx}) N/A (only {len(unique_labels_in_fold_ens)}/{NUM_CLASSES} classes in y_test_ml).")
            else:
                print(f"    Warning: Ensemble does not have predict_proba. ROC AUC N/A.")
            
            cv_results_current_esm["ens_f1"].append(ens_macro_f1)
            cv_results_current_esm["ens_kappa"].append(ens_kappa)
            cv_results_current_esm["ens_roc_auc"].append(ens_roc_auc_val)
            print(f"  ✅ Ensemble (Fold {fold_idx}) Macro-F1: {ens_macro_f1:.4f} / Kappa: {ens_kappa:.4f} / ROC AUC: {ens_roc_auc_val:.4f}")
        
        del X_train_fold, X_test_fold, y_train_ml, y_test_ml, fold_fitted_models 
        if 'ens' in locals(): del ens 
        gc.collect()

    print(f"\n===== Averaged Results for {esm_name} (over {N_SPLITS_ML} ML folds) =====")
    esm_summary = {"ESM Model": esm_name}
    model_keys_for_summary = list(grids.keys()) + ["ens"]
    for model_key in model_keys_for_summary:
        mean_f1 = np.nanmean(cv_results_current_esm.get(f"{model_key}_f1", [np.nan]))
        std_f1 = np.nanstd(cv_results_current_esm.get(f"{model_key}_f1", [np.nan]))
        mean_kappa = np.nanmean(cv_results_current_esm.get(f"{model_key}_kappa", [np.nan]))
        std_kappa = np.nanstd(cv_results_current_esm.get(f"{model_key}_kappa", [np.nan]))
        mean_roc_auc = np.nanmean(cv_results_current_esm.get(f"{model_key}_roc_auc", [np.nan]))
        std_roc_auc = np.nanstd(cv_results_current_esm.get(f"{model_key}_roc_auc", [np.nan]))
        
        print(f"  {model_key:<8} Macro-F1 = {mean_f1:.4f} ± {std_f1:.4f} | Kappa = {mean_kappa:.4f} ± {std_kappa:.4f} | ROC AUC = {mean_roc_auc:.4f} ± {std_roc_auc:.4f}")
        if model_key == "ens": 
            esm_summary["Ensemble Macro F1 Mean"] = mean_f1
            esm_summary["Ensemble Macro F1 Std"] = std_f1
            esm_summary["Ensemble Kappa Mean"] = mean_kappa
            esm_summary["Ensemble Kappa Std"] = std_kappa
            esm_summary["Ensemble ROC AUC Mean"] = mean_roc_auc
            esm_summary["Ensemble ROC AUC Std"] = std_roc_auc
            
    overall_esm_results.append(esm_summary)
    del X_embeddings, cv_results_current_esm 
    gc.collect()
    if device.type == 'cuda': torch.cuda.empty_cache()
    if device.type == 'mps': torch.mps.empty_cache()

print("\n\n{'='*30} Overall Summary Across ESM Models {'='*30}")
summary_df = pd.DataFrame(overall_esm_results)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', 1000)
print(summary_df.round(4).to_string(index=False))
