In [1]:
# -*- coding: utf-8 -*-
import os, json, joblib, warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.metrics import average_precision_score

# ---------- Paths (relative to this notebook in compare/) ----------
HERE = os.getcwd()
COMPARE = os.path.abspath(HERE)

MA_DIR  = os.path.join(COMPARE, "ma")
MDD_DIR = os.path.join(COMPARE, "mdd")

# ---------- Metrics helpers ----------
def to_py(o):
    import numpy as _np
    if isinstance(o, (_np.generic,)):
        return o.item()
    if isinstance(o, (_np.ndarray,)):
        return o.tolist()
    return o

def variant_metrics(y_true, y_prob, thr=0.9):
    y_true = np.asarray(y_true).astype(int)
    y_prob = np.asarray(y_prob).astype(float)
    pred = (y_prob >= thr).astype(int)
    tp = int(np.sum((pred==1)&(y_true==1)))
    fp = int(np.sum((pred==1)&(y_true==0)))
    fn = int(np.sum((pred==0)&(y_true==1)))
    tn = int(np.sum((pred==0)&(y_true==0)))
    ppv = float(tp/max(tp+fp,1))
    tpr = float(tp/max(tp+fn,1))
    f1  = float(2*ppv*tpr/max(ppv+tpr,1e-12))
    fdr = float(fp/max(tp+fp,1))
    return {"PPV":ppv,"TPR":tpr,"F1":f1,"FDR":fdr,"TP":tp,"FP":fp,"FN":fn,"TN":tn}

def cluster_loci(chr_arr, pos_arr, window=250_000):
    """Group variants into putative loci by genomic windows (±window)."""
    df = pd.DataFrame({"chr": chr_arr, "pos": pos_arr}).copy()
    df["chr"] = df["chr"].astype(str)
    df["pos"] = df["pos"].astype(int)
    df["_i"] = np.arange(len(df))
    out = pd.Series(index=np.arange(len(df)), dtype=object)
    for c in df["chr"].unique():
        sub = df[df["chr"]==c].sort_values("pos")
        prev = None; cid = None; start = None
        for _, r in sub.iterrows():
            p = int(r["pos"])
            if prev is None or (p - prev) > window:
                start = p
                cid = f"{c}:{start}"
            out.loc[r["_i"]] = f"{cid}-{p}"
            prev = p
    return out.values

def locus_metrics(chr_arr, pos_arr, y_true, y_prob, thr=0.9, window=250_000):
    clusters = cluster_loci(chr_arr, pos_arr, window=window)
    df = pd.DataFrame({"cluster": clusters, "y": y_true, "p": y_prob})
    agg = df.groupby("cluster").agg(y=("y","max"), p=("p","max")).reset_index(drop=True)
    return variant_metrics(agg["y"].values, agg["p"].values, thr=thr)


In [2]:
class TransformerClassifier(nn.Module):
    def __init__(self, input_dim, num_heads=4, num_layers=2, hidden_dim=64, output_dim=1, p_drop=0.05):
        super().__init__()
        self.input_mapping = nn.Linear(input_dim, 20)
        self.input_dropout = nn.Dropout(p_drop)
        enc_layer = nn.TransformerEncoderLayer(
            d_model=20, nhead=num_heads, dim_feedforward=hidden_dim,
            dropout=0.1, activation='relu', batch_first=False
        )
        self.transformer = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
        self.fc = nn.Linear(20, output_dim)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        x = self.input_mapping(x)
        x = self.input_dropout(x)
        x = x.unsqueeze(1).permute(1,0,2)  # [1,B,20]
        x = self.transformer(x)[0]         # [B,20]
        x = self.fc(x)                     # [B,1]
        return self.sigmoid(x)


In [3]:
# ---------- Load MA config ----------
ma_cfg_path = os.path.join(MA_DIR, "config.json")
with open(ma_cfg_path, "r", encoding="utf-8") as f:
    ma_cfg = json.load(f)

feature_cols = ma_cfg["feature_cols"]
label_col    = ma_cfg["label_col"]
MA_THR       = float(ma_cfg.get("threshold", 0.9))
MA_WINDOW_BP = int(ma_cfg.get("locus_window_bp", 250_000))

# ---------- Load MA data ----------
ma_csv = os.path.join(MA_DIR, "data", "final_training_set_data.csv")
ma_scaler_pkl = os.path.join(MA_DIR, "artifacts", "scaler.pkl")
ma_scratch_pth = os.path.join(MA_DIR, "models", "model_scratch.pth")
ma_transfer_pth = os.path.join(MA_DIR, "models", "model_transfer.pth")

