# Graph-Liquid-KAN Sea Lice Prediction - A100 GPU Training
## Phase 4: Production Training on Google Colab

This notebook runs **A100-optimized GPU training** of the **SeaLicePredictor** architecture with **Weights & Biases** integration for experiment tracking.

**Architecture: SeaLicePredictor**
- **FastKAN Layers**: Gaussian RBF basis functions (learnable non-linearities)
- **GraphonAggregator**: 1/N normalized message passing (scale invariant)
- **LiquidKANCell**: Closed-form Continuous (CfC) dynamics with adaptive tau
- **BelehradekKAN**: Temperature-dependent development rate
- **SalinityMortalityKAN**: Salinity survival factor
- **LarvalTransportModule**: Cross-farm infection via ocean currents
- **Physics-Informed Loss**: Tweedie (p=1.5) + L_bio

**Key Features:**
1. **Tweedie Loss (p=1.5)**: Proper loss for zero-inflated count data. Prevents "mean reversion" where model predicts dataset mean for all inputs.
2. **Dynamic Learning Rate**: Starts at lr=1e-2, reduces by 0.5x after 15 epochs without improvement (ReduceLROnPlateau)
3. **Conformal Prediction**: 90% coverage prediction intervals for uncertainty quantification
4. **Risk-Aware Detection**: Uses upper bounds for conservative outbreak detection

**Outbreak Threshold:** 0.5 adult female lice per fish (Norwegian regulatory threshold)

**Target Metrics:**
| Metric | Target | Description |
|--------|--------|-------------|
| Recall | >=90% | Catch 9/10 outbreaks |
| Precision | >=80% | 8/10 predictions correct |
| F1 Score | >=0.85 | Balance P/R |
| Conformal Coverage | 90% | Prediction intervals contain true value |

**Runtime Configuration:**
- Runtime -> Change runtime type -> **A100 GPU**
- Runtime -> Change runtime type -> **High RAM**

**Setup:**
1. Get wandb API key from https://wandb.ai/authorize
2. Upload `glkan_data.zip` to Google Drive root (must contain tensors.npz with feature_indices)
3. Run all cells

## Cell 1: Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
os.makedirs('/content/drive/MyDrive/GLKAN_Project/checkpoints', exist_ok=True)
os.makedirs('/content/drive/MyDrive/GLKAN_Project/outputs', exist_ok=True)
print('Google Drive mounted and directories created')

## Cell 2: Clone Repository & Install Dependencies

In [None]:
import os

# Clone the Graph-Liquid-KAN repository
REPO_URL = 'https://github.com/themythicalyeti/graph-liquid-kan.git'
REPO_DIR = '/content/graph-liquid-kan'

if os.path.exists(REPO_DIR):
    print(f'Repository already exists at {REPO_DIR}')
    %cd {REPO_DIR}
    !git pull
else:
    !git clone {REPO_URL} {REPO_DIR}
    %cd {REPO_DIR}

print(f'\nWorking directory: {os.getcwd()}')
!ls -la

In [None]:
# Install dependencies
!pip install -q torch torchvision --upgrade
!pip install -q numpy pandas scipy scikit-learn
!pip install -q loguru tqdm matplotlib
!pip install -q torch-geometric
!pip install -q wandb

print('\nDependencies installed')

# =============================================================================
# WANDB AUTHENTICATION
# =============================================================================
# Option 1: Use Colab Secrets (recommended - no prompt each time)
#   1. Click the key icon in left sidebar
#   2. Add secret named "WANDB_API_KEY" with your key from https://wandb.ai/authorize
#
# Option 2: Manual login (will prompt for API key)
#   Just run the cell - it will ask for your key

import wandb
import os

try:
    from google.colab import userdata
    WANDB_KEY = userdata.get('WANDB_API_KEY')
    os.environ['WANDB_API_KEY'] = WANDB_KEY
    wandb.login(key=WANDB_KEY)
    print('Logged in to wandb using Colab Secrets')
except:
    print('Colab Secrets not configured - will prompt for API key')
    print('Get your key at: https://wandb.ai/authorize')
    wandb.login()
    
print(f'wandb authenticated as: {wandb.api.viewer()["entity"]}')

## Cell 3: Verify GPU

In [None]:
!nvidia-smi

import torch
import gc

print('\n' + '='*60)
print('GPU VERIFICATION')
print('='*60)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {gpu_name}")
    print(f"VRAM: {gpu_memory:.1f} GB")
    device = torch.device('cuda')
    
    # Enable TF32 for faster training on Ampere GPUs
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    
    # IMPORTANT: Disable AMP - sparse matrix ops don't support FP16
    # torch.sparse.mm raises NotImplementedError for 'Half' dtype
    USE_AMP = False
    print(f"Mixed Precision (AMP): Disabled (sparse ops don't support FP16)")
    
    torch.cuda.empty_cache()
    gc.collect()
else:
    print("[WARN] CUDA not available - using CPU (will be VERY slow)")
    print("[WARN] Please enable GPU: Runtime -> Change runtime type -> T4 GPU")
    device = torch.device('cpu')
    USE_AMP = False

print(f"\nDefault device: {device}")
print('='*60)

## Cell 4: Load Data from Drive

In [None]:
import os
import numpy as np
import torch

# Path to data on Drive
DRIVE_DATA = '/content/drive/MyDrive/glkan_data.zip'
LOCAL_DATA = '/content/data'

if not os.path.exists(DRIVE_DATA):
    print(f'ERROR: Data not found at {DRIVE_DATA}')
    print('Please upload glkan_data.zip containing:')
    print('  - tensors.npz (from Phase 2)')
    print('  - spatial_graph.pt (from Phase 2)')
else:
    print('Extracting data...')
    os.makedirs(LOCAL_DATA, exist_ok=True)
    !unzip -q "{DRIVE_DATA}" -d {LOCAL_DATA}
    !ls -la {LOCAL_DATA}
    print('\nData loaded')

# Verify data
TENSOR_PATH = f'{LOCAL_DATA}/tensors.npz'
GRAPH_PATH = f'{LOCAL_DATA}/spatial_graph.pt'

if os.path.exists(TENSOR_PATH) and os.path.exists(GRAPH_PATH):
    data = np.load(TENSOR_PATH, allow_pickle=True)
    print(f"\nData shapes:")
    print(f"  X (features): {data['X'].shape}")
    print(f"  Y (targets):  {data['Y'].shape}")
    print(f"  mask:         {data['mask'].shape}")
    
    graph = torch.load(GRAPH_PATH, weights_only=False)
    print(f"  edges:        {graph['edge_index'].shape[1]}")

## Cell 5: Import GLKAN Architecture from Repository

In [None]:
import sys
sys.path.insert(0, '/content/graph-liquid-kan')

