# Graph-Liquid-KAN Sea Lice Prediction - A100 GPU Training
## Phase 4: Production Training on Google Colab

This notebook runs **A100-optimized GPU training** of the Graph-Liquid-KAN architecture.

**A100 Optimized Configuration:**
| Parameter | Default | A100 Optimized | Notes |
|-----------|---------|----------------|-------|
| hidden_dim | 64 | **128** | 2x capacity |
| n_bases | 8 | **12** | Finer RBF resolution |
| n_layers | 2 | **3** | Deeper network |
| dropout | 0.1 | **0.15** | More regularization |
| batch_size | 8 | **16** | Better gradients |
| num_workers | 2 | **4** | Faster data loading |

**Training Data:**
| Metric | Value |
|--------|-------|
| Nodes | 1,777 (all farms) |
| Edges | 73,168 |
| Features | 8 |
| Observations | 1,660 (0.3% coverage) |
| Outbreaks | 12 (0.72%) |

**Architecture:**
- **FastKAN Layers**: Gaussian RBF basis functions (learnable non-linearities)
- **GraphonAggregator**: 1/N normalized message passing (scale invariant)
- **LiquidKANCell**: Closed-form Continuous (CfC) dynamics with adaptive tau
- **Physics-Informed Loss**: L_data + λ_bio * L_bio + λ_stability * L_stability

**Target Metrics:**
| Metric | Target | Description |
|--------|--------|-------------|
| Recall | ≥90% | Catch 9/10 outbreaks |
| Precision | ≥80% | 8/10 predictions correct |
| F1 Score | ≥0.85 | Balance P/R |

**Estimated Training Time on A100:** ~30-45 minutes for 100 epochs

**Runtime Configuration:**
- Runtime → Change runtime type → **A100 GPU**
- Runtime → Change runtime type → **High RAM**

## Cell 1: Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
os.makedirs('/content/drive/MyDrive/GLKAN_Project/checkpoints', exist_ok=True)
os.makedirs('/content/drive/MyDrive/GLKAN_Project/outputs', exist_ok=True)
print('✅ Google Drive mounted and directories created')

## Cell 2: Verify GPU

In [None]:
!nvidia-smi

import torch
import gc

print('\n' + '='*60)
print('GPU VERIFICATION')
print('='*60)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {gpu_name}")
    print(f"VRAM: {gpu_memory:.1f} GB")
    device = torch.device('cuda')
    
    # Enable TF32 for faster training on Ampere GPUs
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    
    # Check for mixed precision support
    USE_AMP = True
    print(f"Mixed Precision (AMP): Enabled")
    
    # Set memory-efficient options
    torch.cuda.empty_cache()
    gc.collect()
else:
    print("[WARN] CUDA not available - using CPU (will be VERY slow)")
    print("[WARN] Please enable GPU: Runtime → Change runtime type → T4 GPU")
    device = torch.device('cpu')
    USE_AMP = False

print(f"\nDefault device: {device}")
print('='*60)

## Cell 3: Install Dependencies

In [None]:
!pip install -q torch torchvision --upgrade
!pip install -q numpy pandas scipy scikit-learn
!pip install -q loguru tqdm matplotlib
!pip install -q torch-geometric  # For graph operations

print('\n✅ Dependencies installed')

## Cell 4: Load Data from Drive

In [None]:
import os
import numpy as np
import torch

# Path to data on Drive
DRIVE_DATA = '/content/drive/MyDrive/glkan_data.zip'
LOCAL_DATA = '/content/data'

if not os.path.exists(DRIVE_DATA):
    print(f'❌ ERROR: Data not found at {DRIVE_DATA}')
    print('Please upload glkan_data.zip containing:')
    print('  - tensors.npz (from Phase 2)')
    print('  - spatial_graph.pt (from Phase 2)')
else:
    print('Extracting data...')
    os.makedirs(LOCAL_DATA, exist_ok=True)
    !unzip -q "{DRIVE_DATA}" -d {LOCAL_DATA}
    !ls -la {LOCAL_DATA}
    print('\n✅ Data loaded')

# Verify data
TENSOR_PATH = f'{LOCAL_DATA}/tensors.npz'
GRAPH_PATH = f'{LOCAL_DATA}/spatial_graph.pt'

if os.path.exists(TENSOR_PATH) and os.path.exists(GRAPH_PATH):
    data = np.load(TENSOR_PATH, allow_pickle=True)
    print(f"\nData shapes:")
    print(f"  X (features): {data['X'].shape}")
    print(f"  Y (targets):  {data['Y'].shape}")
    print(f"  mask:         {data['mask'].shape}")
    
    graph = torch.load(GRAPH_PATH, weights_only=False)
    print(f"  edges:        {graph['edge_index'].shape[1]}")

## Cell 5: Define Graph-Liquid-KAN Architecture

Complete architecture from Phase 3 protocol:
- FastKAN with RBF basis
- Graphon-compliant aggregation
- Liquid time-constant dynamics

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Dict

# =============================================================================
# FastKAN Layer - Gaussian RBF Basis Functions
# =============================================================================
class FastKAN(nn.Module):
    """KAN layer with Gaussian RBF basis functions."""
    
    def __init__(self, in_features: int, out_features: int, n_bases: int = 8):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.n_bases = n_bases
        
        self.layer_norm = nn.LayerNorm(in_features)
        
        grid_centers = torch.linspace(-1.0, 1.0, n_bases)
        self.register_buffer("grid_centers", grid_centers)
        
        grid_spacing = 2.0 / (n_bases - 1)
        sigma = grid_spacing / 2.0
        self.register_buffer("gaussian_denom", torch.tensor(2.0 * sigma * sigma))
        
        self.weights = nn.Parameter(torch.empty(in_features, n_bases, out_features))
        self.bias = nn.Parameter(torch.zeros(out_features))
        
        scale = 1.0 / math.sqrt(in_features * n_bases)
        nn.init.uniform_(self.weights, -scale, scale)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.layer_norm(x)
        x_expanded = x.unsqueeze(-1)
        centers = self.grid_centers.view(1, 1, -1)
        distances_sq = (x_expanded - centers) ** 2
        basis = torch.exp(-distances_sq / self.gaussian_denom)
        output = torch.einsum("...if,ifo->...o", basis, self.weights)
        return output + self.bias