df_ma = pd.read_csv(ma_csv)
X_ma = df_ma[feature_cols].values.astype(np.float32)
y_ma = df_ma[label_col].astype(int).values
chr_ma = df_ma["chr"].values
pos_ma = df_ma["bpos"].astype(int).values

scaler_ma = joblib.load(ma_scaler_pkl)
X_ma = scaler_ma.transform(X_ma).astype(np.float32)

Xt_ma = torch.tensor(X_ma, dtype=torch.float32)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

input_dim = X_ma.shape[1]
def _load_and_predict_ma(model_path):
    model = TransformerClassifier(input_dim=input_dim).to(device)
    state = torch.load(model_path, map_location="cpu")
    try:
        model.load_state_dict(state, strict=False)
    except Exception:
        model = state
    model.eval()
    with torch.no_grad():
        pred = model(Xt_ma.to(device)).cpu().numpy().ravel()
    return pred

pred_ma_scratch  = _load_and_predict_ma(ma_scratch_pth)
pred_ma_transfer = _load_and_predict_ma(ma_transfer_pth)

# ---------- MA: print results (PR-AUC + Variant/Locus @ threshold), no plots ----------
res_ma = {}
res_ma["Scratch"] = {
    "PR_AUC": float(average_precision_score(y_ma, pred_ma_scratch)),
    f"Variant@{MA_THR:.2f}": variant_metrics(y_ma, pred_ma_scratch, thr=MA_THR),
    f"Locus@{MA_THR:.2f}":   locus_metrics(chr_ma, pos_ma, y_ma, pred_ma_scratch, thr=MA_THR, window=MA_WINDOW_BP),
}
res_ma["Transfer"] = {
    "PR_AUC": float(average_precision_score(y_ma, pred_ma_transfer)),
    f"Variant@{MA_THR:.2f}": variant_metrics(y_ma, pred_ma_transfer, thr=MA_THR),
    f"Locus@{MA_THR:.2f}":   locus_metrics(chr_ma, pos_ma, y_ma, pred_ma_transfer, thr=MA_THR, window=MA_WINDOW_BP),
}

print("\n=== MA: Transfer vs Scratch (no plots) ===")
print(json.dumps(res_ma, indent=2, default=to_py))



