<a href="https://colab.research.google.com/github/ramyahramzy/Colab/blob/main/interpretable.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip -q install torch numpy scikit-learn

import torch, torch.nn as nn
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# Make a dataset with 4 features, only 3 are truly informative
X, y = make_classification(
    n_samples=4000, n_features=4, n_informative=3, n_redundant=0,
    class_sep=1.5, random_state=42
)
feature_names = ["BMI", "Glucose", "Age", "Noise"]  # pretend these are real features

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.25, random_state=42, stratify=y
)

X_train_t = torch.tensor(X_train, dtype=torch.float32)
y_train_t = torch.tensor(y_train, dtype=torch.long)
X_test_t  = torch.tensor(X_test,  dtype=torch.float32)
y_test_t  = torch.tensor(y_test,  dtype=torch.long)


In [2]:
class LogisticReg(nn.Module):
    def __init__(self, n_in):
        super().__init__()
        self.linear = nn.Linear(n_in, 2)  # binary classes -> 2 logits
    def forward(self, x):
        return self.linear(x)

model_lr = LogisticReg(X_train.shape[1])
opt = torch.optim.Adam(model_lr.parameters(), lr=1e-2)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(200):
    model_lr.train()
    logits = model_lr(X_train_t)
    loss = loss_fn(logits, y_train_t)
    opt.zero_grad(); loss.backward(); opt.step()

model_lr.eval()
pred = model_lr(X_test_t).argmax(dim=1).cpu().numpy()
print("LogReg accuracy:", accuracy_score(y_test, pred))

# Inspect weights (class 1 minus class 0 effect)
with torch.no_grad():
    W = model_lr.linear.weight.detach().cpu().numpy()  # shape [2, 4]
    # difference between class-1 and class-0 logits per feature
    interpretable_w = W[1] - W[0]

print("Feature weights (higher => pushes toward class 1):")
for name, w in zip(feature_names, interpretable_w):
    print(f"  {name:8s}  {w:+.3f}")


LogReg accuracy: 0.939
Feature weights (higher => pushes toward class 1):
  BMI       +1.117
  Glucose   +1.569
  Age       -0.001
  Noise     +0.684


In [3]:
class MLP(nn.Module):
    def __init__(self, n_in):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_in, 32), nn.ReLU(),
            nn.Linear(32, 16), nn.ReLU(),
            nn.Linear(16, 2)
        )
    def forward(self, x): return self.net(x)

model_mlp = MLP(X_train.shape[1])
opt = torch.optim.Adam(model_mlp.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(15):
    model_mlp.train()
    logits = model_mlp(X_train_t)
    loss = loss_fn(logits, y_train_t)
    opt.zero_grad(); loss.backward(); opt.step()

model_mlp.eval()
pred = model_mlp(X_test_t).argmax(dim=1).cpu().numpy()
print("MLP accuracy:", accuracy_score(y_test, pred))


MLP accuracy: 0.861


In [4]:
rng = np.random.default_rng(0)

def perm_importance(model, X_test, y_test, repeats=5):
    base_pred = model(torch.tensor(X_test, dtype=torch.float32)).argmax(dim=1).cpu().numpy()
    base_acc = accuracy_score(y_test, base_pred)
    importances = []
    for j in range(X_test.shape[1]):
        drops = []
        for _ in range(repeats):
            Xp = X_test.copy()
            rng.shuffle(Xp[:, j])  # destroy info in feature j
            pred = model(torch.tensor(Xp, dtype=torch.float32)).argmax(dim=1).cpu().numpy()
            drops.append(base_acc - accuracy_score(y_test, pred))
        importances.append(np.mean(drops))
    return base_acc, np.array(importances)

# LR permutation importance
model_lr.eval()
acc_lr, imp_lr = perm_importance(model_lr, X_test, y_test)
print("\nPermutation importance (LogReg):")
for name, imp in sorted(zip(feature_names, imp_lr), key=lambda x: -x[1]):
    print(f"  {name:8s}  drop={imp:.4f}")

# MLP permutation importance
model_mlp.eval()
acc_mlp, imp_mlp = perm_importance(model_mlp, X_test, y_test)
print("\nPermutation importance (MLP):")
for name, imp in sorted(zip(feature_names, imp_mlp), key=lambda x: -x[1]):
    print(f"  {name:8s}  drop={imp:.4f}")



Permutation importance (LogReg):
  Glucose   drop=0.3130
  BMI       drop=0.0970
  Noise     drop=0.0398
  Age       drop=0.0000

Permutation importance (MLP):
  Glucose   drop=0.1906
  BMI       drop=0.0830
  Noise     drop=0.0316
  Age       drop=-0.0016