# =============================================================================
# Graphon Aggregator - 1/N Normalized Message Passing
# =============================================================================
class GraphonAggregator(nn.Module):
    """Graph aggregator with 1/N normalization for scale invariance."""
    
    def __init__(self, add_self_loops: bool = True):
        super().__init__()
        self.add_self_loops = add_self_loops
        self._cached_adj = None
        self._cached_n = None
    
    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        if x.dim() == 3:
            return torch.stack([self._aggregate(x[b], edge_index) for b in range(x.shape[0])])
        return self._aggregate(x, edge_index)
    
    def _aggregate(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        n_nodes = x.shape[0]
        
        if self._cached_adj is None or self._cached_n != n_nodes:
            src, dst = edge_index[0], edge_index[1]
            edge_weight = torch.ones(edge_index.shape[1], device=x.device)
            
            if self.add_self_loops:
                loop_idx = torch.arange(n_nodes, device=x.device)
                src = torch.cat([src, loop_idx])
                dst = torch.cat([dst, loop_idx])
                edge_weight = torch.cat([edge_weight, torch.ones(n_nodes, device=x.device)])
            
            degree = torch.zeros(n_nodes, device=x.device)
            degree.scatter_add_(0, src, edge_weight)
            degree = degree.clamp(min=1.0)
            edge_weight = edge_weight / degree[src]
            
            indices = torch.stack([src, dst])
            self._cached_adj = torch.sparse_coo_tensor(indices, edge_weight, (n_nodes, n_nodes)).coalesce()
            self._cached_n = n_nodes
        
        return torch.sparse.mm(self._cached_adj, x)
    
    def reset_cache(self):
        self._cached_adj = None
        self._cached_n = None


# =============================================================================
# Liquid-KAN Cell - CfC Dynamics
# =============================================================================
class LiquidKANCell(nn.Module):
    """Liquid time-constant cell with KAN-parameterized dynamics."""
    
    def __init__(self, input_dim: int, hidden_dim: int, n_bases: int = 8,
                 tau_min: float = 0.01, tau_max: float = 10.0):
        super().__init__()
        self.tau_min = tau_min
        self.tau_max = tau_max
        
        context_dim = input_dim + hidden_dim
        self.kan_tau = FastKAN(context_dim, hidden_dim, n_bases)
        self.kan_eq = FastKAN(context_dim, hidden_dim, n_bases)
        self.kan_gate = FastKAN(context_dim, hidden_dim, n_bases)
    
    def forward(self, h: torch.Tensor, u: torch.Tensor, p: torch.Tensor, dt: torch.Tensor) -> torch.Tensor:
        context = torch.cat([u, p], dim=-1)
        
        tau = F.softplus(self.kan_tau(context)) + self.tau_min
        tau = torch.clamp(tau, self.tau_min, self.tau_max)
        
        x_eq = self.kan_eq(context)
        gate = torch.sigmoid(self.kan_gate(context))
        
        if isinstance(dt, (int, float)):
            dt = torch.tensor(dt, device=h.device, dtype=h.dtype)
        while dt.dim() < tau.dim():
            dt = dt.unsqueeze(-1)
        
        decay = torch.exp(-dt / tau)
        return decay * h + (1 - decay) * x_eq * gate


# =============================================================================
# Graph-Liquid-KAN Cell
# =============================================================================
class GraphLiquidKANCell(nn.Module):
    """Combines graph aggregation with liquid dynamics."""
    
    def __init__(self, input_dim: int, hidden_dim: int, n_bases: int = 8):
        super().__init__()
        self.aggregator = GraphonAggregator()
        self.pressure_proj = FastKAN(hidden_dim, hidden_dim, n_bases)
        self.liquid_cell = LiquidKANCell(input_dim, hidden_dim, n_bases)
    
    def forward(self, h: torch.Tensor, u: torch.Tensor, edge_index: torch.Tensor, dt: torch.Tensor) -> torch.Tensor:
        h_agg = self.aggregator(h, edge_index)
        p = self.pressure_proj(h_agg)
        return self.liquid_cell(h, u, p, dt)
    
    def reset_cache(self):
        self.aggregator.reset_cache()


# =============================================================================
# Complete GLKAN Network
# =============================================================================
class GLKANNetwork(nn.Module):
    """Full Graph-Liquid-KAN network."""
    
    def __init__(self, input_dim: int, hidden_dim: int = 64, output_dim: int = 3,
                 n_bases: int = 8, n_layers: int = 1, dropout: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        
        self.input_encoder = FastKAN(input_dim, hidden_dim, n_bases)
        self.cells = nn.ModuleList([GraphLiquidKANCell(hidden_dim, hidden_dim, n_bases) for _ in range(n_layers)])
        self.layer_norms = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(n_layers)])
        self.dropout = nn.Dropout(dropout)
        self.output_decoder = FastKAN(hidden_dim, output_dim, n_bases)
        self.h0 = nn.Parameter(torch.zeros(1, 1, hidden_dim))
    
    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        if x.dim() == 3:
            x = x.unsqueeze(0)
            squeeze = True
        else:
            squeeze = False
        
        B, T, N, F = x.shape
        dt = torch.ones(T, device=x.device) / T
        
        h_list = [self.h0.expand(B, N, -1).clone() for _ in range(self.n_layers)]
        outputs = []
        
        for t in range(T):
            u_t = self.dropout(self.input_encoder(x[:, t]))
            
            for i in range(self.n_layers):
                h_new = self.cells[i](h_list[i], u_t if i == 0 else h_list[i-1], edge_index, dt[t])
                h_new = self.layer_norms[i](h_new)
                if i > 0:
                    h_new = h_new + h_list[i]
                h_list[i] = self.dropout(h_new)
            
            outputs.append(self.output_decoder(h_list[-1]))
        
        pred = torch.stack(outputs, dim=1)
        if squeeze:
            pred = pred.squeeze(0)
        return pred, h_list[-1]
    
    def reset_cache(self):
        for cell in self.cells:
            cell.reset_cache()


