# Graph-Liquid-KAN Sea Lice Prediction - A100 GPU Training
## Phase 4: Production Training on Google Colab

This notebook runs **A100-optimized GPU training** of the Graph-Liquid-KAN architecture with **Weights & Biases** integration for experiment tracking.

**Architecture:**
- **FastKAN Layers**: Gaussian RBF basis functions (learnable non-linearities)
- **GraphonAggregator**: 1/N normalized message passing (scale invariant)
- **LiquidKANCell**: Closed-form Continuous (CfC) dynamics with adaptive tau
- **Physics-Informed Loss**: L_data + lambda_bio * L_bio

**Target Metrics:**
| Metric | Target | Description |
|--------|--------|-------------|
| Recall | >=90% | Catch 9/10 outbreaks |
| Precision | >=80% | 8/10 predictions correct |
| F1 Score | >=0.85 | Balance P/R |

**Runtime Configuration:**
- Runtime -> Change runtime type -> **A100 GPU**
- Runtime -> Change runtime type -> **High RAM**

**Setup:**
1. Get wandb API key from https://wandb.ai/authorize
2. Upload `glkan_data.zip` to Google Drive root
3. Run all cells

## Cell 1: Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
os.makedirs('/content/drive/MyDrive/GLKAN_Project/checkpoints', exist_ok=True)
os.makedirs('/content/drive/MyDrive/GLKAN_Project/outputs', exist_ok=True)
print('Google Drive mounted and directories created')

## Cell 2: Clone Repository & Install Dependencies

In [None]:
import os

# Clone the Graph-Liquid-KAN repository
REPO_URL = 'https://github.com/themythicalyeti/graph-liquid-kan.git'
REPO_DIR = '/content/graph-liquid-kan'

if os.path.exists(REPO_DIR):
    print(f'Repository already exists at {REPO_DIR}')
    %cd {REPO_DIR}
    !git pull
else:
    !git clone {REPO_URL} {REPO_DIR}
    %cd {REPO_DIR}

print(f'\nWorking directory: {os.getcwd()}')
!ls -la

In [None]:
# Install dependencies
!pip install -q torch torchvision --upgrade
!pip install -q numpy pandas scipy scikit-learn
!pip install -q loguru tqdm matplotlib
!pip install -q torch-geometric
!pip install -q wandb

print('\nDependencies installed')

# =============================================================================
# WANDB AUTHENTICATION
# =============================================================================
# Option 1: Use Colab Secrets (recommended - no prompt each time)
#   1. Click the key icon in left sidebar
#   2. Add secret named "WANDB_API_KEY" with your key from https://wandb.ai/authorize
#
# Option 2: Manual login (will prompt for API key)
#   Just run the cell - it will ask for your key

import wandb
import os

try:
    from google.colab import userdata
    WANDB_KEY = userdata.get('WANDB_API_KEY')
    os.environ['WANDB_API_KEY'] = WANDB_KEY
    wandb.login(key=WANDB_KEY)
    print('Logged in to wandb using Colab Secrets')
except:
    print('Colab Secrets not configured - will prompt for API key')
    print('Get your key at: https://wandb.ai/authorize')
    wandb.login()
    
print(f'wandb authenticated as: {wandb.api.viewer()["entity"]}')

## Cell 3: Verify GPU

In [None]:
!nvidia-smi

import torch
import gc

print('\n' + '='*60)
print('GPU VERIFICATION')
print('='*60)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {gpu_name}")
    print(f"VRAM: {gpu_memory:.1f} GB")
    device = torch.device('cuda')
    
    # Enable TF32 for faster training on Ampere GPUs
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    
    USE_AMP = True
    print(f"Mixed Precision (AMP): Enabled")
    
    torch.cuda.empty_cache()
    gc.collect()
else:
    print("[WARN] CUDA not available - using CPU (will be VERY slow)")
    print("[WARN] Please enable GPU: Runtime -> Change runtime type -> T4 GPU")
    device = torch.device('cpu')
    USE_AMP = False

