# NeuroSmriti - Complete Model Training Pipeline

This notebook will:
1. Load real or synthetic data
2. Train MemoryGNN model
3. Evaluate performance
4. Generate visualizations
5. Save trained model

## 1. Setup & Imports

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import pickle
import sys
import os

# Add src to path
sys.path.append('../src')

from models.memory_gnn import MemoryGNN, MemoryDecayLoss

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

## 2. Check Available Data

In [None]:
# Check what data we have
data_dir = '../data'
synthetic_dir = f'{data_dir}/synthetic'
hackathon_dir = f'{data_dir}/raw/hackathon'

print("üìä Available Datasets:")
print("=" * 50)

# Check synthetic data
if os.path.exists(f'{synthetic_dir}/train.pkl'):
    with open(f'{synthetic_dir}/train.pkl', 'rb') as f:
        synthetic_train = pickle.load(f)
    print(f"‚úÖ Synthetic Data: {len(synthetic_train)} training samples")
    use_synthetic = True
else:
    print("‚ùå Synthetic data not found. Run 01_data_generation.ipynb first.")
    use_synthetic = False

# Check hackathon data
if os.path.exists(hackathon_dir):
    files = os.listdir(hackathon_dir)
    print(f"‚úÖ Hackathon Data: {len(files)} files found")
    use_hackathon = True
else:
    print("‚ÑπÔ∏è  Hackathon data not found (optional)")
    use_hackathon = False

print("\nüí° Using synthetic data for training (perfect for hackathon!)")

## 3. Load Training Data

In [None]:
# If synthetic data doesn't exist, generate it now
if not use_synthetic:
    print("Generating synthetic data...")
    %run 01_data_generation.ipynb

# Load datasets
print("Loading datasets...")

with open(f'{synthetic_dir}/train.pkl', 'rb') as f:
    train_data = pickle.load(f)

with open(f'{synthetic_dir}/val.pkl', 'rb') as f:
    val_data = pickle.load(f)

with open(f'{synthetic_dir}/test.pkl', 'rb') as f:
    test_data = pickle.load(f)

print(f"‚úÖ Train: {len(train_data)} samples")
print(f"‚úÖ Val: {len(val_data)} samples")
print(f"‚úÖ Test: {len(test_data)} samples")

# Analyze data
print("\nüìä Dataset Statistics:")
all_data = train_data + val_data + test_data
avg_nodes = np.mean([d.x.size(0) for d in all_data])
avg_edges = np.mean([d.edge_index.size(1) for d in all_data])
print(f"Average memories per patient: {avg_nodes:.1f}")
print(f"Average connections per patient: {avg_edges:.1f}")

## 4. Create Data Loaders

In [None]:
BATCH_SIZE = 32

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

print(f"‚úÖ Created data loaders (batch size: {BATCH_SIZE})")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

## 5. Initialize Model

In [None]:
# Hyperparameters
NUM_NODE_FEATURES = 10
HIDDEN_CHANNELS = 64
NUM_HEADS = 4
NUM_LAYERS = 3
DROPOUT = 0.3
LEARNING_RATE = 0.001

# Initialize model
model = MemoryGNN(
    num_node_features=NUM_NODE_FEATURES,
    hidden_channels=HIDDEN_CHANNELS,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    dropout=DROPOUT
).to(device)

print("‚úÖ Model initialized")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Loss and optimizer
criterion = MemoryDecayLoss(alpha=0.7, beta=0.3)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, verbose=True
)

print("‚úÖ Optimizer and scheduler configured")

## 6. Training Functions

In [None]:
def train_epoch(model, loader, optimizer, criterion, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0

    for batch in tqdm(loader, desc="Training", leave=False):
        batch = batch.to(device)
        optimizer.zero_grad()

        # Forward pass
        node_pred, graph_pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)

        # Compute loss
        loss = criterion(node_pred, batch.y_node, graph_pred, batch.y_graph)

        # Backward pass
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)