# Import architecture from src/models
from src.models import (
    FastKAN,
    GraphonAggregator,
    LiquidKANCell,
    GraphLiquidKANCell,
    GLKANNetwork,
    GLKANPredictor,
)

# Import SeaLicePredictor (domain-specific with biological modules)
from src.models.sea_lice_network import SeaLicePredictor

# Import conformal prediction for uncertainty quantification
from src.models.conformal import ConformalSeaLicePredictor

# Import training utilities from src/training
from src.training import PhysicsInformedLoss, GLKANLoss
from src.training.losses import LossConfig
from src.training.trainer import TrainingConfig

# Import dataset from src/data
from src.data import SeaLiceGraphDataset

print('Imported from repository:')
print('  - FastKAN, GraphonAggregator, LiquidKANCell')
print('  - GraphLiquidKANCell, GLKANNetwork, GLKANPredictor')
print('  - SeaLicePredictor (domain-specific with biology modules)')
print('  - ConformalSeaLicePredictor (uncertainty quantification)')
print('  - PhysicsInformedLoss, GLKANLoss, LossConfig, TrainingConfig')
print('  - SeaLiceGraphDataset')
print('\nGraph-Liquid-KAN architecture loaded from src/')

## Cell 6: Create Dataset and DataLoaders

In [None]:
from torch.utils.data import Dataset, DataLoader

class SeaLiceDataset(Dataset):
    """Dataset for GLKAN training with feature_indices support."""
    
    def __init__(self, X, Y, mask, edge_index, feature_indices=None, window_size=30, stride=7, time_start=0, time_end=None):
        self.X = X
        self.Y = Y
        self.mask = mask
        self.edge_index = edge_index
        self.feature_indices = feature_indices
        self.window_size = window_size
        
        time_end = time_end or X.shape[0]
        self.sequences = []
        for t in range(time_start, time_end - window_size, stride):
            self.sequences.append((t, t + window_size))
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        t_start, t_end = self.sequences[idx]
        return {
            'x': self.X[t_start:t_end],
            'y': self.Y[t_start:t_end],
            'mask': self.mask[t_start:t_end],
            'edge_index': self.edge_index,
            'feature_indices': self.feature_indices,
        }

def collate_fn(batch):
    return {
        'x': torch.stack([b['x'] for b in batch]),
        'y': torch.stack([b['y'] for b in batch]),
        'mask': torch.stack([b['mask'] for b in batch]),
        'edge_index': batch[0]['edge_index'],
        'feature_indices': batch[0]['feature_indices'],
    }

# Load data
data = np.load(TENSOR_PATH, allow_pickle=True)
graph = torch.load(GRAPH_PATH, weights_only=False)

X = torch.from_numpy(data['X']).float()
Y = torch.from_numpy(data['Y']).float()
mask = torch.from_numpy(data['mask']).bool()
edge_index = graph['edge_index']

# Load feature_indices for SeaLicePredictor biological modules
if 'feature_indices' in data:
    feature_indices = data['feature_indices'].item()
    print(f'Loaded feature_indices: {list(feature_indices.keys())}')
else:
    feature_indices = None
    print('WARNING: No feature_indices found - biological modules will use defaults')

print(f'\nData loaded:')
print(f'  X: {X.shape}')
print(f'  Y: {Y.shape}')
print(f'  edge_index: {edge_index.shape}')

# Train/Val/Test split (70/15/15)
T_total = X.shape[0]
T_train = int(T_total * 0.70)
T_val = int(T_total * 0.85)

# Configuration
WINDOW_SIZE = 30
STRIDE = 7
BATCH_SIZE = 8

train_ds = SeaLiceDataset(X, Y, mask, edge_index, feature_indices, WINDOW_SIZE, STRIDE, 0, T_train)
val_ds = SeaLiceDataset(X, Y, mask, edge_index, feature_indices, WINDOW_SIZE, STRIDE, T_train, T_val)
test_ds = SeaLiceDataset(X, Y, mask, edge_index, feature_indices, WINDOW_SIZE, STRIDE, T_val, T_total)

# NOTE: num_workers=0 required for Google Colab (multiprocessing causes crashes)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          collate_fn=collate_fn, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                        collate_fn=collate_fn, num_workers=0, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False,
                         collate_fn=collate_fn, num_workers=0, pin_memory=True)

print(f'\nDataLoaders created:')
print(f'  Train sequences: {len(train_ds)}')
print(f'  Val sequences: {len(val_ds)}')
print(f'  Test sequences: {len(test_ds)}')
print(f'  Nodes: {X.shape[1]}')
print(f'  Edges: {edge_index.shape[1]}')

## Cell 6b: Pre-training with Physics-Informed Outbreak Augmentation (Optional)

**Why Pre-train?**
Real sea lice data is heavily imbalanced:
- ~95% of observations are low/normal lice levels
- ~5% are outbreaks (above 0.5 adult female lice)

The model rarely sees outbreaks during training, making it hard to learn outbreak patterns.

**Solution: HybridSpatialOutbreakSimulator**
- Generates synthetic outbreak scenarios using physics-informed simulation
- Uses **dynamic flux** from ocean currents (larvae follow water flow)
- Supports counterfactual "nightmare" scenarios (warmer + stronger currents)
- Creates balanced pre-training dataset (50% outbreaks)

**Workflow:**
1. Train simulator on real data (learns epidemic dynamics)
2. Generate synthetic data with controlled outbreak ratio
3. Pre-train model on synthetic data
4. Fine-tune on real data (next cells)

In [None]:
# =============================================================================
# PHYSICS-INFORMED OUTBREAK AUGMENTATION
# =============================================================================
# Set to True to enable pre-training with synthetic outbreak data
ENABLE_PRETRAINING = True

# Pre-training configuration
PRETRAIN_CONFIG = {
    'simulator_epochs': 50,      # Epochs to train the outbreak simulator
    'simulator_lr': 1e-3,        # Learning rate for simulator
    'n_synthetic_samples': 500,  # Number of synthetic sequences to generate
    'outbreak_ratio': 0.5,       # Fraction of outbreaks (vs ~5% in real data)
    'pretrain_epochs': 20,       # Epochs of pre-training on synthetic data
    'pretrain_lr': 1e-3,         # Lower LR for pre-training
    # Nightmare scenario parameters
    'temperature_range': (3.0, 16.0),  # Temperature perturbation range (°C) - realistic Norwegian coastal temps
    'current_range': (0.8, 1.5),        # Current scaling range
}

if ENABLE_PRETRAINING:
    print('='*60)
    print('PHYSICS-INFORMED OUTBREAK AUGMENTATION')
    print('='*60)
    
    # Check if graph has required static edge features
    if 'edge_direction' not in graph or 'edge_distance' not in graph:
        print('\n[WARN] Graph missing edge_direction/edge_distance')
        print('       Run add_static_edge_features_to_graph() first')
        print('       Skipping pre-training...')
        ENABLE_PRETRAINING = False
    else:
        print(f'\nGraph has required static edge features:')
        print(f'  edge_direction: {graph["edge_direction"].shape}')
        print(f'  edge_distance: {graph["edge_distance"].shape}')

