# SwellSight Model Training

This notebook trains the SwellSight multi-task model for wave analysis.

## Purpose
- Train the model on mixed real and synthetic data
- Monitor training progress with live metrics
- Implement early stopping and checkpointing
- Visualize training curves and performance

In [None]:
import os
import json
import time
from datetime import datetime

import torch
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
import matplotlib.pyplot as plt
import pandas as pd
from tqdm.notebook import tqdm
import numpy as np

In [None]:
# Import components from previous notebooks
%run 01_model_architecture.ipynb
%run 02_data_loading.ipynb
%run 03_loss_and_metrics.ipynb

## Training Configuration

In [None]:
# Training configuration
CONFIG = {
    # Data paths
    "train_jsonl": "data/synthetic/mix/train.jsonl",
    "val_jsonl": "data/synthetic/mix/val.jsonl",
    "out_dir": f"runs/swell_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
    
    # Model parameters
    "image_size": 224,
    "dropout": 0.35,
    
    # Training parameters
    "batch_size": 16,
    "epochs": 25,
    "lr": 3e-4,
    "weight_decay": 1e-4,
    
    # Loss weights
    "w_height": 1.0,
    "w_wave_type": 1.0,
    "w_direction": 1.0,
    
    # Other
    "seed": 42,
    "num_workers": 2,
    "pin_memory": True,
    "patience": 5,  # Early stopping patience
}

print("Training Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

## Setup and Data Loading

In [None]:
# Set seed and device
set_seed(CONFIG["seed"])
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Create output directory
ensure_dir(CONFIG["out_dir"])
print(f"Output directory: {CONFIG['out_dir']}")

# Save configuration
with open(os.path.join(CONFIG["out_dir"], "config.json"), "w") as f:
    json.dump(CONFIG, f, indent=2)

In [None]:
# Load training data and build vocabularies
try:
    train_items = read_jsonl(CONFIG["train_jsonl"])
    print(f"✓ Loaded {len(train_items)} training samples")
except FileNotFoundError:
    print(f"✗ Training data not found: {CONFIG['train_jsonl']}")
    print("Please run the data preparation notebooks first.")
    # Create dummy data for demonstration
    print("Creating dummy training data...")
    train_items = [
        {
            "image_path": f"dummy_train_{i}.jpg",
            "height_meters": float(np.random.uniform(0.5, 2.5)),
            "wave_type": np.random.choice(["beach_break", "reef_break", "point_break", "closeout", "a_frame"]),
            "direction": np.random.choice(["left", "right", "both"]),
            "confidence": "high",
            "source": "dummy"
        }
        for i in range(200)
    ]

try:
    val_items = read_jsonl(CONFIG["val_jsonl"])
    print(f"✓ Loaded {len(val_items)} validation samples")
except FileNotFoundError:
    print(f"✗ Validation data not found: {CONFIG['val_jsonl']}")
    # Create dummy validation data
    val_items = [
        {
            "image_path": f"dummy_val_{i}.jpg",
            "height_meters": float(np.random.uniform(0.5, 2.5)),
            "wave_type": np.random.choice(["beach_break", "reef_break", "point_break", "closeout", "a_frame"]),
            "direction": np.random.choice(["left", "right", "both"]),
            "confidence": "high",
            "source": "dummy"
        }
        for i in range(50)
    ]

# Build vocabularies from training data
wt2id, d2id = build_vocabs(train_items)
print(f"\nVocabularies:")
print(f"Wave types: {wt2id}")
print(f"Directions: {d2id}")

# Save vocabularies
vocabs = {"wave_type_to_id": wt2id, "direction_to_id": d2id}
with open(os.path.join(CONFIG["out_dir"], "vocabs.json"), "w") as f:
    json.dump(vocabs, f, indent=2)

In [None]:
# Create datasets and data loaders
print("Creating datasets...")

# Note: For demonstration with dummy data, we'll create a simple dataset
# In practice, you would use the actual SwellSightDataset

class DummyDataset:
    """Dummy dataset for demonstration when real images are not available."""
    def __init__(self, items, transform, wave_type_to_id, direction_to_id):
        self.items = items
        self.transform = transform
        self.wave_type_to_id = wave_type_to_id
        self.direction_to_id = direction_to_id
    
    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, idx):
        r = self.items[idx]
        
        # Create dummy image tensor
        if self.transform:
            # Simulate transformed image
            img = torch.randn(3, CONFIG["image_size"], CONFIG["image_size"])
        else:
            img = torch.randn(3, CONFIG["image_size"], CONFIG["image_size"])
        
        height = torch.tensor([float(r["height_meters"])], dtype=torch.float32)
        wave_type = torch.tensor(self.wave_type_to_id[r["wave_type"]], dtype=torch.long)
        direction = torch.tensor(self.direction_to_id[r["direction"]], dtype=torch.long)
        weight = torch.tensor([1.0], dtype=torch.float32)  # High confidence
        
        return {
            "image": img,
            "height": height,
            "wave_type": wave_type,
            "direction": direction,
            "weight": weight,
            "meta": r
        }