def evaluate(model, loader, criterion, device):
    """Evaluate model"""
    model.eval()
    total_loss = 0
    node_mae = 0
    graph_mae = 0

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)

            node_pred, graph_pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)

            loss = criterion(node_pred, batch.y_node, graph_pred, batch.y_graph)
            total_loss += loss.item()

            node_mae += F.l1_loss(node_pred, batch.y_node).item()
            graph_mae += F.l1_loss(graph_pred, batch.y_graph).item()

    return {
        'loss': total_loss / len(loader),
        'node_mae': node_mae / len(loader),
        'graph_mae': graph_mae / len(loader)
    }

print("‚úÖ Training functions defined")

## 7. Train Model

In [None]:
NUM_EPOCHS = 50
PATIENCE = 10

print(f"üöÄ Starting training for {NUM_EPOCHS} epochs...")
print("=" * 60)

# Track metrics
train_losses = []
val_losses = []
val_node_maes = []
val_graph_maes = []

best_val_loss = float('inf')
patience_counter = 0

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")

    # Train
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
    train_losses.append(train_loss)

    # Validate
    val_metrics = evaluate(model, val_loader, criterion, device)
    val_losses.append(val_metrics['loss'])
    val_node_maes.append(val_metrics['node_mae'])
    val_graph_maes.append(val_metrics['graph_mae'])

    # Scheduler step
    scheduler.step(val_metrics['loss'])

    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Loss: {val_metrics['loss']:.4f} | Node MAE: {val_metrics['node_mae']:.4f} | Graph MAE: {val_metrics['graph_mae']:.4f}")

    # Save best model
    if val_metrics['loss'] < best_val_loss:
        best_val_loss = val_metrics['loss']
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_metrics['loss'],
            'hyperparameters': {
                'hidden_channels': HIDDEN_CHANNELS,
                'num_heads': NUM_HEADS,
                'num_layers': NUM_LAYERS,
                'dropout': DROPOUT
            }
        }, '../models/memory_gnn_best.pth')
        print("‚úÖ Saved best model!")
        patience_counter = 0
    else:
        patience_counter += 1

    # Early stopping
    if patience_counter >= PATIENCE:
        print(f"\n‚èπÔ∏è  Early stopping after {epoch + 1} epochs")
        break

print("\n" + "=" * 60)
print("üéâ Training completed!")