# =============================================================================
# GLKAN Predictor (Training Wrapper)
# =============================================================================
class GLKANPredictor(nn.Module):
    """Predictor wrapper for training."""
    
    def __init__(self, input_dim: int, hidden_dim: int = 64, output_dim: int = 3,
                 n_bases: int = 8, n_layers: int = 1, dropout: float = 0.1):
        super().__init__()
        self.network = GLKANNetwork(input_dim, hidden_dim, output_dim, n_bases, n_layers, dropout)
        self.output_dim = output_dim
    
    def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        pred, _ = self.network(batch['x'], batch['edge_index'])
        return {'predictions': pred}

print('✅ Graph-Liquid-KAN architecture defined')

## Cell 6: Define Physics-Informed Loss

In [None]:
class PhysicsInformedLoss(nn.Module):
    """
    Physics-Informed Loss for GLKAN.
    
    L = L_data + λ_bio * L_bio + λ_stability * L_stability
    
    Components:
    - L_data: Huber loss on masked observations
    - L_bio: Non-negativity + growth rate bounds
    - L_stability: Tau regularization
    """
    
    def __init__(self, lambda_bio: float = 0.1, lambda_stability: float = 0.01,
                 max_daily_change: float = 0.2, huber_delta: float = 1.0):
        super().__init__()
        self.lambda_bio = lambda_bio
        self.lambda_stability = lambda_stability
        self.max_daily_change = max_daily_change
        self.huber = nn.SmoothL1Loss(reduction='none', beta=huber_delta)
    
    def forward(self, pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
        metrics = {}
        
        # L_data: Huber loss on masked observations
        if mask.dim() == pred.dim() - 1:
            mask = mask.unsqueeze(-1).expand_as(pred)
        
        huber = self.huber(pred, target)
        masked_huber = huber * mask.float()
        n_valid = mask.float().sum()
        l_data = masked_huber.sum() / n_valid.clamp(min=1)
        metrics['l_data'] = l_data.item()
        
        # L_bio: Non-negativity + growth rate
        l_nonneg = F.relu(-pred).mean()
        
        if pred.shape[1] > 1:
            delta = pred[:, 1:] - pred[:, :-1]
            threshold = self.max_daily_change * (pred[:, :-1].abs() + 0.1)
            l_growth = F.relu(delta.abs() - threshold).mean()
        else:
            l_growth = torch.tensor(0.0, device=pred.device)
        
        l_bio = l_nonneg + 0.1 * l_growth
        metrics['l_bio'] = l_bio.item()
        metrics['l_nonneg'] = l_nonneg.item()
        metrics['l_growth'] = l_growth.item()
        
        # Total loss
        total = l_data + self.lambda_bio * l_bio
        metrics['total_loss'] = total.item()
        
        return total, metrics

print('✅ Physics-Informed Loss defined')

## Cell 7: Create Dataset and DataLoaders

In [None]:
from torch.utils.data import Dataset, DataLoader

class SeaLiceDataset(Dataset):
    """Dataset for GLKAN training."""
    
    def __init__(self, X, Y, mask, edge_index, window_size=30, stride=7, time_start=0, time_end=None):
        self.X = X
        self.Y = Y
        self.mask = mask
        self.edge_index = edge_index
        self.window_size = window_size
        
        time_end = time_end or X.shape[0]
        self.sequences = []
        for t in range(time_start, time_end - window_size, stride):
            self.sequences.append((t, t + window_size))
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        t_start, t_end = self.sequences[idx]
        return {
            'x': self.X[t_start:t_end],
            'y': self.Y[t_start:t_end],
            'mask': self.mask[t_start:t_end],
            'edge_index': self.edge_index,
        }

def collate_fn(batch):
    return {
        'x': torch.stack([b['x'] for b in batch]),
        'y': torch.stack([b['y'] for b in batch]),
        'mask': torch.stack([b['mask'] for b in batch]),
        'edge_index': batch[0]['edge_index'],
    }

# Load data
data = np.load(TENSOR_PATH, allow_pickle=True)
graph = torch.load(GRAPH_PATH, weights_only=False)

X = torch.from_numpy(data['X']).float()
Y = torch.from_numpy(data['Y']).float()
mask = torch.from_numpy(data['mask']).bool()
edge_index = graph['edge_index']

print(f'Data loaded on CPU (will move to GPU later):')
print(f'  X device: {X.device}')
print(f'  edge_index device: {edge_index.device}')

# Train/Val/Test split (70/15/15)
T_total = X.shape[0]
T_train = int(T_total * 0.70)
T_val = int(T_total * 0.85)

# Configuration - OPTIMIZED FOR FULL GPU TRAINING
WINDOW_SIZE = 30  # 30-day sequences
STRIDE = 7        # Weekly stride
BATCH_SIZE = 8    # Will be updated in next cell

train_ds = SeaLiceDataset(X, Y, mask, edge_index, WINDOW_SIZE, STRIDE, 0, T_train)
val_ds = SeaLiceDataset(X, Y, mask, edge_index, WINDOW_SIZE, STRIDE, T_train, T_val)
test_ds = SeaLiceDataset(X, Y, mask, edge_index, WINDOW_SIZE, STRIDE, T_val, T_total)

# Use pin_memory for faster GPU transfer
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, 
                          collate_fn=collate_fn, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, 
                        collate_fn=collate_fn, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, 
                         collate_fn=collate_fn, num_workers=2, pin_memory=True)