# Check if we have real images or use dummy data
use_dummy_data = not os.path.exists(train_items[0]["image_path"]) if train_items else True

if use_dummy_data:
    print("⚠️ Using dummy data for demonstration (no real images found)")
    train_ds = DummyDataset(
        train_items,
        transform=build_transforms(True, CONFIG["image_size"]),
        wave_type_to_id=wt2id,
        direction_to_id=d2id
    )
    val_ds = DummyDataset(
        val_items,
        transform=build_transforms(False, CONFIG["image_size"]),
        wave_type_to_id=wt2id,
        direction_to_id=d2id
    )
else:
    print("✓ Using real image data")
    train_ds = SwellSightDataset(
        CONFIG["train_jsonl"],
        transform=build_transforms(True, CONFIG["image_size"]),
        wave_type_to_id=wt2id,
        direction_to_id=d2id
    )
    val_ds = SwellSightDataset(
        CONFIG["val_jsonl"],
        transform=build_transforms(False, CONFIG["image_size"]),
        wave_type_to_id=wt2id,
        direction_to_id=d2id
    )

train_loader = DataLoader(
    train_ds, 
    batch_size=CONFIG["batch_size"], 
    shuffle=True, 
    num_workers=CONFIG["num_workers"], 
    pin_memory=CONFIG["pin_memory"]
)

val_loader = DataLoader(
    val_ds, 
    batch_size=CONFIG["batch_size"], 
    shuffle=False, 
    num_workers=CONFIG["num_workers"], 
    pin_memory=CONFIG["pin_memory"]
)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

## Model Setup

In [None]:
# Create model
model = SwellSightNet(
    num_wave_types=len(wt2id), 
    num_directions=len(d2id), 
    dropout=CONFIG["dropout"]
).to(device)

print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")

# Create loss function
loss_fn = MultiTaskLoss(
    w_height=CONFIG["w_height"],
    w_wave_type=CONFIG["w_wave_type"],
    w_direction=CONFIG["w_direction"]
)

# Create optimizer and scheduler
optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=CONFIG["lr"], 
    weight_decay=CONFIG["weight_decay"]
)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode="min", 
    factor=0.5, 
    patience=CONFIG["patience"]//2,
    verbose=True
)

# Mixed precision scaler
scaler = GradScaler() if device == "cuda" else None

print("✓ Model, optimizer, and scheduler created")

## Training Functions

