In [4]:
import torch
import torch.nn as nn
import shap
import numpy as np
import os
import matplotlib.pyplot as plt
import pandas as pd


In [5]:
class Expert(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size, p_dropout=0.3):
        super().__init__()
        layers = []
        prev = input_size
        for h in hidden_sizes:
            layers += [nn.Linear(prev, h), nn.ReLU(), nn.Dropout(p_dropout)]
            prev = h
        layers.append(nn.Linear(prev, output_size))
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        return self.net(x)

class Tower(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, p_dropout=0.4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(p_dropout),
            nn.Linear(hidden_size, output_size)
        )
    def forward(self, x):
        return self.net(x)

class MMoE(nn.Module):
    def __init__(self, input_size, num_experts, expert_hidden, expert_output_dim, tower_hidden_dim, task_output_dims):
        super().__init__()
        self.num_tasks = len(task_output_dims)
        self.experts = nn.ModuleList(
            Expert(input_size, expert_hidden, expert_output_dim)
            for _ in range(num_experts)
        )
        self.gates = nn.ModuleList(
            nn.Linear(input_size, num_experts, bias=False)
            for _ in range(self.num_tasks)
        )
        self.towers = nn.ModuleList(
            Tower(expert_output_dim, tower_hidden_dim, out_dim)
            for out_dim in task_output_dims
        )
    def forward(self, x):
        expert_stack = torch.stack([expert(x) for expert in self.experts], dim=0)
        outputs = []
        for gate, tower in zip(self.gates, self.towers):
            weights = F.softmax(gate(x), dim=1)
            weighted = torch.einsum("be,ebd->bd", weights, expert_stack)
            y = tower(weighted)
            outputs.append(y)
        return outputs

# ---- 2. Placeholder for Isotonic Calibration ----
class IsotonicCalibrator:
    def __init__(self, num_tasks):
        self.num_tasks = num_tasks
        self.calibrators = [IsotonicRegression(out_of_bounds='clip') for _ in range(num_tasks)]

    def fit(self, logits, targets):
        for i in range(self.num_tasks):
            self.calibrators[i].fit(logits[:, i], targets[:, i])

    def predict(self, logits):
        calibrated_probs = np.zeros_like(logits)
        for i in range(self.num_tasks):
            calibrated_probs[:, i] = self.calibrators[i].predict(logits[:, i])
        return calibrated_probs


In [14]:
# Load your data (adjust path as needed)
df = pd.read_csv("/users/2/wan00232/ondemand/jupyter/ED_multitask/data/master_dataset_with_outcomes.csv")
col_list = [
    "age", "gender", "n_ed_30d", "n_ed_90d", "n_ed_365d", "n_hosp_30d", "n_hosp_90d", "n_hosp_365d", "n_icu_30d", "n_icu_90d", "n_icu_365d", "triage_temperature", "triage_heartrate", "triage_resprate", "triage_o2sat", "triage_sbp", "triage_dbp", "triage_pain", "triage_acuity", "chiefcom_chest_pain", "chiefcom_abdominal_pain", "chiefcom_headache", "chiefcom_shortness_of_breath", "chiefcom_back_pain", "chiefcom_cough", "chiefcom_nausea_vomiting", "chiefcom_fever_chills", "chiefcom_syncope", "chiefcom_dizziness", "cci_MI", "cci_CHF", "cci_PVD", "cci_Stroke", "cci_Dementia", "cci_Pulmonary", "cci_Rheumatic", "cci_PUD", "cci_Liver1", "cci_DM1", "cci_DM2", "cci_Paralysis", "cci_Renal", "cci_Cancer1", "cci_Liver2", "cci_Cancer2", "cci_HIV", "eci_Arrhythmia", "eci_Valvular", "eci_PHTN",  "eci_HTN1", "eci_HTN2", "eci_NeuroOther", "eci_Hypothyroid", "eci_Lymphoma", "eci_Coagulopathy", "eci_Obesity", "eci_WeightLoss", "eci_FluidsLytes", "eci_BloodLoss", "eci_Anemia", "eci_Alcohol", "eci_Drugs","eci_Psychoses", "eci_Depression"
]
outcome_list = [
    "outcome_hospitalization", "outcome_critical", "outcome_ed_revisit_3d", "outcome_sepsis", "outcome_pneumonia_bacterial", "outcome_pneumonia_viral", "outcome_pneumonia_all", "outcome_ards", "outcome_pe", "outcome_copd_exac", "outcome_acs_mi", "outcome_stroke", "outcome_aki"
]