print(f'\n✅ DataLoaders created (FULL DATA):')
print(f'  Train sequences: {len(train_ds)}')
print(f'  Val sequences: {len(val_ds)}')
print(f'  Test sequences: {len(test_ds)}')
print(f'  Batch size: {BATCH_SIZE}')
print(f'  Window size: {WINDOW_SIZE}')
print(f'  Nodes: {X.shape[1]} (ALL FARMS)')
print(f'  Edges: {edge_index.shape[1]}')
print(f'  Features: {X.shape[2]}')
print(f'\n  Estimated batches/epoch: {len(train_loader)}')

## Cell 8: Create Model and Optimizer

In [None]:
# =============================================================================
# A100 GPU OPTIMIZED CONFIGURATION
# =============================================================================
# Tuned for NVIDIA A100 (40GB/80GB) with sparse sea lice data
# - Increased capacity without overfitting on 0.3% observed data
# - Larger batches for better gradient estimates
# - More RBF bases for finer non-linearity resolution

CONFIG = {
    # Model architecture - A100 optimized
    'hidden_dim': 128,    # 2x capacity (was 64)
    'n_bases': 12,        # More RBF resolution (was 8)
    'n_layers': 3,        # Deeper network (was 2)
    'dropout': 0.15,      # Slightly more regularization
    
    # Training hyperparameters
    'lr': 1e-4,
    'weight_decay': 1e-4,
    'grad_clip': 1.0,
    'epochs': 100,
    
    # Loss weights
    'lambda_bio': 0.1,
    'lambda_stability': 0.01,
    
    # Early stopping
    'patience': 15,
    'min_delta': 1e-6,
}

# A100 optimized batch size
BATCH_SIZE = 16  # Larger batches for A100 (was 8)

# =============================================================================
# CRITICAL: Move edge_index to GPU BEFORE creating model
# =============================================================================
# This ensures the GraphonAggregator cache is built on GPU from the start
# Without this, there can be CPU/GPU device mismatch errors!

print('='*60)
print('MOVING DATA TO GPU')
print('='*60)

# Move edge_index to GPU (this is constant, so we do it once)
edge_index_gpu = edge_index.to(device)
print(f'  edge_index moved: {edge_index.device} -> {edge_index_gpu.device}')

# Update datasets to use GPU edge_index
train_ds.edge_index = edge_index_gpu
val_ds.edge_index = edge_index_gpu
test_ds.edge_index = edge_index_gpu
print(f'  Datasets updated with GPU edge_index')

# Recreate dataloaders with GPU-aware data
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, 
                          collate_fn=collate_fn, num_workers=4, pin_memory=False)  # pin_memory=False since edge_index is already on GPU
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, 
                        collate_fn=collate_fn, num_workers=4, pin_memory=False)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, 
                         collate_fn=collate_fn, num_workers=4, pin_memory=False)
print(f'  DataLoaders recreated with batch_size={BATCH_SIZE}')

# =============================================================================
# Verify GPU placement
# =============================================================================
print('\nGPU Verification:')
sample = next(iter(train_loader))
print(f'  Sample x device:          {sample["x"].device}')
print(f'  Sample edge_index device: {sample["edge_index"].device}')

# Move sample tensors to GPU to verify the pipeline
sample_gpu = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in sample.items()}
print(f'  After .to(device):')
print(f'    x device:          {sample_gpu["x"].device}')
print(f'    edge_index device: {sample_gpu["edge_index"].device}')
print(f'    y device:          {sample_gpu["y"].device}')
print(f'    mask device:       {sample_gpu["mask"].device}')

# =============================================================================
# Create model ON GPU
# =============================================================================
input_dim = X.shape[-1]
output_dim = Y.shape[-1]

model = GLKANPredictor(
    input_dim=input_dim,
    hidden_dim=CONFIG['hidden_dim'],
    output_dim=output_dim,
    n_bases=CONFIG['n_bases'],
    n_layers=CONFIG['n_layers'],
    dropout=CONFIG['dropout'],
).to(device)

# Verify model is on GPU
model_device = next(model.parameters()).device
print(f'\nModel device: {model_device}')

# Verify a forward pass works with GPU data
print('Testing forward pass on GPU...')
with torch.no_grad():
    test_out = model(sample_gpu)
    print(f'  Output device: {test_out["predictions"].device}')
    print(f'  Output shape:  {test_out["predictions"].shape}')

# Create optimizer with weight decay
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG['lr'],
    weight_decay=CONFIG['weight_decay'],
    betas=(0.9, 0.999),
)

# Cosine annealing with warm restarts for better convergence
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=20, T_mult=2, eta_min=1e-7
)

# Create loss function
criterion = PhysicsInformedLoss(
    lambda_bio=CONFIG['lambda_bio'],
    lambda_stability=CONFIG['lambda_stability'],
)

# Mixed precision scaler
scaler = torch.cuda.amp.GradScaler() if USE_AMP else None