print(f"\nDefault device: {device}")
print('='*60)

## Cell 4: Load Data from Drive

In [None]:
import os
import numpy as np
import torch

# Path to data on Drive
DRIVE_DATA = '/content/drive/MyDrive/glkan_data.zip'
LOCAL_DATA = '/content/data'

if not os.path.exists(DRIVE_DATA):
    print(f'ERROR: Data not found at {DRIVE_DATA}')
    print('Please upload glkan_data.zip containing:')
    print('  - tensors.npz (from Phase 2)')
    print('  - spatial_graph.pt (from Phase 2)')
else:
    print('Extracting data...')
    os.makedirs(LOCAL_DATA, exist_ok=True)
    !unzip -q "{DRIVE_DATA}" -d {LOCAL_DATA}
    !ls -la {LOCAL_DATA}
    print('\nData loaded')

# Verify data
TENSOR_PATH = f'{LOCAL_DATA}/tensors.npz'
GRAPH_PATH = f'{LOCAL_DATA}/spatial_graph.pt'

if os.path.exists(TENSOR_PATH) and os.path.exists(GRAPH_PATH):
    data = np.load(TENSOR_PATH, allow_pickle=True)
    print(f"\nData shapes:")
    print(f"  X (features): {data['X'].shape}")
    print(f"  Y (targets):  {data['Y'].shape}")
    print(f"  mask:         {data['mask'].shape}")
    
    graph = torch.load(GRAPH_PATH, weights_only=False)
    print(f"  edges:        {graph['edge_index'].shape[1]}")

## Cell 5: Import GLKAN Architecture from Repository

In [None]:
import sys
sys.path.insert(0, '/content/graph-liquid-kan')

# Import architecture from src/models
from src.models import (
    FastKAN,
    GraphonAggregator,
    LiquidKANCell,
    GraphLiquidKANCell,
    GLKANNetwork,
    GLKANPredictor,
)

# Import training utilities from src/training
from src.training import PhysicsInformedLoss, GLKANLoss
from src.training.losses import LossConfig  # Import config class

# Import dataset from src/data
from src.data import SeaLiceGraphDataset

print('Imported from repository:')
print('  - FastKAN, GraphonAggregator, LiquidKANCell')
print('  - GraphLiquidKANCell, GLKANNetwork, GLKANPredictor')
print('  - PhysicsInformedLoss, GLKANLoss, LossConfig')
print('  - SeaLiceGraphDataset')
print('\nGraph-Liquid-KAN architecture loaded from src/')

## Cell 6: Create Dataset and DataLoaders

In [None]:
from torch.utils.data import Dataset, DataLoader

class SeaLiceDataset(Dataset):
    """Dataset for GLKAN training."""
    
    def __init__(self, X, Y, mask, edge_index, window_size=30, stride=7, time_start=0, time_end=None):
        self.X = X
        self.Y = Y
        self.mask = mask
        self.edge_index = edge_index
        self.window_size = window_size
        
        time_end = time_end or X.shape[0]
        self.sequences = []
        for t in range(time_start, time_end - window_size, stride):
            self.sequences.append((t, t + window_size))
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        t_start, t_end = self.sequences[idx]
        return {
            'x': self.X[t_start:t_end],
            'y': self.Y[t_start:t_end],
            'mask': self.mask[t_start:t_end],
            'edge_index': self.edge_index,
        }

def collate_fn(batch):
    return {
        'x': torch.stack([b['x'] for b in batch]),
        'y': torch.stack([b['y'] for b in batch]),
        'mask': torch.stack([b['mask'] for b in batch]),
        'edge_index': batch[0]['edge_index'],
    }

# Load data
data = np.load(TENSOR_PATH, allow_pickle=True)
graph = torch.load(GRAPH_PATH, weights_only=False)

X = torch.from_numpy(data['X']).float()
Y = torch.from_numpy(data['Y']).float()
mask = torch.from_numpy(data['mask']).bool()
edge_index = graph['edge_index']