In [None]:
def run_epoch(model, loader, optimizer, loss_fn, device, train: bool, scaler=None):
    """Run one epoch of training or validation."""
    model.train() if train else model.eval()

    total_loss = 0.0
    n = 0

    all_h_pred, all_h_true = [], []
    all_wt_logits, all_wt_true = [], []
    all_dir_logits, all_dir_true = [], []

    progress_bar = tqdm(loader, desc=f"{'Train' if train else 'Val'} Epoch", leave=False)
    
    for batch in progress_bar:
        x = batch["image"].to(device)
        y_h = batch["height"].to(device)
        y_wt = batch["wave_type"].to(device).view(-1)
        y_dir = batch["direction"].to(device).view(-1)
        w = batch["weight"].to(device)

        if train:
            optimizer.zero_grad(set_to_none=True)

        with torch.set_grad_enabled(train):
            with autocast(enabled=(device == "cuda" and scaler is not None)):
                pred_h, pred_wt, pred_dir = model(x)
                loss, loss_dict = loss_fn(pred_h, pred_wt, pred_dir, y_h, y_wt, y_dir, w)

            if train and scaler is not None:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            elif train:
                loss.backward()
                optimizer.step()

        total_loss += loss.item() * x.size(0)
        n += x.size(0)

        # Collect predictions for metrics
        all_h_pred.append(pred_h.detach().cpu())
        all_h_true.append(y_h.detach().cpu())
        all_wt_logits.append(pred_wt.detach().cpu())
        all_wt_true.append(y_wt.detach().cpu())
        all_dir_logits.append(pred_dir.detach().cpu())
        all_dir_true.append(y_dir.detach().cpu())
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'h_loss': f'{loss_dict["loss_h"]:.4f}',
            'wt_loss': f'{loss_dict["loss_wt"]:.4f}',
            'dir_loss': f'{loss_dict["loss_dir"]:.4f}'
        })

    # Compute metrics
    h_pred = torch.cat(all_h_pred, 0)
    h_true = torch.cat(all_h_true, 0)
    wt_logits = torch.cat(all_wt_logits, 0)
    wt_true = torch.cat(all_wt_true, 0)
    dir_logits = torch.cat(all_dir_logits, 0)
    dir_true = torch.cat(all_dir_true, 0)

    metrics = {}
    metrics.update(regression_metrics(h_true, h_pred))
    metrics["wave_type_acc"] = accuracy(wt_true, wt_logits)
    metrics["direction_acc"] = accuracy(dir_true, dir_logits)
    metrics["wave_type_macro_f1"] = macro_f1(wt_true, wt_logits, wt_logits.shape[1])
    metrics["direction_macro_f1"] = macro_f1(dir_true, dir_logits, dir_logits.shape[1])

    return (total_loss / max(1, n)), metrics

## Training Loop

In [None]:
# Training history
history = {
    "train_loss": [],
    "val_loss": [],
    "train_metrics": [],
    "val_metrics": [],
    "lr": [],
    "epoch_times": []
}

# Early stopping
best_val_loss = float('inf')
patience_counter = 0
best_epoch = 0

print(f"Starting training for {CONFIG['epochs']} epochs...")
print(f"Device: {device}")
print(f"Batch size: {CONFIG['batch_size']}")
print(f"Learning rate: {CONFIG['lr']}")
print("-" * 80)

In [None]:
# Main training loop
for epoch in range(CONFIG["epochs"]):
    epoch_start_time = time.time()
    
    print(f"\nEpoch {epoch+1}/{CONFIG['epochs']}")
    print("-" * 40)
    
    # Training phase
    train_loss, train_metrics = run_epoch(
        model, train_loader, optimizer, loss_fn, device, train=True, scaler=scaler
    )
    
    # Validation phase
    val_loss, val_metrics = run_epoch(
        model, val_loader, optimizer, loss_fn, device, train=False, scaler=None
    )
    
    # Update learning rate
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    
    # Record history
    epoch_time = time.time() - epoch_start_time
    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["train_metrics"].append(train_metrics)
    history["val_metrics"].append(val_metrics)
    history["lr"].append(current_lr)
    history["epoch_times"].append(epoch_time)
    
    # Print epoch results
    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | LR: {current_lr:.2e}")
    print(f"Train MAE: {train_metrics['mae']:.4f} | Val MAE: {val_metrics['mae']:.4f}")
    print(f"Train WT Acc: {train_metrics['wave_type_acc']:.4f} | Val WT Acc: {val_metrics['wave_type_acc']:.4f}")
    print(f"Train Dir Acc: {train_metrics['direction_acc']:.4f} | Val Dir Acc: {val_metrics['direction_acc']:.4f}")
    print(f"Epoch time: {epoch_time:.1f}s")
    
    # Save checkpoint
    checkpoint = {
        "epoch": epoch + 1,
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict(),
        "val_loss": val_loss,
        "config": CONFIG
    }
    
    # Save last checkpoint
    torch.save(checkpoint, os.path.join(CONFIG["out_dir"], "last.pt"))
    
    # Save best checkpoint
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_epoch = epoch + 1
        patience_counter = 0
        torch.save(checkpoint, os.path.join(CONFIG["out_dir"], "best.pt"))
        print(f"✓ New best model saved (val_loss: {val_loss:.4f})")
    else:
        patience_counter += 1
        print(f"No improvement for {patience_counter} epochs")
    
    # Early stopping
    if patience_counter >= CONFIG["patience"]:
        print(f"\nEarly stopping triggered after {patience_counter} epochs without improvement")
        print(f"Best model was at epoch {best_epoch} with val_loss: {best_val_loss:.4f}")
        break
    
    # Save history every few epochs
    if (epoch + 1) % 5 == 0 or epoch == CONFIG["epochs"] - 1:
        with open(os.path.join(CONFIG["out_dir"], "history.json"), "w") as f:
            # Convert numpy types to native Python types for JSON serialization
            history_serializable = {}
            for key, values in history.items():
                if key in ["train_metrics", "val_metrics"]:
                    history_serializable[key] = [
                        {k: float(v) for k, v in metrics.items()} for metrics in values
                    ]
                else:
                    history_serializable[key] = [float(v) for v in values]
            json.dump(history_serializable, f, indent=2)