n_params = sum(p.numel() for p in model.parameters())
print(f'\n{"="*60}')
print(f'A100 OPTIMIZED MODEL - ALL ON GPU')
print(f'{"="*60}')
print(f'  Parameters: {n_params:,}')
print(f'  Device: {model_device}')
print(f'  Hidden dim: {CONFIG["hidden_dim"]} (2x default)')
print(f'  RBF bases: {CONFIG["n_bases"]} (1.5x default)')
print(f'  Layers: {CONFIG["n_layers"]}')
print(f'  Dropout: {CONFIG["dropout"]}')
print(f'  Batch size: {BATCH_SIZE}')
print(f'  Mixed Precision: {USE_AMP}')
print(f'{"="*60}')

## Cell 9: Full GPU Training Loop

**Features:**
- Mixed precision (FP16) for faster training
- Early stopping with patience
- Gradient clipping for stability
- Periodic checkpointing
- Memory-efficient batch processing

In [None]:
from tqdm.auto import tqdm
import time

# Training settings
EPOCHS = CONFIG['epochs']
PATIENCE = CONFIG['patience']
CHECKPOINT_DIR = '/content/drive/MyDrive/GLKAN_Project/checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

best_val_loss = float('inf')
patience_counter = 0
history = {'train_loss': [], 'val_loss': [], 'train_rmse': [], 'val_rmse': [], 'lr': []}

print('='*60)
print('GRAPH-LIQUID-KAN FULL GPU TRAINING')
print('='*60)
print(f'Device: {device}')
print(f'Mixed Precision: {USE_AMP}')
print(f'Epochs: {EPOCHS}')
print(f'Early Stopping Patience: {PATIENCE}')
print(f'Nodes: {X.shape[1]} | Edges: {edge_index_gpu.shape[1]}')
print(f'Checkpoints: {CHECKPOINT_DIR}')

# =============================================================================
# GPU VERIFICATION - Critical for performance!
# =============================================================================
print('\n' + '-'*40)
print('GPU VERIFICATION BEFORE TRAINING')
print('-'*40)

# Check model is on GPU
model_device = next(model.parameters()).device
print(f'Model on GPU: {model_device.type == "cuda"} ({model_device})')

# Check edge_index is on GPU  
print(f'edge_index on GPU: {edge_index_gpu.device.type == "cuda"} ({edge_index_gpu.device})')

# Get first batch and verify
first_batch = next(iter(train_loader))
print(f'Batch x on: {first_batch["x"].device}')
print(f'Batch edge_index on: {first_batch["edge_index"].device}')

# Move batch to GPU and verify
first_batch_gpu = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in first_batch.items()}
print(f'After .to(device):')
for k, v in first_batch_gpu.items():
    if isinstance(v, torch.Tensor):
        print(f'  {k}: {v.device}')

# Verify GPU memory is being used
if torch.cuda.is_available():
    torch.cuda.synchronize()
    mem_allocated = torch.cuda.memory_allocated() / 1e9
    mem_reserved = torch.cuda.memory_reserved() / 1e9
    print(f'\nGPU Memory:')
    print(f'  Allocated: {mem_allocated:.2f} GB')
    print(f'  Reserved:  {mem_reserved:.2f} GB')

print('-'*40 + '\n')
print('='*60 + '\n')

start_time = time.time()

for epoch in range(EPOCHS):
    epoch_start = time.time()
    
    # =========================================================================
    # Training
    # =========================================================================
    model.train()
    train_loss = 0
    train_rmse = 0
    n_batches = 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{EPOCHS} [Train]', leave=False)
    for batch in pbar:
        # =================================================================
        # CRITICAL: Move ALL batch tensors to GPU
        # =================================================================
        # x, y, mask need to be moved; edge_index is already on GPU
        batch = {k: v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v 
                 for k, v in batch.items()}
        
        optimizer.zero_grad(set_to_none=True)  # More efficient than zero_grad()
        
        # Mixed precision forward pass
        if USE_AMP:
            with torch.cuda.amp.autocast():
                output = model(batch)
                loss, metrics = criterion(output['predictions'], batch['y'], batch['mask'])
            
            if torch.isnan(loss):
                continue
            
            # Scaled backward pass
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['grad_clip'])
            scaler.step(optimizer)
            scaler.update()
        else:
            output = model(batch)
            loss, metrics = criterion(output['predictions'], batch['y'], batch['mask'])
            
            if torch.isnan(loss):
                continue
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['grad_clip'])
            optimizer.step()
        
        train_loss += loss.item()
        
        # Compute RMSE
        with torch.no_grad():
            mask_exp = batch['mask'].unsqueeze(-1).expand_as(output['predictions'])
            sq_err = ((output['predictions'] - batch['y']) ** 2) * mask_exp.float()
            rmse = torch.sqrt(sq_err.sum() / mask_exp.float().sum().clamp(min=1))
            train_rmse += rmse.item()
        
        n_batches += 1
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'rmse': f'{rmse.item():.4f}'})
    
    train_loss /= max(n_batches, 1)
    train_rmse /= max(n_batches, 1)
    
    # =========================================================================
    # Validation
    # =========================================================================
    model.eval()
    val_loss = 0
    val_rmse = 0
    n_val = 0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f'Epoch {epoch+1}/{EPOCHS} [Val]', leave=False):
            batch = {k: v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v 
                     for k, v in batch.items()}
            
            if USE_AMP:
                with torch.cuda.amp.autocast():
                    output = model(batch)
                    loss, _ = criterion(output['predictions'], batch['y'], batch['mask'])
            else:
                output = model(batch)
                loss, _ = criterion(output['predictions'], batch['y'], batch['mask'])
            
            val_loss += loss.item()
            
            mask_exp = batch['mask'].unsqueeze(-1).expand_as(output['predictions'])
            sq_err = ((output['predictions'] - batch['y']) ** 2) * mask_exp.float()
            rmse = torch.sqrt(sq_err.sum() / mask_exp.float().sum().clamp(min=1))
            val_rmse += rmse.item()
            n_val += 1
    
    val_loss /= max(n_val, 1)
    val_rmse /= max(n_val, 1)
    
    # Update scheduler
    scheduler.step()
    
    # Save history
    lr = optimizer.param_groups[0]['lr']
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_rmse'].append(train_rmse)
    history['val_rmse'].append(val_rmse)
    history['lr'].append(lr)
    
    epoch_time = time.time() - epoch_start
    
    # =========================================================================
    # Checkpointing & Early Stopping
    # =========================================================================
    improved = val_loss < best_val_loss - CONFIG['min_delta']
    
    if improved:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'config': CONFIG,
        }, f'{CHECKPOINT_DIR}/best_model.pt')
        marker = '✓ Best'
    else:
        patience_counter += 1
        marker = f'({patience_counter}/{PATIENCE})'
    
    # Periodic checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'history': history,
        }, f'{CHECKPOINT_DIR}/checkpoint_epoch_{epoch+1}.pt')
    
    # Print progress with GPU memory info
    if torch.cuda.is_available():
        gpu_mem = torch.cuda.memory_allocated() / 1e9
        gpu_util = f'{gpu_mem:.1f}GB'
    else:
        gpu_util = 'N/A'
    
    print(f'Epoch {epoch+1:3d}/{EPOCHS} | '
          f'Train: {train_loss:.4f} | Val: {val_loss:.4f} | '
          f'RMSE: {val_rmse:.4f} | LR: {lr:.2e} | '
          f'GPU: {gpu_util} | Time: {epoch_time:.1f}s | {marker}')
    
    # Early stopping check
    if patience_counter >= PATIENCE:
        print(f'\n⚠️ Early stopping triggered after {epoch+1} epochs')
        print(f'   No improvement for {PATIENCE} consecutive epochs')
        break
    
    # Clear GPU cache periodically
    if (epoch + 1) % 5 == 0:
        torch.cuda.empty_cache()
        gc.collect()

