# SwellSight Model Architecture

This notebook defines the SwellSight multi-task neural network architecture for wave analysis.

## Overview
- **Backbone**: EfficientNet-B0 with pretrained weights
- **Tasks**: Wave height (regression), wave type (classification), direction (classification)
- **Architecture**: Shared feature extractor with task-specific heads

In [None]:
import torch
import torch.nn as nn
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
import matplotlib.pyplot as plt
from torchsummary import summary

In [None]:
class SwellSightNet(nn.Module):
    def __init__(self, num_wave_types: int, num_directions: int, dropout: float = 0.35):
        super().__init__()

        weights = EfficientNet_B0_Weights.DEFAULT
        self.backbone = efficientnet_b0(weights=weights)

        in_features = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Identity()

        self.shared = nn.Sequential(
            nn.LayerNorm(in_features),
            nn.Dropout(dropout),
        )

        self.head_height = nn.Sequential(
            nn.Linear(in_features, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(256, 1)
        )

        self.head_wave_type = nn.Sequential(
            nn.Linear(in_features, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(256, num_wave_types)
        )

        self.head_direction = nn.Sequential(
            nn.Linear(in_features, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(256, num_directions)
        )

    def forward(self, x):
        feats = self.backbone(x)
        feats = self.shared(feats)
        return self.head_height(feats), self.head_wave_type(feats), self.head_direction(feats)

In [None]:
# Example model instantiation
WAVE_TYPES = ["beach_break", "reef_break", "point_break", "closeout", "a_frame"]
DIRECTIONS = ["left", "right", "both"]

model = SwellSightNet(len(WAVE_TYPES), len(DIRECTIONS))
print(f"Model created with {len(WAVE_TYPES)} wave types and {len(DIRECTIONS)} directions")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Test forward pass
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

# Create dummy input
batch_size = 4
x = torch.randn(batch_size, 3, 224, 224).to(device)

with torch.no_grad():
    pred_h, pred_wt, pred_dir = model(x)
    
print(f"Input shape: {x.shape}")
print(f"Height prediction shape: {pred_h.shape}")
print(f"Wave type prediction shape: {pred_wt.shape}")
print(f"Direction prediction shape: {pred_dir.shape}")

In [None]:
# Visualize model architecture (if torchsummary is available)
try:
    summary(model, (3, 224, 224))
except ImportError:
    print("Install torchsummary for detailed model summary: pip install torchsummary")