print(f'Data loaded:')
print(f'  X: {X.shape}')
print(f'  Y: {Y.shape}')
print(f'  edge_index: {edge_index.shape}')

# Train/Val/Test split (70/15/15)
T_total = X.shape[0]
T_train = int(T_total * 0.70)
T_val = int(T_total * 0.85)

# Configuration
WINDOW_SIZE = 30
STRIDE = 7
BATCH_SIZE = 8

train_ds = SeaLiceDataset(X, Y, mask, edge_index, WINDOW_SIZE, STRIDE, 0, T_train)
val_ds = SeaLiceDataset(X, Y, mask, edge_index, WINDOW_SIZE, STRIDE, T_train, T_val)
test_ds = SeaLiceDataset(X, Y, mask, edge_index, WINDOW_SIZE, STRIDE, T_val, T_total)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, 
                          collate_fn=collate_fn, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, 
                        collate_fn=collate_fn, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, 
                         collate_fn=collate_fn, num_workers=2, pin_memory=True)

print(f'\nDataLoaders created:')
print(f'  Train sequences: {len(train_ds)}')
print(f'  Val sequences: {len(val_ds)}')
print(f'  Test sequences: {len(test_ds)}')
print(f'  Nodes: {X.shape[1]}')
print(f'  Edges: {edge_index.shape[1]}')

## Cell 7: Create Model and Optimizer

In [None]:
# A100 Optimized Configuration
CONFIG = {
    'hidden_dim': 128,
    'n_bases': 12,
    'n_layers': 3,
    'dropout': 0.15,
    'lr': 1e-4,
    'weight_decay': 1e-4,
    'grad_clip': 1.0,
    'epochs': 100,
    'lambda_bio': 0.1,
    'lambda_stability': 0.01,
    'patience': 15,
    'min_delta': 1e-6,
    'n_nodes': X.shape[1],
    'n_edges': edge_index.shape[1],
    'n_features': X.shape[-1],
    'window_size': WINDOW_SIZE,
    'stride': STRIDE,
    'batch_size': 16,
}

BATCH_SIZE = CONFIG['batch_size']

# Initialize wandb run (already authenticated in Cell 2)
print('='*60)
print('INITIALIZING WANDB RUN')
print('='*60)

run = wandb.init(
    project="graph-liquid-kan",
    name=f"glkan-a100-{CONFIG['hidden_dim']}h-{CONFIG['n_layers']}L",
    config=CONFIG,
    tags=["A100", "sea-lice", "graph-neural-network", "liquid-networks"],
)

if torch.cuda.is_available():
    wandb.config.update({
        'gpu_name': torch.cuda.get_device_name(0),
        'gpu_memory_gb': torch.cuda.get_device_properties(0).total_memory / 1e9,
        'mixed_precision': USE_AMP,
    })

print(f'wandb run: {wandb.run.name}')
print(f'wandb URL: {wandb.run.get_url()}')

# Move edge_index to GPU
print('\n' + '='*60)
print('MOVING DATA TO GPU')
print('='*60)

edge_index_gpu = edge_index.to(device)
print(f'  edge_index moved to: {edge_index_gpu.device}')

train_ds.edge_index = edge_index_gpu
val_ds.edge_index = edge_index_gpu
test_ds.edge_index = edge_index_gpu

# Recreate dataloaders with num_workers=0
# IMPORTANT: Cannot use num_workers>0 when edge_index is on GPU
# because CUDA tensors cannot be pickled across worker processes
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, 
                          collate_fn=collate_fn, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, 
                        collate_fn=collate_fn, num_workers=0)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, 
                         collate_fn=collate_fn, num_workers=0)

print(f'  DataLoaders recreated with num_workers=0 (required for GPU edge_index)')

# Create model
input_dim = X.shape[-1]
output_dim = Y.shape[-1]