if ENABLE_PRETRAINING:
    from src.models.timegan import HybridSpatialOutbreakSimulator, HybridOutbreakAugmenter
    from tqdm.auto import tqdm
    import torch.nn.functional as F
    
    # Load feature_indices from JSON if needed
    if feature_indices is None:
        import json
        if 'feature_indices_json' in data.files:
            fi_json = data['feature_indices_json'].item()
            feature_indices = json.loads(fi_json)
            print(f'\nLoaded feature_indices: {list(feature_indices.keys())}')
    
    # ==========================================================================
    # Step 1: Create HybridSpatialOutbreakSimulator
    # ==========================================================================
    print('\n--- Step 1: Creating Hybrid Outbreak Simulator ---')
    
    simulator = HybridSpatialOutbreakSimulator(
        n_farms=graph['n_nodes'],
        edge_index=graph['edge_index'],
        edge_distance=graph['edge_distance'],
        edge_direction=graph['edge_direction'],
        feature_indices=feature_indices,
        feature_dim=Y.shape[-1],  # 3: adult female, mobile, attached
        hidden_dim=32,            # Smaller than main model
        decay_km=15.0,            # Larval survival distance
    ).to(device)
    
    n_sim_params = sum(p.numel() for p in simulator.parameters())
    print(f'  Simulator parameters: {n_sim_params:,}')
    print(f'  Device: {device}')
    
    # ==========================================================================
    # Step 2: Train simulator on real outbreak patterns
    # ==========================================================================
    print('\n--- Step 2: Training Simulator on Real Data ---')
    
    # Use training data portion
    X_train = X[:T_train].to(device)
    Y_train = Y[:T_train].to(device)
    mask_train = mask[:T_train].to(device)
    
    sim_optimizer = torch.optim.Adam(simulator.parameters(), lr=PRETRAIN_CONFIG['simulator_lr'])
    
    print(f'  Training on {T_train} timesteps...')
    
    for epoch in tqdm(range(PRETRAIN_CONFIG['simulator_epochs']), desc='Training simulator'):
        simulator.train()
        sim_optimizer.zero_grad()
        
        # Forward pass (flux computed dynamically!)
        Y_pred = simulator.forward(X_train, initial_lice=Y_train[0], add_noise=False)
        
        # MSE loss on valid observations
        diff = (Y_pred - Y_train) ** 2
        loss = (diff * mask_train.unsqueeze(-1).float()).sum() / mask_train.float().sum().clamp(min=1)
        
        loss.backward()
        sim_optimizer.step()
        
        if (epoch + 1) % 10 == 0:
            print(f'  Epoch {epoch+1}: loss={loss.item():.4f}')
    
    simulator.eval()
    print(f'  Final simulator loss: {loss.item():.4f}')
    
    # ==========================================================================
    # Step 3: Generate synthetic outbreak scenarios
    # ==========================================================================
    print('\n--- Step 3: Generating Synthetic Outbreak Data ---')
    
    n_samples = PRETRAIN_CONFIG['n_synthetic_samples']
    n_outbreaks = int(n_samples * PRETRAIN_CONFIG['outbreak_ratio'])
    n_normal = n_samples - n_outbreaks
    
    print(f'  Generating {n_normal} normal + {n_outbreaks} outbreak scenarios...')
    print(f'  Using dynamic flux from ocean currents')
    
    synthetic_X = []  # Node features (for context)
    synthetic_Y = []  # Lice counts (targets)
    
    # Use a sample of real X for environmental context
    X_sample = X[:WINDOW_SIZE].to(device)
    
    # Generate outbreak scenarios (perturbed conditions)
    for i in tqdm(range(n_outbreaks), desc='Generating outbreaks'):
        # Random perturbation
        temp_perturb = PRETRAIN_CONFIG['temperature_range'][0] + \
                       torch.rand(1).item() * (PRETRAIN_CONFIG['temperature_range'][1] - PRETRAIN_CONFIG['temperature_range'][0])
        current_scale = PRETRAIN_CONFIG['current_range'][0] + \
                        torch.rand(1).item() * (PRETRAIN_CONFIG['current_range'][1] - PRETRAIN_CONFIG['current_range'][0])
        
        with torch.no_grad():
            scenario = simulator.generate_outbreak_scenarios(
                n_scenarios=1,
                X=X_sample,
                initial_intensity=0.3,
                temperature_perturbation=temp_perturb,
                current_scale=current_scale,
                device=device,
            )
        synthetic_Y.append(scenario[0].cpu())
        synthetic_X.append(X_sample.cpu())
    
    # Generate normal scenarios
    for i in tqdm(range(n_normal), desc='Generating normal'):
        with torch.no_grad():
            initial = torch.zeros(X.shape[1], Y.shape[-1], device=device)
            initial[:, 0] = torch.rand(X.shape[1], device=device) * 0.1  # Small random initial
            
            scenario = simulator.forward(X_sample, initial_lice=initial, add_noise=True)
        synthetic_Y.append(scenario.cpu())
        synthetic_X.append(X_sample.cpu())
    
    # Stack into tensors
    synthetic_X = torch.stack(synthetic_X, dim=0)  # (n_samples, T, N, F)
    synthetic_Y = torch.stack(synthetic_Y, dim=0)  # (n_samples, T, N, 3)
    synthetic_mask = torch.ones(n_samples, WINDOW_SIZE, X.shape[1], dtype=torch.bool)
    
    print(f'\nSynthetic data generated:')
    print(f'  X shape: {synthetic_X.shape}')
    print(f'  Y shape: {synthetic_Y.shape}')
    print(f'  Y range: [{synthetic_Y.min():.4f}, {synthetic_Y.max():.4f}]')
    
    # Check outbreak rate
    outbreak_threshold = 0.5
    has_outbreak = (synthetic_Y[:, :, :, 0] > outbreak_threshold).any(dim=(1, 2))
    actual_outbreak_rate = has_outbreak.float().mean().item()
    print(f'  Actual outbreak rate: {actual_outbreak_rate:.1%}')
    
    # ==========================================================================
    # Step 4: Create pre-training DataLoader
    # ==========================================================================
    print('\n--- Step 4: Creating Pre-training DataLoader ---')
    
    class SyntheticDataset(Dataset):
        def __init__(self, X, Y, mask, edge_index, feature_indices):
            self.X = X
            self.Y = Y
            self.mask = mask
            self.edge_index = edge_index
            self.feature_indices = feature_indices
        
        def __len__(self):
            return len(self.X)
        
        def __getitem__(self, idx):
            return {
                'x': self.X[idx],
                'y': self.Y[idx],
                'mask': self.mask[idx],
                'edge_index': self.edge_index,
                'feature_indices': self.feature_indices,
            }
    
    pretrain_ds = SyntheticDataset(synthetic_X, synthetic_Y, synthetic_mask, edge_index, feature_indices)
    pretrain_loader = DataLoader(pretrain_ds, batch_size=BATCH_SIZE, shuffle=True,
                                  collate_fn=collate_fn, num_workers=0, pin_memory=True)
    
    print(f'  Pre-training batches: {len(pretrain_loader)}')
    
    # Clear simulator from GPU memory
    del simulator, sim_optimizer, X_train, Y_train, mask_train
    torch.cuda.empty_cache()
    gc.collect()
    
    print('\n' + '='*60)
    print('SYNTHETIC DATA READY FOR PRE-TRAINING')
    print('='*60)
    print(f'  Samples: {n_samples}')
    print(f'  Outbreak ratio: {actual_outbreak_rate:.1%}')
    print(f'  Window size: {WINDOW_SIZE} days')
    print('\nPre-training will run in the main training loop (Cell 8)')
