# SwellSight Loss Functions and Metrics

This notebook defines the multi-task loss function and evaluation metrics for the SwellSight model.

## Components
- Multi-task loss combining regression and classification
- Regression metrics (MAE, MSE, RMSE)
- Classification metrics (accuracy, macro F1)
- Sample weighting based on confidence

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns

## Multi-Task Loss Function

In [None]:
class MultiTaskLoss(nn.Module):
    """Multi-task loss combining regression and classification losses."""
    
    def __init__(self, w_height=1.0, w_wave_type=1.0, w_direction=1.0):
        super().__init__()
        self.w_height = w_height
        self.w_wave_type = w_wave_type
        self.w_direction = w_direction

        self.reg = nn.SmoothL1Loss(reduction="none")
        self.ce = nn.CrossEntropyLoss(reduction="none")

    def forward(self, pred_h, pred_wt, pred_dir, y_h, y_wt, y_dir, sample_w):
        w = sample_w.view(-1)

        loss_h = self.reg(pred_h.view(-1), y_h.view(-1))
        loss_h = (loss_h * w).mean()

        loss_wt = self.ce(pred_wt, y_wt.view(-1))
        loss_wt = (loss_wt * w).mean()

        loss_dir = self.ce(pred_dir, y_dir.view(-1))
        loss_dir = (loss_dir * w).mean()

        total = self.w_height * loss_h + self.w_wave_type * loss_wt + self.w_direction * loss_dir
        return total, {
            "loss_h": float(loss_h.item()), 
            "loss_wt": float(loss_wt.item()), 
            "loss_dir": float(loss_dir.item())
        }

## Regression Metrics

In [None]:
def regression_metrics(y_true: torch.Tensor, y_pred: torch.Tensor) -> dict:
    """
    Compute common regression metrics.
    Expected shapes:
      - y_true: (N,) or (N, 1)
      - y_pred: (N,) or (N, 1)
    Returns: dict with mae, mse, rmse
    """
    y_true = y_true.detach().float().view(-1)
    y_pred = y_pred.detach().float().view(-1)

    err = y_pred - y_true
    mae = err.abs().mean()
    mse = (err ** 2).mean()
    rmse = torch.sqrt(mse)

    return {
        "mae": mae.item(),
        "mse": mse.item(),
        "rmse": rmse.item(),
    }

## Classification Metrics

In [None]:
def _to_labels(y_pred: torch.Tensor) -> torch.Tensor:
    """
    Convert logits/probs to labels if needed.
    If y_pred is (N, C) -> argmax to (N,)
    If y_pred is (N,) -> assume already labels
    """
    if y_pred.dim() == 2:
        return torch.argmax(y_pred, dim=1)
    return y_pred.view(-1)


def accuracy(y_true: torch.Tensor, y_pred: torch.Tensor) -> float:
    """
    Accuracy for classification.
    y_true: (N,) long
    y_pred: (N, C) logits or (N,) labels
    """
    y_true = y_true.detach().view(-1).long()
    y_hat = _to_labels(y_pred.detach()).view(-1).long()
    return (y_hat == y_true).float().mean().item()


def macro_f1(y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int | None = None, eps: float = 1e-9) -> float:
    """
    Macro F1 score without sklearn.
    y_true: (N,) long
    y_pred: (N, C) logits or (N,) labels
    """
    y_true = y_true.detach().view(-1).long()
    y_hat = _to_labels(y_pred.detach()).view(-1).long()

    if num_classes is None:
        num_classes = int(torch.max(torch.cat([y_true, y_hat])).item()) + 1

    f1_sum = 0.0
    for c in range(num_classes):
        tp = ((y_hat == c) & (y_true == c)).sum().float()
        fp = ((y_hat == c) & (y_true != c)).sum().float()
        fn = ((y_hat != c) & (y_true == c)).sum().float()

        precision = tp / (tp + fp + eps)
        recall = tp / (tp + fn + eps)
        f1 = 2.0 * precision * recall / (precision + recall + eps)
        f1_sum += f1.item()

    return f1_sum / float(num_classes)

## Visualization Functions