model = GLKANPredictor(
    input_dim=input_dim,
    hidden_dim=CONFIG['hidden_dim'],
    output_dim=output_dim,
    n_bases=CONFIG['n_bases'],
    n_layers=CONFIG['n_layers'],
    dropout=CONFIG['dropout'],
).to(device)

# Optimizer and scheduler
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG['lr'],
    weight_decay=CONFIG['weight_decay'],
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=20, T_mult=2, eta_min=1e-7
)

# Create loss function with LossConfig
loss_config = LossConfig(
    lambda_bio=CONFIG['lambda_bio'],
    lambda_stability=CONFIG['lambda_stability'],
)
criterion = PhysicsInformedLoss(config=loss_config)

scaler = torch.cuda.amp.GradScaler() if USE_AMP else None

n_params = sum(p.numel() for p in model.parameters())
wandb.config.update({'n_parameters': n_params})
wandb.watch(model, log='all', log_freq=100)

print(f'\nModel created:')
print(f'  Parameters: {n_params:,}')
print(f'  Device: {next(model.parameters()).device}')
print(f'  Hidden dim: {CONFIG["hidden_dim"]}')
print(f'  Layers: {CONFIG["n_layers"]}')

## Cell 8: Training Loop

In [None]:
from tqdm.auto import tqdm
import time

EPOCHS = CONFIG['epochs']
PATIENCE = CONFIG['patience']
CHECKPOINT_DIR = '/content/drive/MyDrive/GLKAN_Project/checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

best_val_loss = float('inf')
patience_counter = 0
history = {'train_loss': [], 'val_loss': [], 'train_rmse': [], 'val_rmse': [], 'lr': []}

print('='*60)
print('GRAPH-LIQUID-KAN TRAINING')
print('='*60)
print(f'Device: {device}')
print(f'Epochs: {EPOCHS}')
print(f'wandb: {wandb.run.get_url()}')

start_time = time.time()

for epoch in range(EPOCHS):
    epoch_start = time.time()
    
    # Training
    model.train()
    train_loss = 0
    train_rmse = 0
    n_batches = 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{EPOCHS}', leave=False)
    for batch_idx, batch in enumerate(pbar):
        batch = {k: v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v 
                 for k, v in batch.items()}
        
        optimizer.zero_grad(set_to_none=True)
        
        if USE_AMP:
            with torch.cuda.amp.autocast():
                output = model(batch)
                loss, metrics = criterion(output['predictions'], batch['y'], batch['mask'])
            
            if torch.isnan(loss):
                continue
            
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['grad_clip'])
            scaler.step(optimizer)
            scaler.update()
        else:
            output = model(batch)
            loss, metrics = criterion(output['predictions'], batch['y'], batch['mask'])
            
            if torch.isnan(loss):
                continue
            
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['grad_clip'])
            optimizer.step()
        
        train_loss += loss.item()
        
        with torch.no_grad():
            mask_exp = batch['mask'].unsqueeze(-1).expand_as(output['predictions'])
            sq_err = ((output['predictions'] - batch['y']) ** 2) * mask_exp.float()
            rmse = torch.sqrt(sq_err.sum() / mask_exp.float().sum().clamp(min=1))
            train_rmse += rmse.item()
        
        n_batches += 1
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        if batch_idx % 10 == 0:
            wandb.log({'batch/loss': loss.item(), 'batch/rmse': rmse.item()}, commit=False)
    
    train_loss /= max(n_batches, 1)
    train_rmse /= max(n_batches, 1)
    
    # Validation
    model.eval()
    val_loss = 0
    val_rmse = 0
    n_val = 0
    
    with torch.no_grad():
        for batch in val_loader:
            batch = {k: v.to(device, non_blocking=True) if isinstance(v, torch.Tensor) else v 
                     for k, v in batch.items()}
            
            if USE_AMP:
                with torch.cuda.amp.autocast():
                    output = model(batch)
                    loss, _ = criterion(output['predictions'], batch['y'], batch['mask'])
            else:
                output = model(batch)
                loss, _ = criterion(output['predictions'], batch['y'], batch['mask'])
            
            val_loss += loss.item()
            
            mask_exp = batch['mask'].unsqueeze(-1).expand_as(output['predictions'])
            sq_err = ((output['predictions'] - batch['y']) ** 2) * mask_exp.float()
            rmse = torch.sqrt(sq_err.sum() / mask_exp.float().sum().clamp(min=1))
            val_rmse += rmse.item()
            n_val += 1
    
    val_loss /= max(n_val, 1)
    val_rmse /= max(n_val, 1)
    
    scheduler.step()
    lr = optimizer.param_groups[0]['lr']
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_rmse'].append(train_rmse)
    history['val_rmse'].append(val_rmse)
    history['lr'].append(lr)
    
    epoch_time = time.time() - epoch_start
    
    # Log to wandb
    wandb.log({
        'epoch': epoch + 1,
        'train/loss': train_loss,
        'train/rmse': train_rmse,
        'val/loss': val_loss,
        'val/rmse': val_rmse,
        'learning_rate': lr,
        'epoch_time_seconds': epoch_time,
    })
    
    # Checkpointing
    improved = val_loss < best_val_loss - CONFIG['min_delta']
    
    if improved:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'config': CONFIG,
        }, f'{CHECKPOINT_DIR}/best_model.pt')
        marker = '* Best'
        wandb.run.summary['best_val_loss'] = val_loss
        wandb.run.summary['best_epoch'] = epoch + 1
    else:
        patience_counter += 1
        marker = f'({patience_counter}/{PATIENCE})'
    
    print(f'Epoch {epoch+1:3d}/{EPOCHS} | Train: {train_loss:.4f} | Val: {val_loss:.4f} | {marker}')
    
    if patience_counter >= PATIENCE:
        print(f'\nEarly stopping at epoch {epoch+1}')
        break
    
    if (epoch + 1) % 5 == 0:
        torch.cuda.empty_cache()
        gc.collect()