else:
    print('\n[INFO] Pre-training disabled (ENABLE_PRETRAINING = False)')
    pretrain_loader = None

## Cell 7: Create Model and Optimizer

In [None]:
# Configuration - tuned for 1777 nodes on A100 (80GB)
# Using SeaLicePredictor with biological modules and dynamic LR
CONFIG = {
    'hidden_dim': 64,      # Reduced from 128 (OOM fix)
    'n_bases': 8,          # Reduced from 12
    'k_hops': 3,           # Multi-hop spatial aggregation
    'dropout': 0.1,
    'lr': 1e-2,            # CRITICAL: Must be >= 1e-2 for Tweedie loss to work!
    'weight_decay': 1e-4,  # Prevents spline coefficient oscillation
    'grad_clip': 1.0,
    'epochs': 100,
    'lambda_bio': 0.01,    # Bio penalty
    'lambda_stability': 0.001,
    # Dynamic LR scheduler (ReduceLROnPlateau)
    'scheduler_patience': 15,  # Epochs without improvement before reducing LR
    'scheduler_factor': 0.5,   # Multiply LR by this when reducing
    'min_lr': 1e-6,            # Don't reduce below this
    # Early stopping DISABLED - will run all epochs
    'early_stopping_patience': 10000,  # Set very high to effectively disable
    'min_delta': 1e-6,
    # Data config
    'n_nodes': X.shape[1],
    'n_edges': edge_index.shape[1],
    'n_features': X.shape[-1],
    'window_size': WINDOW_SIZE,
    'stride': STRIDE,
    'batch_size': 4,       # Reduced from 16 (OOM fix)
    # ==========================================================================
    # INPUT DENORMALIZATION FOR BIOLOGICAL MODULES
    # ==========================================================================
    # These values are used to convert z-scored temperature/salinity back to
    # raw physical units (°C, PSU) before passing to BelehradekKAN and
    # SalinityMortalityKAN. Without these, the biological modules receive
    # normalized values and output constants.
    'temp_mean': 9.30,     # Mean temperature in °C (from data normalization)
    'temp_std': 3.94,      # Temperature std in °C
    'sal_mean': 31.96,     # Mean salinity in PSU
    'sal_std': 2.84,       # Salinity std in PSU
}

BATCH_SIZE = CONFIG['batch_size']

# Initialize wandb run (already authenticated in Cell 2)
print('='*60)
print('INITIALIZING WANDB RUN')
print('='*60)

run = wandb.init(
    project="graph-liquid-kan",
    name=f"sealice-predictor-lr{CONFIG['lr']}-{CONFIG['hidden_dim']}h",
    config=CONFIG,
    tags=["A100", "sea-lice", "tweedie-loss", "SeaLicePredictor", "dynamic-lr"],
)

if torch.cuda.is_available():
    wandb.config.update({
        'gpu_name': torch.cuda.get_device_name(0),
        'gpu_memory_gb': torch.cuda.get_device_properties(0).total_memory / 1e9,
        'mixed_precision': USE_AMP,
    })

print(f'wandb run: {wandb.run.name}')
print(f'wandb URL: {wandb.run.get_url()}')

# Keep edge_index on CPU in datasets (required for num_workers > 0)
# We'll move it to GPU once before training and use that reference
print('\n' + '='*60)
print('SETTING UP DATA PIPELINE')
print('='*60)

# Datasets keep CPU edge_index (for multiprocessing compatibility)
# NOTE: num_workers=0 required for Google Colab (multiprocessing causes crashes)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          collate_fn=collate_fn, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                        collate_fn=collate_fn, num_workers=0, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False,
                         collate_fn=collate_fn, num_workers=0, pin_memory=True)

# Create a GPU copy of edge_index to use in training loop
# This avoids moving it every batch (it's constant)
edge_index_gpu = edge_index.to(device)
print(f'  edge_index_gpu created on: {edge_index_gpu.device}')
print(f'  DataLoaders using num_workers=0 (Colab compatibility)')
print(f'  Batch size: {BATCH_SIZE}')

# Create SeaLicePredictor model (domain-specific with biological modules)
input_dim = X.shape[-1]
output_dim = Y.shape[-1]

# =============================================================================
# MODEL CREATION WITH INPUT DENORMALIZATION
# =============================================================================
# The temp_mean/std and sal_mean/std parameters enable biological modules
# (BelehradekKAN, SalinityMortalityKAN) to receive raw physical values
# instead of z-scored values. This is CRITICAL for proper biological response.
model = SeaLicePredictor(
    input_dim=input_dim,
    hidden_dim=CONFIG['hidden_dim'],
    output_dim=output_dim,
    n_bases=CONFIG['n_bases'],
    k_hops=CONFIG['k_hops'],
    dropout=CONFIG['dropout'],
    # Input denormalization for biological modules
    temp_mean=CONFIG['temp_mean'],
    temp_std=CONFIG['temp_std'],
    sal_mean=CONFIG['sal_mean'],
    sal_std=CONFIG['sal_std'],
).to(device)

# Optimizer - AdamW with weight decay
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG['lr'],
    weight_decay=CONFIG['weight_decay'],
)

# Scheduler - ReduceLROnPlateau (reduces LR when val_loss plateaus)
from torch.optim.lr_scheduler import ReduceLROnPlateau
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=CONFIG['scheduler_factor'],
    patience=CONFIG['scheduler_patience'],
    min_lr=CONFIG['min_lr'],
)

# Create loss function with LossConfig
# IMPORTANT: Uses Tweedie loss (p=1.5) for zero-inflated count data
loss_config = LossConfig(
    loss_type='tweedie',
    tweedie_p=1.5,
    lambda_bio=CONFIG['lambda_bio'],
    lambda_stability=CONFIG['lambda_stability'],
)
criterion = PhysicsInformedLoss(config=loss_config)