=== MA: Transfer vs Scratch (no plots) ===
{
  "Scratch": {
    "PR_AUC": 0.9405108673729293,
    "Variant@0.90": {
      "PPV": 0.8655519535915372,
      "TPR": 0.9220283533260633,
      "F1": 0.8928980022881281,
      "FDR": 0.13444804640846272,
      "TP": 5073,
      "FP": 788,
      "FN": 429,
      "TN": 98248
    },
    "Locus@0.90": {
      "PPV": 0.682129891085115,
      "TPR": 0.9220283533260633,
      "F1": 0.7841409691629955,
      "FDR": 0.31787010891488504,
      "TP": 1691,
      "FP": 788,
      "FN": 143,
      "TN": 98248
    }
  },
  "Transfer": {
    "PR_AUC": 0.9861127642825256,
    "Variant@0.90": {
      "PPV": 0.9777537796976242,
      "TPR": 0.8227917121046892,
      "F1": 0.8936044216344255,
      "FDR": 0.02224622030237581,
      "TP": 4527,
      "FP": 103,
      "FN": 975,
      "TN": 98933
    },
    "Locus@0.90": {
      "PPV": 0.9361042183622829,
      "TPR": 0.8227917121046892,
      "F1": 0.8757980266976203,
      "FDR": 0.06389578163771713,
      "TP

In [6]:
# ---------- Load MDD config & artifacts ----------
mdd_cfg_path = os.path.join(MDD_DIR, "config.json")
with open(mdd_cfg_path, "r", encoding="utf-8") as f:
    mdd_cfg = json.load(f)

MDD_THR = float(mdd_cfg.get("threshold", 0.9))
MDD_WINDOW_BP = 250_000  # use ±250 kb window to approximate LD clumping

mdd_npz = os.path.join(MDD_DIR, "data", "splits_and_arrays.npz")
mdd_chrpos_csv = os.path.join(MDD_DIR, "data", "chrpos_val.csv")
mdd_scaler_pkl = os.path.join(MDD_DIR, "artifacts", "scaler.pkl")
pth_insight = os.path.join(MDD_DIR, "models", "insightgwas_transformer.pt")
pth_mlp     = os.path.join(MDD_DIR, "models", "deepgwas_mlp.h5")

# Load arrays (already scaled in your saved pipeline; we will use them as-is)
npz = np.load(mdd_npz)
X_val = npz["X_val"]
y_val = npz["y_val"]

chrpos_val = pd.read_csv(mdd_chrpos_csv)
chr_mdd = chrpos_val["chr"].values
pos_mdd = chrpos_val["bpos"].astype(int).values

# ---------- Rebuild Transformer and load weights ----------
class Transformer_MDD(nn.Module):
    def __init__(self, input_dim, num_heads=4, num_layers=2, hidden_dim=64, output_dim=1):
        super().__init__()
        self.input_mapping = nn.Linear(input_dim, 20)
        enc = nn.TransformerEncoderLayer(d_model=20, nhead=num_heads, dim_feedforward=hidden_dim)
        self.transformer = nn.TransformerEncoder(enc, num_layers=num_layers)
        self.fc = nn.Linear(20, output_dim)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        x = self.input_mapping(x)
        x = x.unsqueeze(1).permute(1,0,2)  # [1,B,20]
        x = self.transformer(x)[0]         # [B,20]
        x = self.fc(x)
        return self.sigmoid(x)

tconf = mdd_cfg.get("transformer", {"num_heads":4,"num_layers":2,"hidden_dim":64})
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mdl_insight = Transformer_MDD(
    input_dim=int(mdd_cfg["input_dim"]),
    num_heads=int(tconf["num_heads"]),
    num_layers=int(tconf["num_layers"]),
    hidden_dim=int(tconf["hidden_dim"])
).to(device)
mdl_insight.load_state_dict(torch.load(pth_insight, map_location=device))
mdl_insight.eval()

from tensorflow.keras.models import load_model as tf_load_model
mlp = tf_load_model(pth_mlp)

# ---------- Predict ----------
with torch.no_grad():
    Xt_val = torch.tensor(X_val, dtype=torch.float32).to(device)
    preds_insight = mdl_insight(Xt_val).cpu().numpy().ravel()

preds_mlp = mlp.predict(X_val, verbose=0).ravel()

# ---------- MDD: print results (PR-AUC + Variant/Locus @ threshold), no plots ----------
res_mdd = {}
res_mdd["InsightGWAS(Transformer)"] = {
    "PR_AUC": float(average_precision_score(y_val, preds_insight)),
    f"Variant@{MDD_THR:.2f}": variant_metrics(y_val, preds_insight, thr=MDD_THR),
    f"Locus@{MDD_THR:.2f}":   locus_metrics(chr_mdd, pos_mdd, y_val, preds_insight, thr=MDD_THR, window=MDD_WINDOW_BP),
}
res_mdd["DeepGWAS(MLP)"] = {
    "PR_AUC": float(average_precision_score(y_val, preds_mlp)),
    f"Variant@{MDD_THR:.2f}": variant_metrics(y_val, preds_mlp, thr=MDD_THR),
    f"Locus@{MDD_THR:.2f}":   locus_metrics(chr_mdd, pos_mdd, y_val, preds_mlp, thr=MDD_THR, window=MDD_WINDOW_BP),
}

print("\n=== MDD: InsightGWAS vs DeepGWAS (no plots) ===")
print(json.dumps(res_mdd, indent=2, default=to_py))





=== MDD: InsightGWAS vs DeepGWAS (no plots) ===
{
  "InsightGWAS(Transformer)": {
    "PR_AUC": 0.9944808690190807,
    "Variant@0.90": {
      "PPV": 0.9848249027237355,
      "TPR": 0.9401931649331352,
      "F1": 0.9619916381603953,
      "FDR": 0.015175097276264591,
      "TP": 2531,
      "FP": 39,
      "FN": 161,
      "TN": 24869
    },
    "Locus@0.90": {
      "PPV": 0.9814462416745956,
      "TPR": 0.9407204742362061,
      "F1": 0.960651920838184,
      "FDR": 0.018553758325404377,
      "TP": 2063,
      "FP": 39,
      "FN": 130,
      "TN": 24869
    }
  },
  "DeepGWAS(MLP)": {
    "PR_AUC": 0.9892302971830891,
    "Variant@0.90": {
      "PPV": 0.9865891819400984,
      "TPR": 0.8198365527488856,
      "F1": 0.895516331913167,
      "FDR": 0.013410818059901655,
      "TP": 2207,
      "FP": 30,
      "FN": 485,
      "TN": 24878
    },
    "Locus@0.90": {
      "PPV": 0.9835975943138326,
      "TPR": 0.8203374373005016,
      "F1": 0.8945798110392839,
      "FDR": 0.01