X_np = df.loc[:, col_list].select_dtypes(include=[np.number, bool]).fillna(0).to_numpy(dtype=float)

# Model setup (match your training config)
input_size = X_np.shape[1]
num_experts = 8
expert_hidden = [128, 64]
expert_output_dim = 32
tower_hidden_dim = 16
task_output_dims = [1] * 13  # Adjust if needed

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MMoE(input_size, num_experts, expert_hidden, expert_output_dim, tower_hidden_dim, task_output_dims).to(device)
model.load_state_dict(torch.load("../best_mmoe_iso.pt", map_location=device))
model.eval()

# Use a small background set for SHAP
background = torch.from_numpy(X_np[:1000]).float().to(device)
test_samples = torch.from_numpy(X_np[1000:1500]).float().to(device)

In [15]:
# --- Helper: Wrapper for expert/gate ---
class ExpertWrapper(torch.nn.Module):
    def __init__(self, expert):
        super().__init__()
        self.expert = expert
    def forward(self, x):
        return self.expert(x)

class GateWrapper(torch.nn.Module):
    def __init__(self, gate):
        super().__init__()
        self.gate = gate
    def forward(self, x):
        return torch.softmax(self.gate(x), dim=1)

In [None]:
# --- SHAP for Experts ---
print("\nSHAP for Experts:")
for i, expert in enumerate(model.experts):
    expert_dir = f"SHAP_Plot/expert_{i+1}"
    os.makedirs(expert_dir, exist_ok=True)
    wrapper = ExpertWrapper(expert)
    explainer = shap.DeepExplainer(wrapper, background)
    shap_values = explainer.shap_values(test_samples)
    if isinstance(shap_values, list):
        shap_values = shap_values[0]
    # shap_values shape: (batch, output_dim) or (batch, output_dim, input_dim)
    # We want to aggregate over output_dim if present, to get (batch, input_dim)
    if shap_values.ndim == 3:
        # (batch, output_dim, input_dim) -> (batch, input_dim)
        shap_values_input = np.mean(np.abs(shap_values), axis=1)
    else:
        shap_values_input = shap_values  # (batch, input_dim)
    mean_abs_shap = np.abs(shap_values_input).mean(axis=0)
    n_feat = min(20, mean_abs_shap.shape[0], X_np.shape[1])
    top_idx = np.argsort(mean_abs_shap)[-n_feat:][::-1]
    feature_names = np.array(col_list)[top_idx]
    shap.summary_plot(shap_values_input[:, top_idx], X_np[1000:1500][:, top_idx], feature_names=feature_names, show=False)
    plt.title(f"Expert {i+1} Top {n_feat} Features")
    plt.tight_layout()
    plt.savefig(os.path.join(expert_dir, f"expert_{i+1}_shap.png"))
    plt.close()
    print(f"Expert {i+1} SHAP plot saved to {expert_dir}")

# --- SHAP for Gates ---
print("\nSHAP for Gates:")
for t, gate in enumerate(model.gates):
    gate_dir = f"SHAP_Plot/gate_{t+1}"
    os.makedirs(gate_dir, exist_ok=True)
    wrapper = GateWrapper(gate)
    explainer = shap.DeepExplainer(wrapper, background)
    shap_values = explainer.shap_values(test_samples, check_additivity=False)
    if isinstance(shap_values, list):
        shap_values = shap_values[0]
    # shap_values shape: (batch, output_dim) or (batch, output_dim, input_dim)
    if shap_values.ndim == 3:
        shap_values_input = np.mean(np.abs(shap_values), axis=1)
    else:
        shap_values_input = shap_values
    mean_abs_shap = np.abs(shap_values_input).mean(axis=0)
    n_feat = min(20, mean_abs_shap.shape[0], X_np.shape[1])
    top_idx = np.argsort(mean_abs_shap)[-n_feat:][::-1]
    feature_names = np.array(col_list)[top_idx]
    shap.summary_plot(shap_values_input[:, top_idx], X_np[1000:1500][:, top_idx], feature_names=feature_names, show=False)
    plt.title(f"Gate {t+1} Top {n_feat} Features")
    plt.tight_layout()
    plt.savefig(os.path.join(gate_dir, f"gate_{t+1}_shap.png"))
    plt.close()
    print(f"Gate {t+1} SHAP plot saved to {gate_dir}")



SHAP for Experts:
