In [None]:
# ============================================================
# MODEL DEVELOPMENT PIPELINE FOR TNBC SUBTYPE CLASSIFICATION
# ============================================================

import pandas as pd
import numpy as np
import pickle
import time
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.feature_selection import RFE
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import confusion_matrix
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
from xgboost import XGBClassifier
import shap
import matplotlib.pyplot as plt

# ------------------------------------------------------------
# 1. LOAD SAVED CLUSTERING OUTPUT
# ------------------------------------------------------------
with open("clustering_output.pkl", "rb") as f:
    data = pickle.load(f)

X_train_df = data["X_train_df"]
X_test_df  = data["X_test_df"]

# ------------------------------------------------------------
# 2. LOAD DEG FEATURE LISTS
# ------------------------------------------------------------
files = [
    "significant_genes_C1_v2.csv",
    "significant_genes_C2_v2.csv",
    "significant_genes_C3_v2.csv",
    "significant_genes_C4_v2.csv"
]

dfs = [pd.read_csv(f, index_col=0) for f in files]
for df in dfs:
    df.index.name = None

# DEG union
df = pd.concat(dfs, axis=1, join="outer")

# Load subtypes assigned from clustering
subs = pd.read_csv("only_subtypes_v2.csv")
subs.set_index("Unnamed: 0", inplace=True)
subs.index.name = None

df["subtype"] = subs["subtype"]

# ------------------------------------------------------------
# 3. DATA PREPROCESSING
# ------------------------------------------------------------
print("Missing values:", df.isna().all().sum())
print(df["subtype"].value_counts())

X_train_deg = df.drop(columns=["subtype"])
y_train      = df["subtype"]

y_test = X_test_df["subtype"]
X_test_df = X_test_df.drop(columns=["subtype"])

# Remove duplicate gene names
X_train_deg = X_train_deg.groupby(X_train_deg.columns, axis=1).mean()
X_test_df   = X_test_df.groupby(X_test_df.columns, axis=1).mean()

# Remove highly correlated genes
corr = X_train_deg.corr()
upper = corr.where(np.triu(np.ones(corr.shape), k=1).astype(bool))
to_drop = [col for col in upper.columns if any(upper[col] > 0.80)]
X_train_filtered = X_train_deg.drop(columns=to_drop)

deg_genes = X_train_deg.columns.tolist()
X_test_deg = X_test_df[deg_genes]
X_test_filtered = X_test_deg.drop(columns=to_drop)

# Encode labels
le = LabelEncoder()
y_train_encoded = le.fit_transform(y_train)
y_test_encoded  = le.transform(y_test)

print("Label Mapping:")
for o, e in zip(le.classes_, le.transform(le.classes_)):
    print(f"{o} → {e}")

# Standardize
scaler = StandardScaler()
col_train = X_train_filtered.columns
col_test  = X_test_filtered.columns

X_train_scaled = scaler.fit_transform(X_train_filtered)
X_test_scaled  = scaler.transform(X_test_filtered)

X_train_scaled_df = pd.DataFrame(X_train_scaled, columns=col_train)
X_test_scaled_df  = pd.DataFrame(X_test_scaled,  columns=col_test)

# ------------------------------------------------------------
# 4. RANDOM FOREST RFE FEATURE SELECTION
# ------------------------------------------------------------
rf = RandomForestClassifier(random_state=42)

n_features_to_select = 50  # Adjust as needed

start_time = time.time()
rfe = RFE(estimator=rf, n_features_to_select=n_features_to_select, verbose=1)
rfe.fit(X_train_scaled_df, y_train)
end_time = time.time()

print(f"RFE completed in {end_time - start_time:.2f} seconds")

selected_features = rfe.support_
X_train_rfe = X_train_scaled_df.loc[:, selected_features]
X_test_rfe  = X_test_scaled_df.loc[:, selected_features]

top_features = X_train_filtered.columns[selected_features]

# ------------------------------------------------------------
# 5. TRAIN MODELS (RF, XGB, DT, SVM) — CROSS-VALIDATION
# ------------------------------------------------------------

models = {
    "RandomForest": RandomForestClassifier(n_estimators=100, criterion="entropy",
                                           random_state=42, n_jobs=-1),
    "XGB": XGBClassifier(objective="multi:softmax", num_class=5, random_state=42),
    "DecisionTree": DecisionTreeClassifier(random_state=42),
    "SVM": SVC(kernel="rbf", probability=True, random_state=42)
}

results = {}

X_train_np = X_train_rfe.values
y_train_np = np.array(y_train_encoded)

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

for model_name, clf in models.items():
    print("\n===============================")
    print(f"Training model: {model_name}")
    print("===============================")

    all_accuracy, all_sensitivity, all_specificity, all_precision = [], [], [], []
    best_model = None
    best_acc = 0

    for fold, (train_idx, val_idx) in enumerate(skf.split(X_train_np, y_train_np), 1):
        X_tr, X_val = X_train_np[train_idx], X_train_np[val_idx]
        y_tr, y_val = y_train_np[train_idx], y_train_np[val_idx]

        clf.fit(X_tr, y_tr)
        y_pred = clf.predict(X_val)

        cm = confusion_matrix(y_val, y_pred)

        sensitivity = []
        specificity = []
        precision   = []
        accuracy    = []

        for i in range(cm.shape[0]):
            TP = cm[i, i]
            FP = cm[:, i].sum() - TP
            FN = cm[i, :].sum() - TP
            TN = cm.sum() - TP - FP - FN

            sensitivity.append(TP/(TP+FN) if TP+FN>0 else 0)
            specificity.append(TN/(TN+FP) if TN+FP>0 else 0)
            precision.append(TP/(TP+FP) if TP+FP>0 else 0)
            accuracy.append((TP+TN)/cm.sum())

        fold_acc = np.mean(accuracy)
        all_accuracy.append(fold_acc)
        all_sensitivity.append(np.mean(sensitivity))
        all_specificity.append(np.mean(specificity))
        all_precision.append(np.mean(precision))

        if fold_acc > best_acc:
            best_acc = fold_acc
            best_model = clf

        print(f"Fold {fold}: Accuracy = {fold_acc:.4f}")

    results[model_name] = {
        "accuracy": np.mean(all_accuracy),
        "sensitivity": np.mean(all_sensitivity),
        "specificity": np.mean(all_specificity),
        "precision": np.mean(all_precision),
        "best_model": best_model
    }

    print("\n--- CV SUMMARY ---")
    print(f"Accuracy:   {np.mean(all_accuracy):.4f}")
    print(f"Sensitivity:{np.mean(all_sensitivity):.4f}")
    print(f"Specificity:{np.mean(all_specificity):.4f}")
    print(f"Precision:  {np.mean(all_precision):.4f}")

# ------------------------------------------------------------
# 6. TEST SET EVALUATION (RF SHAP ONLY)
# ------------------------------------------------------------
rf_best = results["RandomForest"]["best_model"]
y_test_pred = rf_best.predict(X_test_rfe)

print("\nConfusion Matrix (TEST SET, RF):")
print(confusion_matrix(y_test_encoded, y_test_pred))

# ------------------------------------------------------------
# 7. SHAP ANALYSIS FOR RANDOM FOREST
# ------------------------------------------------------------
explainer = shap.TreeExplainer(rf_best)
feature_names = top_features.tolist()

X_train_df_shap = pd.DataFrame(X_train_rfe, columns=feature_names)

shap_values = explainer.shap_values(X_train_df_shap)
shap_values_corrected = np.moveaxis(shap_values, -1, 0)

for i in range(len(shap_values_corrected)):
    plt.title(f"SHAP Summary Plot – Class {i}")
    shap.summary_plot(shap_values_corrected[i], X_train_df_shap, show=False)
    plt.show()