In [None]:
def plot_confusion_matrix(y_true, y_pred, class_names, title="Confusion Matrix"):
    """Plot confusion matrix for classification results."""
    y_true_np = y_true.detach().cpu().numpy() if torch.is_tensor(y_true) else y_true
    y_pred_np = _to_labels(y_pred).detach().cpu().numpy() if torch.is_tensor(y_pred) else y_pred
    
    cm = confusion_matrix(y_true_np, y_pred_np)
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title(title)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.show()


def plot_regression_results(y_true, y_pred, title="Height Prediction Results"):
    """Plot regression results with scatter plot and metrics."""
    y_true_np = y_true.detach().cpu().numpy() if torch.is_tensor(y_true) else y_true
    y_pred_np = y_pred.detach().cpu().numpy() if torch.is_tensor(y_pred) else y_pred
    
    metrics = regression_metrics(torch.tensor(y_true_np), torch.tensor(y_pred_np))
    
    plt.figure(figsize=(10, 4))
    
    # Scatter plot
    plt.subplot(1, 2, 1)
    plt.scatter(y_true_np, y_pred_np, alpha=0.6)
    min_val = min(y_true_np.min(), y_pred_np.min())
    max_val = max(y_true_np.max(), y_pred_np.max())
    plt.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.8)
    plt.xlabel('True Height (m)')
    plt.ylabel('Predicted Height (m)')
    plt.title(f'{title}\nMAE: {metrics["mae"]:.3f}, RMSE: {metrics["rmse"]:.3f}')
    
    # Error distribution
    plt.subplot(1, 2, 2)
    errors = y_pred_np - y_true_np
    plt.hist(errors, bins=30, alpha=0.7, edgecolor='black')
    plt.xlabel('Prediction Error (m)')
    plt.ylabel('Frequency')
    plt.title('Error Distribution')
    plt.axvline(0, color='red', linestyle='--', alpha=0.8)
    
    plt.tight_layout()
    plt.show()

## Example Usage and Testing

In [None]:
# Test loss function
batch_size = 8
num_wave_types = 5
num_directions = 3

# Create dummy predictions and targets
pred_h = torch.randn(batch_size, 1) * 2 + 1  # Heights around 1m
pred_wt = torch.randn(batch_size, num_wave_types)
pred_dir = torch.randn(batch_size, num_directions)

y_h = torch.randn(batch_size, 1) * 2 + 1
y_wt = torch.randint(0, num_wave_types, (batch_size,))
y_dir = torch.randint(0, num_directions, (batch_size,))
sample_w = torch.ones(batch_size, 1) * 0.8  # Confidence weights

# Test loss computation
loss_fn = MultiTaskLoss(w_height=1.0, w_wave_type=1.0, w_direction=1.0)
total_loss, loss_dict = loss_fn(pred_h, pred_wt, pred_dir, y_h, y_wt, y_dir, sample_w)

print(f"Total loss: {total_loss.item():.4f}")
print(f"Height loss: {loss_dict['loss_h']:.4f}")
print(f"Wave type loss: {loss_dict['loss_wt']:.4f}")
print(f"Direction loss: {loss_dict['loss_dir']:.4f}")

In [None]:
# Test metrics
reg_metrics = regression_metrics(y_h, pred_h)
wt_acc = accuracy(y_wt, pred_wt)
dir_acc = accuracy(y_dir, pred_dir)
wt_f1 = macro_f1(y_wt, pred_wt, num_wave_types)
dir_f1 = macro_f1(y_dir, pred_dir, num_directions)

print("\nRegression Metrics:")
print(f"MAE: {reg_metrics['mae']:.4f}")
print(f"MSE: {reg_metrics['mse']:.4f}")
print(f"RMSE: {reg_metrics['rmse']:.4f}")

print("\nClassification Metrics:")
print(f"Wave type accuracy: {wt_acc:.4f}")
print(f"Direction accuracy: {dir_acc:.4f}")
print(f"Wave type macro F1: {wt_f1:.4f}")
print(f"Direction macro F1: {dir_f1:.4f}")

In [None]:
# Visualize example results
wave_types = ["beach_break", "reef_break", "point_break", "closeout", "a_frame"]
directions = ["left", "right", "both"]

# Plot regression results
plot_regression_results(y_h, pred_h, "Example Height Predictions")

# Plot confusion matrices
plot_confusion_matrix(y_wt, pred_wt, wave_types, "Wave Type Predictions")
plot_confusion_matrix(y_dir, pred_dir, directions, "Direction Predictions")