## 8. Plot Training History

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Loss curve
axes[0].plot(train_losses, label='Train Loss', linewidth=2)
axes[0].plot(val_losses, label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training & Validation Loss', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Node MAE
axes[1].plot(val_node_maes, color='orange', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('MAE', fontsize=12)
axes[1].set_title('Memory Decay Prediction Error', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3)

# Graph MAE
axes[2].plot(val_graph_maes, color='green', linewidth=2)
axes[2].set_xlabel('Epoch', fontsize=12)
axes[2].set_ylabel('MAE', fontsize=12)
axes[2].set_title('Risk Score Prediction Error', fontsize=14, fontweight='bold')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../models/training_history.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úÖ Training curves saved to: ml/models/training_history.png")

## 9. Final Evaluation on Test Set

In [None]:
print("=" * 60)
print("üìä FINAL EVALUATION ON TEST SET")
print("=" * 60)

# Load best model
checkpoint = torch.load('../models/memory_gnn_best.pth')
model.load_state_dict(checkpoint['model_state_dict'])

# Evaluate
test_metrics = evaluate(model, test_loader, criterion, device)

print(f"\nüìà Test Results:")
print(f"Loss: {test_metrics['loss']:.4f}")
print(f"Node MAE (decay prediction): {test_metrics['node_mae']:.4f}")
print(f"Graph MAE (risk score): {test_metrics['graph_mae']:.4f}")

print(f"\nüìä As Percentages:")
print(f"Memory Decay Error: {test_metrics['node_mae'] * 100:.2f}%")
print(f"Risk Score Error: {test_metrics['graph_mae'] * 100:.2f}%")

# Calculate accuracy (within threshold)
model.eval()
correct_30 = 0
correct_90 = 0
correct_180 = 0
total_nodes = 0

with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(device)
        node_pred, _ = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)

        diff = torch.abs(node_pred - batch.y_node)
        correct_30 += (diff[:, 0] < 0.1).sum().item()
        correct_90 += (diff[:, 1] < 0.1).sum().item()
        correct_180 += (diff[:, 2] < 0.1).sum().item()
        total_nodes += batch.x.size(0)

print(f"\nüéØ Accuracy (within 10% threshold):")
print(f"30-day forecast: {100 * correct_30 / total_nodes:.2f}%")
print(f"90-day forecast: {100 * correct_90 / total_nodes:.2f}%")
print(f"180-day forecast: {100 * correct_180 / total_nodes:.2f}%")

print("\n" + "=" * 60)
print("‚úÖ Evaluation complete!")
print("üì¶ Model saved to: ml/models/memory_gnn_best.pth")
print("=" * 60)

## 10. Sample Predictions

In [None]:
# Get a sample patient
sample_patient = test_data[0].to(device)

model.eval()
with torch.no_grad():
    node_pred, graph_pred = model(
        sample_patient.x,
        sample_patient.edge_index,
        sample_patient.edge_attr
    )

print("üîÆ Sample Patient Prediction")
print("=" * 60)
print(f"Stage: {sample_patient.stage}")
print(f"Number of memories: {sample_patient.x.size(0)}")
print(f"\nOverall Risk Score: {graph_pred.item():.3f}")

print(f"\nüìâ Top 5 High-Risk Memories:")
# Get memories with highest decay risk (lowest predicted 30-day strength)
risk_scores = 1.0 - node_pred[:, 0]  # Inverse of predicted strength
top_risk_indices = torch.argsort(risk_scores, descending=True)[:5]

for i, idx in enumerate(top_risk_indices, 1):
    current_strength = sample_patient.x[idx, 5].item() * 100  # Current strength
    pred_30 = node_pred[idx, 0].item() * 100
    pred_90 = node_pred[idx, 1].item() * 100
    pred_180 = node_pred[idx, 2].item() * 100

    print(f"\nMemory #{i} (Index {idx.item()}):")
    print(f"  Current: {current_strength:.1f}%")
    print(f"  Predicted 30-day: {pred_30:.1f}% (Œî {pred_30 - current_strength:+.1f}%)")
    print(f"  Predicted 90-day: {pred_90:.1f}%")
    print(f"  Predicted 180-day: {pred_180:.1f}%")
    print(f"  ‚ö†Ô∏è  Intervention recommended!" if pred_30 < 50 else "  ‚úÖ Stable")

## 11. Export Model for Production

In [None]:
print("üì¶ Exporting model for production...")

# Save final model with metadata
torch.save({
    'model_state_dict': model.state_dict(),
    'hyperparameters': {
        'num_node_features': NUM_NODE_FEATURES,
        'hidden_channels': HIDDEN_CHANNELS,
        'num_heads': NUM_HEADS,
        'num_layers': NUM_LAYERS,
        'dropout': DROPOUT
    },
    'performance': {
        'test_loss': test_metrics['loss'],
        'node_mae': test_metrics['node_mae'],
        'graph_mae': test_metrics['graph_mae'],
        'accuracy_30d': correct_30 / total_nodes,
        'accuracy_90d': correct_90 / total_nodes,
        'accuracy_180d': correct_180 / total_nodes
    }
}, '../models/memory_gnn_production.pth')

# Copy to backend for API use
import shutil
backend_model_dir = '../../backend/app/ml/models'
os.makedirs(backend_model_dir, exist_ok=True)
shutil.copy('../models/memory_gnn_production.pth', f'{backend_model_dir}/memory_gnn_production.pth')

print("‚úÖ Model exported to:")
print("   1. ml/models/memory_gnn_production.pth")
print("   2. backend/app/ml/models/memory_gnn_production.pth")

print("\nüéâ Training pipeline complete!")
print("\nNext steps:")
print("1. Start backend: cd ../../backend && uvicorn app.main:app --reload")
print("2. Test predictions at: http://localhost:8000/docs")
print("3. Build frontend dashboard")