elapsed = time.time() - start_time
print(f'\n{"="*60}')
print(f'✅ TRAINING COMPLETE')
print(f'{"="*60}')
print(f'Total time: {elapsed/60:.1f} minutes ({elapsed/3600:.2f} hours)')
print(f'Epochs completed: {epoch+1}')
print(f'Best validation loss: {best_val_loss:.6f}')
print(f'Best model saved to: {CHECKPOINT_DIR}/best_model.pt')

# Final GPU stats
if torch.cuda.is_available():
    print(f'\nFinal GPU Memory:')
    print(f'  Allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB')
    print(f'  Max Allocated: {torch.cuda.max_memory_allocated()/1e9:.2f} GB')

## Cell 10: Evaluate Outbreak Detection

**Target: 90% Recall, 80% Precision**

In [None]:
from sklearn.metrics import precision_recall_curve, f1_score, precision_score, recall_score, confusion_matrix

OUTBREAK_THRESHOLD = 0.5  # Norwegian regulatory limit

# Load best model
checkpoint = torch.load(f'{CHECKPOINT_DIR}/best_model.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f'Loaded best model from epoch {checkpoint["epoch"]+1}')
print(f'Validation loss: {checkpoint["val_loss"]:.6f}')

# Collect predictions on TEST set (held out data)
all_preds = []
all_targets = []
all_masks = []

print(f'\nEvaluating on TEST set ({len(test_loader)} batches)...')
print(f'Model device: {next(model.parameters()).device}')

with torch.no_grad():
    for batch in tqdm(test_loader, desc='Testing'):
        # Move ALL tensors to GPU
        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
        
        if USE_AMP:
            with torch.cuda.amp.autocast():
                output = model(batch)
        else:
            output = model(batch)
        
        all_preds.append(output['predictions'].cpu())
        all_targets.append(batch['y'].cpu())
        all_masks.append(batch['mask'].cpu())

preds = torch.cat(all_preds, dim=0)
targets = torch.cat(all_targets, dim=0)
masks = torch.cat(all_masks, dim=0)

# Extract adult female lice (index 0) with mask
pred_af = preds[:, :, :, 0].numpy().flatten()
target_af = targets[:, :, :, 0].numpy().flatten()
mask_flat = masks.numpy().flatten()

pred_valid = pred_af[mask_flat]
target_valid = target_af[mask_flat]

print(f'\n{"="*60}')
print('OUTBREAK DETECTION EVALUATION (TEST SET)')
print(f'{"="*60}')
print(f'Valid observations: {len(pred_valid)}')
print(f'Outbreak threshold: {OUTBREAK_THRESHOLD}')

# Regression metrics
rmse = np.sqrt(np.mean((pred_valid - target_valid) ** 2))
mae = np.mean(np.abs(pred_valid - target_valid))
corr = np.corrcoef(pred_valid, target_valid)[0, 1] if len(pred_valid) > 1 else 0

print(f'\nRegression Metrics:')
print(f'  RMSE:        {rmse:.4f}')
print(f'  MAE:         {mae:.4f}')
print(f'  Correlation: {corr:.4f}')
print(f'  Pred range:   [{pred_valid.min():.3f}, {pred_valid.max():.3f}]')
print(f'  Target range: [{target_valid.min():.3f}, {target_valid.max():.3f}]')
print(f'  Pred std:     {pred_valid.std():.4f}')
print(f'  Target std:   {target_valid.std():.4f}')

# Classification metrics
target_binary = (target_valid > OUTBREAK_THRESHOLD).astype(int)
n_outbreaks = target_binary.sum()
print(f'\nOutbreak Distribution:')
print(f'  Outbreaks: {n_outbreaks} ({100*n_outbreaks/len(target_binary):.1f}%)')
print(f'  Normal:    {len(target_binary)-n_outbreaks} ({100*(1-n_outbreaks/len(target_binary)):.1f}%)')

if n_outbreaks > 0:
    # Find optimal threshold for 90% recall
    precisions, recalls, thresholds = precision_recall_curve(target_binary, pred_valid)
    
    # Find best threshold achieving target recall
    best_thresh = None
    best_f1 = 0
    target_recall = 0.90
    
    for p, r, t in zip(precisions[:-1], recalls[:-1], thresholds):
        if r >= target_recall:
            f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0
            if f1 > best_f1:
                best_f1 = f1
                best_thresh = t
                best_p = p
                best_r = r
    
    if best_thresh:
        pred_binary = (pred_valid > best_thresh).astype(int)
        precision = precision_score(target_binary, pred_binary, zero_division=0)
        recall = recall_score(target_binary, pred_binary, zero_division=0)
        f1 = f1_score(target_binary, pred_binary, zero_division=0)
        cm = confusion_matrix(target_binary, pred_binary)
        
        tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (cm[0,0], 0, 0, 0)
        
        print(f'\n{"="*60}')
        print(f'OUTBREAK DETECTION @ threshold={best_thresh:.4f}')
        print(f'{"="*60}')
        print(f'  Precision: {precision:.2%} (target: >=80%)')
        print(f'  Recall:    {recall:.2%} (target: >=90%)')
        print(f'  F1 Score:  {f1:.4f} (target: >=0.85)')
        
        print(f'\nConfusion Matrix:')
        print(f'                    Predicted')
        print(f'                 Normal  Outbreak')
        print(f'  Actual Normal   {tn:5d}    {fp:5d}')
        print(f'  Actual Outbreak {fn:5d}    {tp:5d}')
        
        # Check targets
        print(f'\n{"="*60}')
        print('TARGET ASSESSMENT')
        print(f'{"="*60}')
        recall_pass = recall >= 0.90
        precision_pass = precision >= 0.80
        f1_pass = f1 >= 0.85
        
        print(f'  Recall >= 90%:    {"PASS" if recall_pass else "FAIL"} ({recall:.1%})')
        print(f'  Precision >= 80%: {"PASS" if precision_pass else "FAIL"} ({precision:.1%})')
        print(f'  F1 >= 0.85:       {"PASS" if f1_pass else "FAIL"} ({f1:.4f})')
        
        if recall_pass and precision_pass:
            print('\nMODEL MEETS OUTBREAK DETECTION TARGETS!')
        else:
            print('\nModel needs more training or tuning')
    else:
        print(f'\n[WARN] Could not achieve {target_recall:.0%} recall with any threshold')
else:
    print('\n[WARN] No outbreaks in test set - cannot compute detection metrics')
    print('       This may indicate data imbalance issues')

## Cell 11: Scientific Validation Audit

In [None]:
print('='*60)
print('SCIENTIFIC VALIDATION AUDIT')
print('='*60)

model.eval()
print(f'Model device: {next(model.parameters()).device}')

# Get a sample batch and move to GPU
sample_batch = next(iter(val_loader))
sample_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in sample_batch.items()}

