# Train HAN Model for Sepsis Prediction

**Input**: HeteroData from `owl_to_heterodata.ipynb`  
**Output**: Trained HAN model with attention weights

**HAN** (Heterogeneous Attention Network) uses:
- **Node-level attention**: Which neighbors are important?
- **Semantic-level attention**: Which metapaths are important?

In [11]:
from pathlib import Path
import torch
import torch.nn.functional as F
from torch_geometric.nn import HANConv  # Use HANConv layer
from torch_geometric.data import HeteroData

project_root = Path.cwd().parent.parent
print(f"✓ Project root: {project_root}")

✓ Project root: /Users/silviatrottet/Documents/M2_GENIOMHE/Deep Learning/2526-m2geniomhe-GNN-sepsis


## 1. Load HeteroData

In [12]:
data_path = project_root / "data" / "han" / "hetero_graph_han.pt"

print(f"Loading HeteroData from: {data_path.name}")
data = torch.load(data_path, weights_only=False)  # PyTorch 2.6+ requires this

print(f"\n{data}")
print(f"\nNode types: {data.node_types}")
print(f"Edge types: {data.edge_types}")

Loading HeteroData from: hetero_graph_han.pt

HeteroData(
  patient={
    x=[163, 163],
    y=[163],
    train_mask=[163],
    val_mask=[163],
    test_mask=[163],
  },
  protein={ x=[1295, 1295] },
  (patient, expresses, protein)={ edge_index=[2, 277784] },
  (protein, interacts, protein)={ edge_index=[2, 2821] }
)

Node types: ['patient', 'protein']
Edge types: [('patient', 'expresses', 'protein'), ('protein', 'interacts', 'protein')]


## 2. Define Metapaths

**Metapaths** = paths through the graph that capture different semantic relationships

For our sepsis prediction:
- `Patient → Protein → Patient`: Patients connected through shared proteins
- `Patient → Protein → Protein → Patient`: Patients connected through interacting proteins

In [None]:
# Define metapaths for HAN
metapaths = [
    # Metapath 1: Patient → expresses → Protein → rev_expresses → Patient
    [('patient', 'expresses', 'protein'), ('protein', 'rev_expresses', 'patient')],
    
    # Metapath 2: Patient → expresses → Protein → interacts → Protein → rev_expresses → Patient
    [('patient', 'expresses', 'protein'), 
     ('protein', 'interacts', 'protein'),
     ('protein', 'rev_expresses', 'patient')]
]

print("Metapaths defined:")
for i, mp in enumerate(metapaths, 1):
    path_str = ' → '.join([f"{s}--[{r}]--{d}" for s, r, d in mp])
    print(f"  {i}. {path_str}")

### Add Reverse Edges

HAN needs reverse edges for metapaths

In [None]:
# Add reverse edges for patient→protein
edge_index = data['patient', 'expresses', 'protein'].edge_index
data['protein', 'rev_expresses', 'patient'].edge_index = edge_index.flip(0)

print(f"✓ Added reverse edges")
print(f"\nUpdated edge types: {data.edge_types}")

## 3. Create HAN Model

In [None]:
class HANClassifier(torch.nn.Module):
    """
    HAN-based classifier for sepsis prediction using HANConv layers
    """
    def __init__(self, in_channels, hidden_channels, out_channels, metadata, num_heads=8):
        super().__init__()
        
        # HANConv layer
        self.han_conv = HANConv(
            in_channels=in_channels,
            out_channels=hidden_channels,
            metadata=metadata,
            heads=num_heads
        )
        
        # Classifier for patient nodes
        self.classifier = torch.nn.Linear(hidden_channels, out_channels)
    
    def forward(self, x_dict, edge_index_dict):
        # HANConv forward pass
        out_dict = self.han_conv(x_dict, edge_index_dict)
        
        # Classify patient nodes
        patient_out = self.classifier(out_dict['patient'])
        
        return patient_out

In [None]:
# Model parameters
in_channels = -1  # Auto-infer from data
hidden_channels = 64
out_channels = 2  # Binary: sepsis vs healthy
num_heads = 8

# Create model
model = HANClassifier(
    in_channels=in_channels,
    hidden_channels=hidden_channels,
    out_channels=out_channels,
    metadata=data.metadata(),
    num_heads=num_heads
)

# Initialize lazy modules with a dummy forward pass
with torch.no_grad():
    _ = model(data.x_dict, data.edge_index_dict)

