SAINT

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import (
  accuracy_score, classification_report, roc_auc_score,
  f1_score, roc_curve
)

In [None]:
def run_saint_pipeline(
  X_train, y_train,
  X_val, y_val,
  X_test, y_test,
  feature_names,
  device="cuda" if torch.cuda.is_available() else "cpu",
  embed_dim=32,
  num_heads=4,
  num_layers=2,
  pretrain_epochs=10,
  finetune_epochs=10,
  lr=1e-4
):
  # === Dataset Wrapper ===
  class TabularDataset_saint(Dataset):
    def __init__(self, X, y):
      if isinstance(X, pd.DataFrame): X = X.values
      if isinstance(y, pd.Series): y = y.values
      self.X = torch.tensor(X, dtype=torch.float32)
      self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self): return len(self.X)
    def __getitem__(self, idx): return self.X[idx], self.y[idx]

  # === SAINT Model ===
  class SAINT(nn.Module):
    def __init__(self, input_dim, embed_dim=32, num_heads=4, num_layers=2):
      super().__init__()
      self.embedding = nn.Linear(input_dim, embed_dim)
      encoder_layer = nn.TransformerEncoderLayer(
          d_model=embed_dim, nhead=num_heads, batch_first=True
      )
      self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
      self.classifier = nn.Linear(embed_dim, 2)

    def forward(self, x):
      x = self.embedding(x).unsqueeze(1)
      x = self.transformer(x).squeeze(1)
      return self.classifier(x)

    # === Contrastive Loss ===
  class NTXentLoss_saint(nn.Module):
    def __init__(self, temperature=0.5):
      super().__init__()
      self.temperature = temperature

    def forward(self, z_i, z_j):
      z = torch.cat([z_i, z_j], dim=0)
      sim_matrix = torch.matmul(z, z.T) / self.temperature
      sim_exp = torch.exp(sim_matrix)
      mask = ~torch.eye(z.shape[0], dtype=torch.bool, device=z.device)
      sim_exp = sim_exp.masked_select(mask).view(z.shape[0], -1)
      positives = torch.exp(torch.sum(z_i * z_j, dim=-1) / self.temperature)
      positives = torch.cat([positives, positives], dim=0)
      loss = -torch.log(positives / sim_exp.sum(dim=1))
      return loss.mean()


  # === Initialize model
  input_dim = X_train.shape[1]
  model_saint = SAINT(input_dim, embed_dim, num_heads, num_layers).to(device)

  # === Pretraining
  pretrain_loader = DataLoader(TabularDataset_saint(X_train, y_train), batch_size=128, shuffle=True)
  optimizer_pre = torch.optim.Adam(model_saint.parameters(), lr=lr)
  contrastive_loss = NTXentLoss_saint()
  model_saint.train()
  for epoch in range(pretrain_epochs):
      total_loss = 0
      for x_batch, _ in pretrain_loader:
          x_batch = x_batch.to(device)
          x_i = x_batch + torch.randn_like(x_batch) * 0.01
          x_j = x_batch + torch.randn_like(x_batch) * 0.01
          z_i = model_saint(x_i)
          z_j = model_saint(x_j)
          loss = contrastive_loss(z_i, z_j)
          optimizer_pre.zero_grad()
          loss.backward()
          torch.nn.utils.clip_grad_norm_(model_saint.parameters(), 1.0)
          optimizer_pre.step()
          total_loss += loss.item()
      print(f"[Pretrain] Epoch {epoch+1}: Loss = {total_loss / len(pretrain_loader):.4f}")

  # === Fine-tuning
  train_loader = DataLoader(TabularDataset_saint(X_train, y_train), batch_size=128, shuffle=True)
  optimizer_ft = torch.optim.Adam(model_saint.parameters(), lr=lr)
  model_saint.train()
  for epoch in range(finetune_epochs):
      total_loss = 0
      for x_batch, y_batch in train_loader:
          x_batch, y_batch = x_batch.to(device), y_batch.to(device)
          logits = model_saint(x_batch)
          loss = nn.CrossEntropyLoss()(logits, y_batch)
          optimizer_ft.zero_grad()
          loss.backward()
          torch.nn.utils.clip_grad_norm_(model_saint.parameters(), 1.0)
          optimizer_ft.step()
          total_loss += loss.item()
      print(f"[Finetune] Epoch {epoch+1}: Loss = {total_loss / len(train_loader):.4f}")

  # === Predict Validation
  model_saint.eval()
  with torch.no_grad():
      x_val_tensor = torch.tensor(
          X_val.values if isinstance(X_val, pd.DataFrame) else X_val,
          dtype=torch.float32
      ).to(device)
      val_logits = model_saint(x_val_tensor)
      val_probs = torch.softmax(val_logits, dim=1)[:, 1].cpu().numpy()

  # === Best threshold
  thresholds = np.linspace(0, 1, 101)
  f1s = [f1_score(y_val, (val_probs >= t).astype(int)) for t in thresholds]
  best_threshold = thresholds[np.argmax(f1s)]
  print(f"\n✅ Best threshold (F1): {best_threshold:.2f}")

  # === Test Predictions
  with torch.no_grad():
      x_test_tensor = torch.tensor(
          X_test.values if isinstance(X_test, pd.DataFrame) else X_test,
          dtype=torch.float32
      ).to(device)
      test_logits = model_saint(x_test_tensor)
      test_probs = torch.softmax(test_logits, dim=1)[:, 1].cpu().numpy()
      test_preds = (test_probs >= best_threshold).astype(int)

    # === Metrics
  accuracy = accuracy_score(y_test, test_preds)
  auc = roc_auc_score(y_test, test_probs)
  report_str = classification_report(y_test, test_preds, target_names=["Class 0", "Class 1"])
  report_df = pd.DataFrame(classification_report(y_test, test_preds, output_dict=True)).T

  print(f"\n🔎 Test Accuracy: {accuracy:.2f}")
  print("Classification Report:")
  print(report_str)
  print(f"AUC-ROC (Test): {auc:.2f}")

  # === ROC Plot
  fpr, tpr, _ = roc_curve(y_test, test_probs)
  plt.figure(figsize=(6, 4))
  plt.plot(fpr, tpr, label=f"AUC = {auc:.2f}")
  plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
  plt.xlabel("False Positive Rate")
  plt.ylabel("True Positive Rate")
  plt.title("ROC Curve - SAINT")
  plt.legend()
  plt.grid(True)
  plt.tight_layout()
  plt.show()

  # === Top 10 Feature Importance (placeholder for SAINT, or you can add attention weights later)
  feature_importance = pd.DataFrame({
      'Feature': feature_names,
      'Importance': np.random.dirichlet(np.ones(len(feature_names)), size=1)[0]  # Dummy values
  }).sort_values(by='Importance', ascending=False)

  print("\n📊 Top 10 Important Features (SAINT):")
  print(feature_importance.head(10))

  # === Return structured result
  results_saint = {
      'val_probs': val_probs,
      'test_probs': test_probs,
      'test_preds': test_preds,
      'accuracy': accuracy,
      'auc': auc,
      'best_threshold': best_threshold,
      'fpr': fpr,
      'tpr': tpr,
      'report_df': report_df,
      'feature_importance': feature_importance
  }

  return model_saint, results_saint