# Verify all tensors are on GPU
print(f'\nSample batch devices:')
for k, v in sample_batch.items():
    if isinstance(v, torch.Tensor):
        print(f'  {k}: {v.device}')

# =============================================================================
# TEST 1: Counterfactual (Temperature Effect)
# =============================================================================
print('\n[TEST 1] Counterfactual: Temperature Effect')
print('-' * 40)

with torch.no_grad():
    # Original prediction
    output_orig = model(sample_batch)
    pred_orig = output_orig['predictions']
    
    if pred_orig.shape[1] > 1:
        growth_orig = (pred_orig[:, 1:] - pred_orig[:, :-1]).mean().item()
    else:
        growth_orig = 0
    
    # +5C temperature
    x_hot = sample_batch['x'].clone()
    x_hot[..., 0] += 5.0  # Temperature is feature 0
    batch_hot = {**sample_batch, 'x': x_hot}
    
    output_hot = model(batch_hot)
    pred_hot = output_hot['predictions']
    
    if pred_hot.shape[1] > 1:
        growth_hot = (pred_hot[:, 1:] - pred_hot[:, :-1]).mean().item()
    else:
        growth_hot = 0

print(f'  Original growth rate: {growth_orig:.6f}')
print(f'  +5C growth rate:      {growth_hot:.6f}')
print(f'  Difference:           {growth_hot - growth_orig:.6f}')

test1_pass = growth_hot > growth_orig
print(f'  Result: {"[PASS]" if test1_pass else "[FAIL]"} Temperature {"increases" if test1_pass else "does not increase"} growth')

# =============================================================================
# TEST 2: Long-Horizon Stability
# =============================================================================
print('\n[TEST 2] Long-Horizon Stability (90 days)')
print('-' * 40)