print(f"✓ Model created and initialized")
print(f"\nParameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"\n{model}")

## 4. Training Setup

In [40]:
# Calculate class weights to handle imbalance
y_train = data['patient'].y[data['patient'].train_mask]
n_class_0 = (y_train == 0).sum().item()
n_class_1 = (y_train == 1).sum().item()
total = n_class_0 + n_class_1

# Calculate inverse frequency weights automatically
weight_class_0 = total / (2 * n_class_0)
weight_class_1 = total / (2 * n_class_1)
class_weights = torch.tensor([weight_class_0, weight_class_1], dtype=torch.float)

print(f"Training set class distribution:")
print(f"  Class 0 (healthy): {n_class_0}")
print(f"  Class 1 (sepsis):  {n_class_1}")
print(f"  Ratio: {n_class_0 / n_class_1:.2f}:1")
print(f"\nClass weights (computed automatically):")
print(f"  Class 0: {weight_class_0:.4f}")
print(f"  Class 1: {weight_class_1:.4f}")

# Optimizer and loss with class weights
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

print("\n✓ Optimizer: Adam (lr=0.001)")
print("✓ Loss: CrossEntropyLoss with class weights")

Training set class distribution:
  Class 0 (healthy): 22
  Class 1 (sepsis):  92
  Ratio: 0.24:1

Class weights (computed automatically):
  Class 0: 2.5909
  Class 1: 0.6196

✓ Optimizer: Adam (lr=0.001)
✓ Loss: CrossEntropyLoss with class weights


## 5. Training Loop

In [35]:
def train():
    model.train()
    optimizer.zero_grad()
    
    # Forward pass
    out = model(data.x_dict, data.edge_index_dict)
    
    # Loss on training nodes only
    loss = criterion(out[data['patient'].train_mask], 
                     data['patient'].y[data['patient'].train_mask])
    
    # Backward
    loss.backward()
    optimizer.step()
    
    return loss.item()


@torch.no_grad()
def evaluate(mask):
    model.eval()
    
    out = model(data.x_dict, data.edge_index_dict)
    pred = out.argmax(dim=1)
    
    # Accuracy
    correct = (pred[mask] == data['patient'].y[mask]).sum()
    acc = int(correct) / int(mask.sum())
    
    return acc

In [None]:
# Train!
num_epochs = 200
patience = 20
best_val_acc = 0
patience_counter = 0

print("Training HAN...\n")
print(f"Epoch | Train Loss | Val Acc | Test Acc")
print("-" * 45)

for epoch in range(1, num_epochs + 1):
    loss = train()
    
    if epoch % 10 == 0:
        val_acc = evaluate(data['patient'].val_mask)
        test_acc = evaluate(data['patient'].test_mask)
        
        print(f"{epoch:4d}  | {loss:.4f}     | {val_acc:.3f}   | {test_acc:.3f}")
        
        # Early stopping
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            # Save best model
            torch.save(model.state_dict(), project_root / 'best_han_model.pt')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"\n Early stopping at epoch {epoch}")
                break

print(f"\n Best validation accuracy: {best_val_acc:.3f}")

Training HAN...

Epoch | Train Loss | Val Acc | Test Acc
---------------------------------------------
  10  | 0.6863     | 0.875   | 0.680
  20  | 0.6854     | 0.917   | 0.880
  30  | 0.6845     | 0.958   | 0.920
  40  | 0.6834     | 0.917   | 0.800
  50  | 0.6821     | 0.917   | 0.760
  60  | 0.6807     | 0.917   | 0.800
  70  | 0.6789     | 0.917   | 0.800
  80  | 0.6768     | 0.917   | 0.800
  90  | 0.6740     | 0.917   | 0.800
 100  | 0.6704     | 0.917   | 0.840
 110  | 0.6654     | 0.917   | 0.840
 120  | 0.6584     | 0.958   | 0.880
 130  | 0.6478     | 0.958   | 0.920
 140  | 0.6315     | 0.958   | 0.920
 150  | 0.6065     | 0.958   | 0.880
 160  | 0.5695     | 0.958   | 0.920
 170  | 0.5172     | 0.750   | 0.840
 180  | 0.4499     | 0.667   | 0.840
 190  | 0.3728     | 0.667   | 0.840
 200  | 0.2973     | 0.625   | 0.800

 Best validation accuracy: 0.958


## 6. Final Evaluation

In [43]:
# Load best model
model.load_state_dict(torch.load(project_root / 'best_han_model.pt', weights_only=False))

# Evaluate
train_acc = evaluate(data['patient'].train_mask)
val_acc = evaluate(data['patient'].val_mask)
test_acc = evaluate(data['patient'].test_mask)

print("\n" + "="*50)
print("Final Results")
print("="*50)
print(f"Train Accuracy: {train_acc:.1%}")
print(f"Val Accuracy:   {val_acc:.1%}")
print(f"Test Accuracy:  {test_acc:.1%}")


Final Results
Train Accuracy: 97.4%
Val Accuracy:   95.8%
Test Accuracy:  92.0%


## 7. Extract Attention Weights

Get node-level and semantic-level attention for interpretability

In [None]:
# Get attention weights from HAN
model.eval()

# Forward pass to get embeddings and attentions
with torch.no_grad():
    out = model(data.x_dict, data.edge_index_dict)

print("✓ Attention weights computed")

✓ Attention weights computed

Next: Extract and visualize attention for specific patients


## 8. Save Model

In [45]:
# Save complete model info
output_path = project_root / "results" / "han" / "han_model.pt"
output_path.parent.mkdir(parents=True, exist_ok=True)

torch.save({
    'model_state_dict': model.state_dict(),
    'train_acc': train_acc,
    'val_acc': val_acc,
    'test_acc': test_acc,
    'hidden_channels': hidden_channels,
    'num_heads': num_heads
}, output_path)

print(f"\n Model saved to: {output_path}")


 Model saved to: /Users/silviatrottet/Documents/M2_GENIOMHE/Deep Learning/2526-m2geniomhe-GNN-sepsis/results/han/han_model.pt