print(f"\nTraining completed!")
print(f"Best model: epoch {best_epoch}, val_loss: {best_val_loss:.4f}")
print(f"Total training time: {sum(history['epoch_times']):.1f}s")

## Training Visualization

In [None]:
# Plot training curves
def plot_training_curves(history):
    """Plot comprehensive training curves."""
    epochs = range(1, len(history["train_loss"]) + 1)
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    
    # Loss curves
    axes[0, 0].plot(epochs, history["train_loss"], 'b-', label='Train', linewidth=2)
    axes[0, 0].plot(epochs, history["val_loss"], 'r-', label='Validation', linewidth=2)
    axes[0, 0].set_title('Loss Curves')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # MAE curves
    train_mae = [m['mae'] for m in history["train_metrics"]]
    val_mae = [m['mae'] for m in history["val_metrics"]]
    axes[0, 1].plot(epochs, train_mae, 'b-', label='Train', linewidth=2)
    axes[0, 1].plot(epochs, val_mae, 'r-', label='Validation', linewidth=2)
    axes[0, 1].set_title('Height MAE')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('MAE (meters)')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Wave type accuracy
    train_wt_acc = [m['wave_type_acc'] for m in history["train_metrics"]]
    val_wt_acc = [m['wave_type_acc'] for m in history["val_metrics"]]
    axes[0, 2].plot(epochs, train_wt_acc, 'b-', label='Train', linewidth=2)
    axes[0, 2].plot(epochs, val_wt_acc, 'r-', label='Validation', linewidth=2)
    axes[0, 2].set_title('Wave Type Accuracy')
    axes[0, 2].set_xlabel('Epoch')
    axes[0, 2].set_ylabel('Accuracy')
    axes[0, 2].legend()
    axes[0, 2].grid(True, alpha=0.3)
    
    # Direction accuracy
    train_dir_acc = [m['direction_acc'] for m in history["train_metrics"]]
    val_dir_acc = [m['direction_acc'] for m in history["val_metrics"]]
    axes[1, 0].plot(epochs, train_dir_acc, 'b-', label='Train', linewidth=2)
    axes[1, 0].plot(epochs, val_dir_acc, 'r-', label='Validation', linewidth=2)
    axes[1, 0].set_title('Direction Accuracy')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Accuracy')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Learning rate
    axes[1, 1].plot(epochs, history["lr"], 'g-', linewidth=2)
    axes[1, 1].set_title('Learning Rate')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Learning Rate')
    axes[1, 1].set_yscale('log')
    axes[1, 1].grid(True, alpha=0.3)
    
    # Epoch times
    axes[1, 2].plot(epochs, history["epoch_times"], 'purple', linewidth=2)
    axes[1, 2].set_title('Epoch Duration')
    axes[1, 2].set_xlabel('Epoch')
    axes[1, 2].set_ylabel('Time (seconds)')
    axes[1, 2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(CONFIG["out_dir"], "training_curves.png"), dpi=300, bbox_inches='tight')
    plt.show()

plot_training_curves(history)

In [None]:
# Training summary statistics
print("Training Summary:")
print("=" * 50)
print(f"Total epochs: {len(history['train_loss'])}")
print(f"Best epoch: {best_epoch}")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Final learning rate: {history['lr'][-1]:.2e}")
print(f"Total training time: {sum(history['epoch_times']):.1f}s")
print(f"Average epoch time: {np.mean(history['epoch_times']):.1f}s")