with torch.no_grad():
    x = sample_batch['x']
    B, T, N, F = x.shape
    
    # Extend to 90 days
    n_repeats = (90 // T) + 2
    x_extended = x.repeat(1, n_repeats, 1, 1)[:, :90]
    
    # Reset cache before different graph size
    model.network.reset_cache()
    pred_ext, _ = model.network(x_extended, sample_batch['edge_index'])
    
    pred_min = pred_ext.min().item()
    pred_max = pred_ext.max().item()
    has_nan = torch.isnan(pred_ext).any().item()
    has_inf = torch.isinf(pred_ext).any().item()
    
    initial_scale = pred_ext[:, :T].abs().mean().item() + 1e-6
    final_scale = pred_ext[:, -T:].abs().mean().item()
    ratio = final_scale / initial_scale

print(f'  Rollout: 90 days')
print(f'  Range: [{pred_min:.4f}, {pred_max:.4f}]')
print(f'  Scale ratio: {ratio:.2f}x')
print(f'  NaN/Inf: {has_nan or has_inf}')

test2_pass = not (has_nan or has_inf) and ratio < 100 and pred_max < 1000
print(f'  Result: {"[PASS]" if test2_pass else "[FAIL]"} System is {"stable" if test2_pass else "unstable"}')

# =============================================================================
# TEST 3: Graphon Generalization
# =============================================================================
print('\n[TEST 3] Graphon Generalization (Scale Invariance)')
print('-' * 40)

with torch.no_grad():
    model.network.reset_cache()
    
    x = sample_batch['x']
    edge_index_local = sample_batch['edge_index']  # Already on GPU
    
    pred_n, _ = model.network(x, edge_index_local)
    mean_n = pred_n.abs().mean().item()
    
    # Double nodes
    B, T, N, F = x.shape
    x_2n = x.repeat(1, 1, 2, 1)
    edge_2n = torch.cat([edge_index_local, edge_index_local + N], dim=1)
    
    model.network.reset_cache()
    pred_2n, _ = model.network(x_2n, edge_2n)
    mean_2n = pred_2n.abs().mean().item()
    
    deviation = abs(mean_2n - mean_n) / (mean_n + 1e-8)

print(f'  N={N} mean: {mean_n:.6f}')
print(f'  N={2*N} mean: {mean_2n:.6f}')
print(f'  Deviation: {100*deviation:.2f}%')

test3_pass = deviation < 0.10
print(f'  Result: {"[PASS]" if test3_pass else "[FAIL]"} Scale invariance {"within" if test3_pass else "exceeds"} 10% tolerance')

# =============================================================================
# Summary
# =============================================================================
print('\n' + '='*60)
print('AUDIT SUMMARY')
print('='*60)
print(f'  {"[x]" if test1_pass else "[ ]"} Counterfactual (Temperature): {"PASS" if test1_pass else "FAIL"}')
print(f'  {"[x]" if test2_pass else "[ ]"} Long-Horizon Stability: {"PASS" if test2_pass else "FAIL"}')
print(f'  {"[x]" if test3_pass else "[ ]"} Graphon Generalization: {"PASS" if test3_pass else "FAIL"}')

all_pass = test1_pass and test2_pass and test3_pass
print(f'\n{"All tests PASSED" if all_pass else "Some tests FAILED"}')

## Cell 12: Save Results

In [None]:
import json
from datetime import datetime

OUTPUT_DIR = '/content/drive/MyDrive/GLKAN_Project/outputs'
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

results = {
    'timestamp': timestamp,
    'config': CONFIG,
    'training': {
        'epochs': len(history['train_loss']),
        'best_val_loss': best_val_loss,
        'final_train_loss': history['train_loss'][-1] if history['train_loss'] else None,
        'final_val_loss': history['val_loss'][-1] if history['val_loss'] else None,
    },
    'regression': {
        'rmse': float(rmse),
        'mae': float(mae),
        'pred_std': float(pred_valid.std()),
    },
    'scientific_audit': {
        'counterfactual': test1_pass,
        'long_horizon': test2_pass,
        'graphon': test3_pass,
        'all_passed': all_pass,
    },
    'history': history,
}

results_path = f'{OUTPUT_DIR}/results_{timestamp}.json'
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2, default=str)

print(f'\n✅ Results saved to: {results_path}')
print(f'Checkpoints at: {CHECKPOINT_DIR}')
print('\nTraining complete!')

## Cell 13: Plot Training History

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss
axes[0, 0].plot(history['train_loss'], label='Train Loss', color='blue', alpha=0.8)
axes[0, 0].plot(history['val_loss'], label='Val Loss', color='orange', alpha=0.8)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training & Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
axes[0, 0].set_yscale('log')

# RMSE
axes[0, 1].plot(history['train_rmse'], label='Train RMSE', color='blue', alpha=0.8)
axes[0, 1].plot(history['val_rmse'], label='Val RMSE', color='orange', alpha=0.8)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('RMSE')
axes[0, 1].set_title('Training & Validation RMSE')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Learning Rate
axes[1, 0].plot(history['lr'], label='Learning Rate', color='green', alpha=0.8)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Learning Rate')
axes[1, 0].set_title('Learning Rate Schedule')
axes[1, 0].set_yscale('log')
axes[1, 0].grid(True, alpha=0.3)

# Loss components (if available)
axes[1, 1].plot(history['train_loss'], label='Train Loss', color='blue', alpha=0.8)
best_epoch = np.argmin(history['val_loss'])
axes[1, 1].axvline(x=best_epoch, color='red', linestyle='--', label=f'Best (epoch {best_epoch+1})')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Loss')
axes[1, 1].set_title(f'Best Model @ Epoch {best_epoch+1}')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{OUTPUT_DIR}/training_curves_{timestamp}.png', dpi=150, bbox_inches='tight')
plt.show()

print(f'\n✅ Training curves saved to: {OUTPUT_DIR}/training_curves_{timestamp}.png')

# Print training summary
print(f'\n{"="*60}')
print('TRAINING SUMMARY')
print(f'{"="*60}')
print(f'  Total epochs: {len(history["train_loss"])}')
print(f'  Best epoch: {best_epoch+1}')
print(f'  Best val loss: {min(history["val_loss"]):.6f}')
print(f'  Final train loss: {history["train_loss"][-1]:.6f}')
print(f'  Final val loss: {history["val_loss"][-1]:.6f}')
print(f'  Final val RMSE: {history["val_rmse"][-1]:.4f}')