print(f'\nModel: SeaLicePredictor (with biological modules)')
print(f'  - BelehradekKAN: Temperature-dependent development')
print(f'  - SalinityMortalityKAN: Salinity survival factor')
print(f'  - LarvalTransportModule: Cross-farm infection')
print(f'\nInput Denormalization (for biological modules):')
print(f'  - Temperature: mean={CONFIG["temp_mean"]}°C, std={CONFIG["temp_std"]}°C')
print(f'  - Salinity: mean={CONFIG["sal_mean"]} PSU, std={CONFIG["sal_std"]} PSU')
print(f'\nLoss function: Tweedie (p={loss_config.tweedie_p})')
print(f'  - Handles zero-inflated count data (many farms have 0 lice)')
print(f'  - Prevents convergence to dataset mean')
print(f'\nLearning Rate Schedule:')
print(f'  - Initial: {CONFIG["lr"]}')
print(f'  - Reduces by {CONFIG["scheduler_factor"]}x after {CONFIG["scheduler_patience"]} epochs w/o improvement')
print(f'  - Min LR: {CONFIG["min_lr"]}')
print(f'\nEarly Stopping: DISABLED (will run all {CONFIG["epochs"]} epochs)')

scaler = torch.cuda.amp.GradScaler() if USE_AMP else None

n_params = sum(p.numel() for p in model.parameters())
wandb.config.update({
    'n_parameters': n_params, 
    'loss_type': 'tweedie', 
    'tweedie_p': 1.5,
    'model_type': 'SeaLicePredictor',
})
wandb.watch(model, log='all', log_freq=100)

# Clear cache before training
torch.cuda.empty_cache()
gc.collect()

print(f'\nModel created:')
print(f'  Parameters: {n_params:,}')
print(f'  Device: {next(model.parameters()).device}')
print(f'  Hidden dim: {CONFIG["hidden_dim"]}')
print(f'  K-hops: {CONFIG["k_hops"]}')
print(f'  Batch size: {BATCH_SIZE}')

# Show memory usage
if torch.cuda.is_available():
    mem_alloc = torch.cuda.memory_allocated() / 1e9
    mem_reserved = torch.cuda.memory_reserved() / 1e9
    print(f'\nGPU Memory:')
    print(f'  Allocated: {mem_alloc:.2f} GB')
    print(f'  Reserved:  {mem_reserved:.2f} GB')

## Cell 8: Training Loop (with Pre-training)

In [None]:
from tqdm.auto import tqdm
import time

EPOCHS = CONFIG['epochs']
PATIENCE = CONFIG['early_stopping_patience']
CHECKPOINT_DIR = '/content/drive/MyDrive/GLKAN_Project/checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# =============================================================================
# PRE-TRAINING PHASE (if enabled)
# =============================================================================
if ENABLE_PRETRAINING and pretrain_loader is not None:
    print('='*60)
    print('PRE-TRAINING ON SYNTHETIC OUTBREAK DATA')
    print('='*60)
    
    pretrain_epochs = PRETRAIN_CONFIG['pretrain_epochs']
    pretrain_lr = PRETRAIN_CONFIG['pretrain_lr']
    
    # Use separate optimizer with lower LR for pre-training
    pretrain_optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=pretrain_lr,
        weight_decay=CONFIG['weight_decay'],
    )
    
    print(f'Pre-training epochs: {pretrain_epochs}')
    print(f'Pre-training LR: {pretrain_lr}')
    print(f'Synthetic samples: {len(pretrain_loader.dataset)}')
    
    pretrain_history = []
    
    for epoch in range(pretrain_epochs):
        model.train()
        epoch_loss = 0
        n_batches = 0
        
        pbar = tqdm(pretrain_loader, desc=f'Pre-train {epoch+1}/{pretrain_epochs}', leave=False)
        for batch in pbar:
            batch = {
                'x': batch['x'].to(device, non_blocking=True),
                'y': batch['y'].to(device, non_blocking=True),
                'mask': batch['mask'].to(device, non_blocking=True),
                'edge_index': edge_index_gpu,
                'feature_indices': batch['feature_indices'],
            }
            
            pretrain_optimizer.zero_grad(set_to_none=True)
            
            output = model(batch)
            loss, _ = criterion(output['predictions'], batch['y'], batch['mask'])
            
            if torch.isnan(loss):
                continue
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['grad_clip'])
            pretrain_optimizer.step()
            
            epoch_loss += loss.item()
            n_batches += 1
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        avg_loss = epoch_loss / max(n_batches, 1)
        pretrain_history.append(avg_loss)
        print(f'Pre-train Epoch {epoch+1}/{pretrain_epochs} | Loss: {avg_loss:.4f}')
        
        wandb.log({
            'pretrain/epoch': epoch + 1,
            'pretrain/loss': avg_loss,
        })
    
    print(f'\nPre-training complete!')
    print(f'Final pre-train loss: {pretrain_history[-1]:.4f}')
    
    # Log pre-training to wandb
    wandb.config.update({
        'pretrain_enabled': True,
        'pretrain_epochs': pretrain_epochs,
        'pretrain_lr': pretrain_lr,
        'pretrain_samples': len(pretrain_loader.dataset),
        'pretrain_outbreak_ratio': PRETRAIN_CONFIG['outbreak_ratio'],
    })
    
    # Clear pre-training data from memory
    del pretrain_loader, pretrain_optimizer
    torch.cuda.empty_cache()
    gc.collect()
    
    print('\n' + '='*60)
    print('STARTING FINE-TUNING ON REAL DATA')
    print('='*60)
else:
    print('\n[INFO] Skipping pre-training (disabled or no data)')
    wandb.config.update({'pretrain_enabled': False})

# =============================================================================
# MAIN TRAINING LOOP (Fine-tuning on real data)
# =============================================================================
best_val_loss = float('inf')
patience_counter = 0
history = {'train_loss': [], 'val_loss': [], 'train_rmse': [], 'val_rmse': [], 'lr': []}

print('='*60)
print('GRAPH-LIQUID-KAN TRAINING (SeaLicePredictor)')
print('='*60)
print(f'Device: {device}')
print(f'Epochs: {EPOCHS}')
print(f'Initial LR: {CONFIG["lr"]} (reduces by {CONFIG["scheduler_factor"]}x after {CONFIG["scheduler_patience"]} epochs w/o improvement)')
print(f'wandb: {wandb.run.get_url()}')

start_time = time.time()

for epoch in range(EPOCHS):
    epoch_start = time.time()
    
    # Training
    model.train()
    train_loss = 0
    train_rmse = 0
    n_batches = 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{EPOCHS}', leave=False)
    for batch_idx, batch in enumerate(pbar):
        batch = {
            'x': batch['x'].to(device, non_blocking=True),
            'y': batch['y'].to(device, non_blocking=True),
            'mask': batch['mask'].to(device, non_blocking=True),
            'edge_index': edge_index_gpu,
            'feature_indices': batch['feature_indices'],
        }
        
        optimizer.zero_grad(set_to_none=True)
        
        if USE_AMP:
            with torch.cuda.amp.autocast():
                output = model(batch)
                loss, metrics = criterion(output['predictions'], batch['y'], batch['mask'])
            
            if torch.isnan(loss):
                continue
            
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['grad_clip'])
            scaler.step(optimizer)
            scaler.update()
        else:
            output = model(batch)
            loss, metrics = criterion(output['predictions'], batch['y'], batch['mask'])
            
            if torch.isnan(loss):
                continue
            
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['grad_clip'])
            optimizer.step()
        
        train_loss += loss.item()
        
        with torch.no_grad():
            mask_exp = batch['mask'].unsqueeze(-1).expand_as(output['predictions'])
            sq_err = ((output['predictions'] - batch['y']) ** 2) * mask_exp.float()
            rmse = torch.sqrt(sq_err.sum() / mask_exp.float().sum().clamp(min=1))
            train_rmse += rmse.item()
        
        n_batches += 1
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        if batch_idx % 10 == 0:
            wandb.log({'batch/loss': loss.item(), 'batch/rmse': rmse.item()}, commit=False)
    
    train_loss /= max(n_batches, 1)
    train_rmse /= max(n_batches, 1)
    
    # Validation
    model.eval()
    val_loss = 0
    val_rmse = 0
    n_val = 0
    
    with torch.no_grad():
        for batch in val_loader:
            batch = {
                'x': batch['x'].to(device, non_blocking=True),
                'y': batch['y'].to(device, non_blocking=True),
                'mask': batch['mask'].to(device, non_blocking=True),
                'edge_index': edge_index_gpu,
                'feature_indices': batch['feature_indices'],
            }
            
            if USE_AMP:
                with torch.cuda.amp.autocast():
                    output = model(batch)
                    loss, _ = criterion(output['predictions'], batch['y'], batch['mask'])
            else:
                output = model(batch)
                loss, _ = criterion(output['predictions'], batch['y'], batch['mask'])
            
            val_loss += loss.item()
            
            mask_exp = batch['mask'].unsqueeze(-1).expand_as(output['predictions'])
            sq_err = ((output['predictions'] - batch['y']) ** 2) * mask_exp.float()
            rmse = torch.sqrt(sq_err.sum() / mask_exp.float().sum().clamp(min=1))
            val_rmse += rmse.item()
            n_val += 1
    
    val_loss /= max(n_val, 1)
    val_rmse /= max(n_val, 1)
    
    # Update scheduler (ReduceLROnPlateau) - based on val_loss
    old_lr = optimizer.param_groups[0]['lr']
    scheduler.step(val_loss)
    new_lr = optimizer.param_groups[0]['lr']
    
    if new_lr < old_lr:
        print(f'>>> LR REDUCED: {old_lr:.2e} -> {new_lr:.2e} <<<')
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_rmse'].append(train_rmse)
    history['val_rmse'].append(val_rmse)
    history['lr'].append(new_lr)
    
    epoch_time = time.time() - epoch_start
    
    # Log to wandb
    wandb.log({
        'epoch': epoch + 1,
        'train/loss': train_loss,
        'train/rmse': train_rmse,
        'val/loss': val_loss,
        'val/rmse': val_rmse,
        'learning_rate': new_lr,
        'epoch_time_seconds': epoch_time,
    })
    
    # Checkpointing
    improved = val_loss < best_val_loss - CONFIG['min_delta']
    
    if improved:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_loss': val_loss,
            'config': CONFIG,
            'pretrain_enabled': ENABLE_PRETRAINING,
        }, f'{CHECKPOINT_DIR}/best_model.pt')
        marker = '* Best'
        wandb.run.summary['best_val_loss'] = val_loss
        wandb.run.summary['best_epoch'] = epoch + 1
    else:
        patience_counter += 1
        marker = f'({patience_counter}/{PATIENCE})'
    
    print(f'Epoch {epoch+1:3d}/{EPOCHS} | Train: {train_loss:.4f} | Val: {val_loss:.4f} | LR: {new_lr:.2e} | {marker}')
    
    if patience_counter >= PATIENCE:
        print(f'\nEarly stopping at epoch {epoch+1}')
        break
    
    if (epoch + 1) % 5 == 0:
        torch.cuda.empty_cache()
        gc.collect()

elapsed = time.time() - start_time
print(f'\nTraining complete in {elapsed/60:.1f} minutes')
print(f'Best validation loss: {best_val_loss:.6f}')

## Cell 9: Evaluate and Scientific Audit

In [None]:
from sklearn.metrics import precision_recall_curve, f1_score, precision_score, recall_score, confusion_matrix
import numpy as np

# Outbreak threshold (Norwegian regulatory threshold)
OUTBREAK_THRESHOLD = 0.5
CONFORMAL_COVERAGE = 0.90

# Load best model
checkpoint = torch.load(f'{CHECKPOINT_DIR}/best_model.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f'Loaded best model from epoch {checkpoint["epoch"]+1}')

# =============================================================================
# STEP 1: Wrap model with Conformal Prediction
# =============================================================================
print('\n' + '='*60)
print('CONFORMAL PREDICTION SETUP')
print('='*60)

conformal_model = ConformalSeaLicePredictor(
    base_model=model,
    coverage=CONFORMAL_COVERAGE,
    calibration_window=100,
    use_adaptive=True,
)

print(f'Target coverage: {CONFORMAL_COVERAGE*100:.0f}%')

# =============================================================================
# STEP 2: Calibrate on validation set
# =============================================================================
print('\nCalibrating conformal predictor on validation set...')

calib_preds = []
calib_targets = []
calib_masks = []

with torch.no_grad():
    for batch in tqdm(val_loader, desc='Calibrating'):
        batch = {
            'x': batch['x'].to(device),
            'y': batch['y'].to(device),
            'mask': batch['mask'].to(device),
            'edge_index': edge_index_gpu,
            'feature_indices': batch['feature_indices'],
        }
        output = model(batch)
        calib_preds.append(output['predictions'].cpu())
        calib_targets.append(batch['y'].cpu())
        calib_masks.append(batch['mask'].cpu())

calib_preds = torch.cat(calib_preds, dim=0)
calib_targets = torch.cat(calib_targets, dim=0)
calib_masks = torch.cat(calib_masks, dim=0)

# Flatten batch dimension for calibration
B, T, N, C = calib_preds.shape
calib_preds = calib_preds.view(B*T, N, C)
calib_targets = calib_targets.view(B*T, N, C)
calib_masks = calib_masks.view(B*T, N)

conformal_model.conformal.calibrate(calib_preds, calib_targets, calib_masks)
calib_diagnostics = conformal_model.conformal.get_diagnostics()
print(f'Calibration complete: {calib_diagnostics.get("n_residuals", "N/A")} residuals collected')

# =============================================================================
# STEP 3: Evaluate on test set with uncertainty
# =============================================================================
print('\n' + '='*60)
print('TEST SET EVALUATION')
print('='*60)

all_point_preds = []
all_lower_bounds = []
all_upper_bounds = []
all_targets = []
all_masks = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc='Testing'):
        batch_x = batch['x'].to(device)
        batch_y = batch['y']
        batch_mask = batch['mask']
        
        # Process each sequence in the batch
        for i in range(batch_x.shape[0]):
            x_seq = batch_x[i]  # (T, N, F)
            
            interval = conformal_model.predict_with_uncertainty(
                x_seq, edge_index_gpu, feature_indices=batch['feature_indices']
            )
            
            all_point_preds.append(interval.point_prediction.cpu())
            all_lower_bounds.append(interval.lower_bound.cpu())
            all_upper_bounds.append(interval.upper_bound.cpu())
            all_targets.append(batch_y[i])
            all_masks.append(batch_mask[i])