# Best model metrics
best_idx = best_epoch - 1
if best_idx < len(history['val_metrics']):
    best_metrics = history['val_metrics'][best_idx]
    print(f"\nBest Model Performance:")
    print(f"  Height MAE: {best_metrics['mae']:.4f}m")
    print(f"  Height RMSE: {best_metrics['rmse']:.4f}m")
    print(f"  Wave Type Accuracy: {best_metrics['wave_type_acc']:.4f}")
    print(f"  Direction Accuracy: {best_metrics['direction_acc']:.4f}")
    print(f"  Wave Type Macro F1: {best_metrics['wave_type_macro_f1']:.4f}")
    print(f"  Direction Macro F1: {best_metrics['direction_macro_f1']:.4f}")

# Final model metrics
final_metrics = history['val_metrics'][-1]
print(f"\nFinal Model Performance:")
print(f"  Height MAE: {final_metrics['mae']:.4f}m")
print(f"  Height RMSE: {final_metrics['rmse']:.4f}m")
print(f"  Wave Type Accuracy: {final_metrics['wave_type_acc']:.4f}")
print(f"  Direction Accuracy: {final_metrics['direction_acc']:.4f}")
print(f"  Wave Type Macro F1: {final_metrics['wave_type_macro_f1']:.4f}")
print(f"  Direction Macro F1: {final_metrics['direction_macro_f1']:.4f}")

## Model Analysis

In [None]:
# Load best model for analysis
best_checkpoint = torch.load(os.path.join(CONFIG["out_dir"], "best.pt"), map_location=device)
model.load_state_dict(best_checkpoint["model"])
model.eval()

print(f"Loaded best model from epoch {best_checkpoint['epoch']}")

# Get predictions on validation set
all_predictions = []
all_targets = []

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Getting predictions"):
        x = batch["image"].to(device)
        y_h = batch["height"].cpu()
        y_wt = batch["wave_type"].cpu()
        y_dir = batch["direction"].cpu()
        
        pred_h, pred_wt, pred_dir = model(x)
        
        all_predictions.append({
            "height": pred_h.cpu(),
            "wave_type": pred_wt.cpu(),
            "direction": pred_dir.cpu()
        })
        
        all_targets.append({
            "height": y_h,
            "wave_type": y_wt,
            "direction": y_dir
        })

# Concatenate all predictions and targets
pred_heights = torch.cat([p["height"] for p in all_predictions], 0)
pred_wave_types = torch.cat([p["wave_type"] for p in all_predictions], 0)
pred_directions = torch.cat([p["direction"] for p in all_predictions], 0)

true_heights = torch.cat([t["height"] for t in all_targets], 0)
true_wave_types = torch.cat([t["wave_type"] for t in all_targets], 0)
true_directions = torch.cat([t["direction"] for t in all_targets], 0)

print(f"Collected predictions for {len(pred_heights)} validation samples")

In [None]:
# Visualize predictions vs targets
wave_type_names = list(wt2id.keys())
direction_names = list(d2id.keys())

# Height regression results
plot_regression_results(true_heights, pred_heights, "Validation Set - Height Predictions")

# Classification confusion matrices
plot_confusion_matrix(true_wave_types, pred_wave_types, wave_type_names, "Wave Type Predictions")
plot_confusion_matrix(true_directions, pred_directions, direction_names, "Direction Predictions")

## Save Final Results

In [None]:
# Save final training summary
training_summary = {
    "config": CONFIG,
    "total_epochs": len(history['train_loss']),
    "best_epoch": best_epoch,
    "best_val_loss": best_val_loss,
    "final_lr": history['lr'][-1],
    "total_training_time": sum(history['epoch_times']),
    "avg_epoch_time": np.mean(history['epoch_times']),
    "best_metrics": history['val_metrics'][best_epoch-1] if best_epoch <= len(history['val_metrics']) else None,
    "final_metrics": history['val_metrics'][-1],
    "model_parameters": sum(p.numel() for p in model.parameters()),
    "vocabularies": vocabs
}

with open(os.path.join(CONFIG["out_dir"], "training_summary.json"), "w") as f:
    json.dump(training_summary, f, indent=2)

print(f"✓ Training completed successfully!")
print(f"✓ Model checkpoints saved to: {CONFIG['out_dir']}")
print(f"✓ Best model: best.pt (epoch {best_epoch}, val_loss: {best_val_loss:.4f})")
print(f"✓ Latest model: last.pt")
print(f"✓ Training history: history.json")
print(f"✓ Vocabularies: vocabs.json")
print(f"✓ Configuration: config.json")
print(f"✓ Summary: training_summary.json")
print(f"✓ Training curves: training_curves.png")