elapsed = time.time() - start_time
print(f'\nTraining complete in {elapsed/60:.1f} minutes')
print(f'Best validation loss: {best_val_loss:.6f}')

## Cell 9: Evaluate and Scientific Audit

In [None]:
from sklearn.metrics import precision_recall_curve, f1_score, precision_score, recall_score

# Load best model
checkpoint = torch.load(f'{CHECKPOINT_DIR}/best_model.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f'Loaded best model from epoch {checkpoint["epoch"]+1}')

# Evaluate on test set
all_preds = []
all_targets = []
all_masks = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc='Testing'):
        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
        
        if USE_AMP:
            with torch.cuda.amp.autocast():
                output = model(batch)
        else:
            output = model(batch)
        
        all_preds.append(output['predictions'].cpu())
        all_targets.append(batch['y'].cpu())
        all_masks.append(batch['mask'].cpu())

preds = torch.cat(all_preds, dim=0)
targets = torch.cat(all_targets, dim=0)
masks = torch.cat(all_masks, dim=0)

# Regression metrics
pred_flat = preds[:, :, :, 0].numpy().flatten()
target_flat = targets[:, :, :, 0].numpy().flatten()
mask_flat = masks.numpy().flatten()

pred_valid = pred_flat[mask_flat]
target_valid = target_flat[mask_flat]

rmse = np.sqrt(np.mean((pred_valid - target_valid) ** 2))
mae = np.mean(np.abs(pred_valid - target_valid))

print(f'\nTest Set Results:')
print(f'  RMSE: {rmse:.4f}')
print(f'  MAE:  {mae:.4f}')

# Scientific Audit
print('\n' + '='*60)
print('SCIENTIFIC AUDIT')
print('='*60)

sample_batch = next(iter(val_loader))
sample_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in sample_batch.items()}