# Concatenate all
point_preds = torch.cat(all_point_preds, dim=0)
lower_bounds = torch.cat(all_lower_bounds, dim=0)
upper_bounds = torch.cat(all_upper_bounds, dim=0)
targets = torch.cat(all_targets, dim=0)
masks = torch.cat(all_masks, dim=0)

# Extract adult female lice (index 0)
pred_af = point_preds[:, :, 0].numpy()
lower_af = lower_bounds[:, :, 0].numpy()
upper_af = upper_bounds[:, :, 0].numpy()
target_af = targets[:, :, 0].numpy()
mask_np = masks.numpy()

# Flatten and filter by mask
pred_flat = pred_af[mask_np]
lower_flat = lower_af[mask_np]
upper_flat = upper_af[mask_np]
target_flat = target_af[mask_np]

print(f'Valid observations: {len(pred_flat)}')

# =============================================================================
# STEP 4: Conformal Coverage Analysis
# =============================================================================
print('\n' + '='*60)
print('CONFORMAL PREDICTION COVERAGE')
print('='*60)

in_interval = (target_flat >= lower_flat) & (target_flat <= upper_flat)
empirical_coverage = in_interval.mean()

print(f'Target coverage: {CONFORMAL_COVERAGE*100:.0f}%')
print(f'Empirical coverage: {empirical_coverage*100:.1f}%')
print(f'Mean interval width: {(upper_flat - lower_flat).mean():.4f}')

# =============================================================================
# STEP 5: Regression Metrics
# =============================================================================
print('\n' + '='*60)
print('REGRESSION METRICS')
print('='*60)

rmse = np.sqrt(np.mean((pred_flat - target_flat) ** 2))
mae = np.mean(np.abs(pred_flat - target_flat))

print(f'RMSE: {rmse:.4f}')
print(f'MAE:  {mae:.4f}')

# =============================================================================
# STEP 6: Outbreak Detection - Point Predictions
# =============================================================================
print('\n' + '='*60)
print(f'OUTBREAK DETECTION - POINT PREDICTIONS (threshold={OUTBREAK_THRESHOLD})')
print('='*60)

target_binary = (target_flat > OUTBREAK_THRESHOLD).astype(int)
n_outbreaks = target_binary.sum()
n_normal = len(target_binary) - n_outbreaks

print(f'Actual outbreaks: {n_outbreaks} ({100*n_outbreaks/len(target_binary):.1f}%)')

if n_outbreaks > 0:
    pred_binary_point = (pred_flat > OUTBREAK_THRESHOLD).astype(int)
    precision_point = precision_score(target_binary, pred_binary_point, zero_division=0)
    recall_point = recall_score(target_binary, pred_binary_point, zero_division=0)
    f1_point = f1_score(target_binary, pred_binary_point, zero_division=0)
    
    print(f'Precision: {precision_point:.2%}')
    print(f'Recall:    {recall_point:.2%}')
    print(f'F1 Score:  {f1_point:.4f}')
else:
    precision_point, recall_point, f1_point = 0, 0, 0
    print('No outbreaks in test set!')

# =============================================================================
# STEP 7: Outbreak Detection - Risk-Aware (Upper Bounds)
# =============================================================================
print('\n' + '='*60)
print('OUTBREAK DETECTION - RISK-AWARE (Upper Bounds)')
print('='*60)

if n_outbreaks > 0:
    pred_binary_risk = (upper_flat > OUTBREAK_THRESHOLD).astype(int)
    precision_risk = precision_score(target_binary, pred_binary_risk, zero_division=0)
    recall_risk = recall_score(target_binary, pred_binary_risk, zero_division=0)
    f1_risk = f1_score(target_binary, pred_binary_risk, zero_division=0)
    
    print(f'Precision: {precision_risk:.2%}')
    print(f'Recall:    {recall_risk:.2%}')
    print(f'F1 Score:  {f1_risk:.4f}')
    
    cm = confusion_matrix(target_binary, pred_binary_risk)
    print(f'\nConfusion Matrix:')
    print(f'              Predicted')
    print(f'            Normal  At-Risk')
    print(f'  Normal     {cm[0,0]:5d}    {cm[0,1]:5d}')
    print(f'  Outbreak   {cm[1,0]:5d}    {cm[1,1]:5d}')
else:
    precision_risk, recall_risk, f1_risk = 0, 0, 0

# =============================================================================
# STEP 8: Scientific Audit
# =============================================================================
print('\n' + '='*60)
print('SCIENTIFIC AUDIT')
print('='*60)

sample_batch = next(iter(val_loader))
sample_batch = {
    'x': sample_batch['x'].to(device),
    'y': sample_batch['y'].to(device),
    'mask': sample_batch['mask'].to(device),
    'edge_index': edge_index_gpu,
    'feature_indices': sample_batch['feature_indices'],
}

# Test 1: Counterfactual
with torch.no_grad():
    output_orig = model(sample_batch)
    growth_orig = (output_orig['predictions'][:, 1:] - output_orig['predictions'][:, :-1]).mean().item()
    
    x_hot = sample_batch['x'].clone()
    x_hot[..., 0] += 5.0
    batch_hot = {**sample_batch, 'x': x_hot}
    output_hot = model(batch_hot)
    growth_hot = (output_hot['predictions'][:, 1:] - output_hot['predictions'][:, :-1]).mean().item()

test1_pass = growth_hot > growth_orig
print(f'[TEST 1] Counterfactual: {"PASS" if test1_pass else "FAIL"}')
print(f'         Growth normal: {growth_orig:.6f}, Growth +5C: {growth_hot:.6f}')

# Test 2: Long-horizon stability
with torch.no_grad():
    x = sample_batch['x']
    x_ext = x.repeat(1, 3, 1, 1)[:, :90]
    model.network.reset_cache()
    pred_ext, _ = model.network(x_ext, edge_index_gpu)
    has_nan = torch.isnan(pred_ext).any().item()
    has_inf = torch.isinf(pred_ext).any().item()

test2_pass = not (has_nan or has_inf)
print(f'[TEST 2] Long-horizon: {"PASS" if test2_pass else "FAIL"}')

# Test 3: Graphon
with torch.no_grad():
    model.network.reset_cache()
    pred_n, _ = model.network(x, edge_index_gpu)
    mean_n = pred_n.abs().mean().item()
    
    N = x.shape[2]
    x_2n = x.repeat(1, 1, 2, 1)
    edge_2n = torch.cat([edge_index_gpu, edge_index_gpu + N], dim=1)
    model.network.reset_cache()
    pred_2n, _ = model.network(x_2n, edge_2n)
    mean_2n = pred_2n.abs().mean().item()
    deviation = abs(mean_2n - mean_n) / (mean_n + 1e-8)

test3_pass = deviation < 0.10
print(f'[TEST 3] Graphon: {"PASS" if test3_pass else "FAIL"} ({100*deviation:.1f}% deviation)')

print(f'\nAll tests: {"PASSED" if all([test1_pass, test2_pass, test3_pass]) else "SOME FAILED"}')

# =============================================================================
# STEP 9: Final Summary
# =============================================================================
print('\n' + '='*60)
print('FINAL SUMMARY')
print('='*60)

print(f'\n--- Point Prediction Performance ---')
print(f'  Recall:    {recall_point:.1%} (target: 90%)')
print(f'  Precision: {precision_point:.1%} (target: 80%)')
print(f'  F1:        {f1_point:.4f} (target: 0.85)')

print(f'\n--- Risk-Aware Performance (Upper Bound) ---')
print(f'  Recall:    {recall_risk:.1%} (target: 90%)')
print(f'  Precision: {precision_risk:.1%} (target: 80%)')
print(f'  F1:        {f1_risk:.4f} (target: 0.85)')

print(f'\n--- Conformal Coverage ---')
print(f'  Target: {CONFORMAL_COVERAGE*100:.0f}%, Achieved: {empirical_coverage*100:.1f}%')

# Log to wandb
wandb.run.summary['outbreak_threshold'] = OUTBREAK_THRESHOLD
wandb.run.summary['conformal_coverage_target'] = CONFORMAL_COVERAGE
wandb.run.summary['conformal_coverage_achieved'] = float(empirical_coverage)
wandb.run.summary['audit/counterfactual'] = test1_pass
wandb.run.summary['audit/long_horizon'] = test2_pass
wandb.run.summary['audit/graphon'] = test3_pass
wandb.run.summary['test/rmse'] = float(rmse)
wandb.run.summary['test/mae'] = float(mae)
wandb.run.summary['test/precision_point'] = float(precision_point)
wandb.run.summary['test/recall_point'] = float(recall_point)
wandb.run.summary['test/f1_point'] = float(f1_point)
wandb.run.summary['test/precision_risk'] = float(precision_risk)
wandb.run.summary['test/recall_risk'] = float(recall_risk)
wandb.run.summary['test/f1_risk'] = float(f1_risk)

## Cell 10: Save Results and Finish

In [None]:
import json
from datetime import datetime
import matplotlib.pyplot as plt

OUTPUT_DIR = '/content/drive/MyDrive/GLKAN_Project/outputs'
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

# Save results
results = {
    'timestamp': timestamp,
    'config': CONFIG,
    'outbreak_threshold': OUTBREAK_THRESHOLD,
    'conformal_coverage': CONFORMAL_COVERAGE,
    'training': {
        'epochs': len(history['train_loss']),
        'best_val_loss': best_val_loss,
        'lr_schedule': history['lr'],
    },
    'regression': {'rmse': float(rmse), 'mae': float(mae)},
    'point_prediction': {
        'precision': float(precision_point),
        'recall': float(recall_point),
        'f1': float(f1_point),
    },
    'risk_aware': {
        'precision': float(precision_risk),
        'recall': float(recall_risk),
        'f1': float(f1_risk),
    },
    'conformal': {
        'target_coverage': CONFORMAL_COVERAGE,
        'empirical_coverage': float(empirical_coverage),
    },
    'scientific_audit': {
        'counterfactual': test1_pass,
        'long_horizon': test2_pass,
        'graphon': test3_pass,
    },
    'history': history,
}

results_path = f'{OUTPUT_DIR}/results_{timestamp}.json'
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2, default=str)

# Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].legend()
axes[0].set_yscale('log')

axes[1].plot(history['train_rmse'], label='Train')
axes[1].plot(history['val_rmse'], label='Val')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('RMSE')
axes[1].set_title('Training RMSE')
axes[1].legend()

axes[2].plot(history['lr'], label='Learning Rate', color='green')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Learning Rate')
axes[2].set_title('Learning Rate Schedule')
axes[2].set_yscale('log')
axes[2].legend()

plt.tight_layout()
plot_path = f'{OUTPUT_DIR}/training_curves_{timestamp}.png'
plt.savefig(plot_path, dpi=150)
wandb.log({'training_curves': wandb.Image(fig)})
plt.show()

# Save model artifact
artifact = wandb.Artifact('sealice-predictor-model', type='model')
artifact.add_file(f'{CHECKPOINT_DIR}/best_model.pt')
wandb.log_artifact(artifact)

print(f'Results saved to: {results_path}')
print(f'Training curves: {plot_path}')

# Print summary
print('\n' + '='*60)
print('FINAL SUMMARY')
print('='*60)
print(f'Model: SeaLicePredictor (with biological modules)')
print(f'Outbreak threshold: {OUTBREAK_THRESHOLD}')
print(f'Conformal coverage: {CONFORMAL_COVERAGE*100:.0f}% target, {empirical_coverage*100:.1f}% achieved')
print(f'\nRegression:')
print(f'  RMSE: {rmse:.4f}')
print(f'  MAE:  {mae:.4f}')
print(f'\nPoint Prediction:')
print(f'  Precision: {precision_point:.3f}')
print(f'  Recall:    {recall_point:.3f}')
print(f'  F1 Score:  {f1_point:.3f}')
print(f'\nRisk-Aware (Upper Bound):')
print(f'  Precision: {precision_risk:.3f}')
print(f'  Recall:    {recall_risk:.3f}')
print(f'  F1 Score:  {f1_risk:.3f}')
print('='*60)

# Finish wandb
wandb.finish()
print('\nwandb run finished!')