# Test 1: Counterfactual
with torch.no_grad():
    output_orig = model(sample_batch)
    growth_orig = (output_orig['predictions'][:, 1:] - output_orig['predictions'][:, :-1]).mean().item()
    
    x_hot = sample_batch['x'].clone()
    x_hot[..., 0] += 5.0
    batch_hot = {**sample_batch, 'x': x_hot}
    output_hot = model(batch_hot)
    growth_hot = (output_hot['predictions'][:, 1:] - output_hot['predictions'][:, :-1]).mean().item()

test1_pass = growth_hot > growth_orig
print(f'[TEST 1] Counterfactual: {"PASS" if test1_pass else "FAIL"}')

# Test 2: Long-horizon stability
with torch.no_grad():
    x = sample_batch['x']
    x_ext = x.repeat(1, 3, 1, 1)[:, :90]
    model.network.reset_cache()
    pred_ext, _ = model.network(x_ext, sample_batch['edge_index'])
    has_nan = torch.isnan(pred_ext).any().item()
    has_inf = torch.isinf(pred_ext).any().item()

test2_pass = not (has_nan or has_inf)
print(f'[TEST 2] Long-horizon: {"PASS" if test2_pass else "FAIL"}')

# Test 3: Graphon
with torch.no_grad():
    model.network.reset_cache()
    pred_n, _ = model.network(x, sample_batch['edge_index'])
    mean_n = pred_n.abs().mean().item()
    
    N = x.shape[2]
    x_2n = x.repeat(1, 1, 2, 1)
    edge_2n = torch.cat([sample_batch['edge_index'], sample_batch['edge_index'] + N], dim=1)
    model.network.reset_cache()
    pred_2n, _ = model.network(x_2n, edge_2n)
    mean_2n = pred_2n.abs().mean().item()
    deviation = abs(mean_2n - mean_n) / (mean_n + 1e-8)

test3_pass = deviation < 0.10
print(f'[TEST 3] Graphon: {"PASS" if test3_pass else "FAIL"} ({100*deviation:.1f}% deviation)')

print(f'\nAll tests: {"PASSED" if all([test1_pass, test2_pass, test3_pass]) else "SOME FAILED"}')

# Log to wandb
wandb.run.summary['audit/counterfactual'] = test1_pass
wandb.run.summary['audit/long_horizon'] = test2_pass
wandb.run.summary['audit/graphon'] = test3_pass
wandb.run.summary['test/rmse'] = float(rmse)
wandb.run.summary['test/mae'] = float(mae)

## Cell 10: Save Results and Finish

In [None]:
import json
from datetime import datetime
import matplotlib.pyplot as plt

OUTPUT_DIR = '/content/drive/MyDrive/GLKAN_Project/outputs'
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

# Save results
results = {
    'timestamp': timestamp,
    'config': CONFIG,
    'training': {
        'epochs': len(history['train_loss']),
        'best_val_loss': best_val_loss,
    },
    'regression': {'rmse': float(rmse), 'mae': float(mae)},
    'scientific_audit': {
        'counterfactual': test1_pass,
        'long_horizon': test2_pass,
        'graphon': test3_pass,
    },
    'history': history,
}

results_path = f'{OUTPUT_DIR}/results_{timestamp}.json'
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2, default=str)

# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].legend()
axes[0].set_yscale('log')

axes[1].plot(history['train_rmse'], label='Train')
axes[1].plot(history['val_rmse'], label='Val')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('RMSE')
axes[1].set_title('Training RMSE')
axes[1].legend()

plt.tight_layout()
plot_path = f'{OUTPUT_DIR}/training_curves_{timestamp}.png'
plt.savefig(plot_path, dpi=150)
wandb.log({'training_curves': wandb.Image(fig)})
plt.show()

# Save model artifact
artifact = wandb.Artifact('glkan-model', type='model')
artifact.add_file(f'{CHECKPOINT_DIR}/best_model.pt')
wandb.log_artifact(artifact)

print(f'Results saved to: {results_path}')
print(f'Training curves: {plot_path}')

# Finish wandb
wandb.finish()
print('\nwandb run finished!')