# SwellSight Data Loading and Utilities

This notebook contains the data loading utilities, dataset classes, and helper functions for the SwellSight project.

## Components
- Utility functions (seed setting, file I/O)
- Dataset class for loading wave images and labels
- Vocabulary building for classification tasks
- Data transforms for training and inference

In [None]:
import os
import json
import random
from typing import List, Dict, Any, Tuple

import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import matplotlib.pyplot as plt

## Utility Functions

In [None]:
def set_seed(seed: int) -> None:
    """Set random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def ensure_dir(path: str) -> None:
    """Create directory if it doesn't exist."""
    os.makedirs(path, exist_ok=True)


def read_jsonl(path: str) -> List[Dict[str, Any]]:
    """Read JSONL file and return list of records."""
    items = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            items.append(json.loads(line))
    return items


def write_jsonl(items: List[Dict[str, Any]], path: str) -> None:
    """Write list of records to JSONL file."""
    ensure_dir(os.path.dirname(path))
    with open(path, "w", encoding="utf-8") as f:
        for r in items:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

## Data Transforms

In [None]:
def build_transforms(train: bool, image_size: int = 224):
    """Build data transforms for training or inference."""
    if train:
        return transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.10, hue=0.03),
            transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.15),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    return transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])


def build_infer_transform(image_size: int):
    """Build transform for single image inference."""
    return transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

## Dataset and Vocabulary Functions

In [None]:
def confidence_to_weight(conf: str) -> float:
    """Convert confidence string to sample weight."""
    conf = (conf or "medium").lower().strip()
    if conf == "high":
        return 1.0
    if conf == "medium":
        return 0.7
    if conf == "low":
        return 0.4
    return 0.7


def build_vocabs(items: List[Dict[str, Any]]) -> Tuple[Dict[str, int], Dict[str, int]]:
    """Build vocabularies for wave types and directions."""
    wave_types = sorted({x["wave_type"] for x in items})
    directions = sorted({x["direction"] for x in items})
    return {k: i for i, k in enumerate(wave_types)}, {k: i for i, k in enumerate(directions)}

In [None]:
class SwellSightDataset(Dataset):
    """Dataset class for SwellSight wave analysis."""
    
    def __init__(
        self,
        index_jsonl: str,
        transform=None,
        wave_type_to_id: Dict[str, int] = None,
        direction_to_id: Dict[str, int] = None,
    ):
        self.items = read_jsonl(index_jsonl)
        self.transform = transform

        if wave_type_to_id is None or direction_to_id is None:
            self.wave_type_to_id, self.direction_to_id = build_vocabs(self.items)
        else:
            self.wave_type_to_id = wave_type_to_id
            self.direction_to_id = direction_to_id

    def __len__(self) -> int:
        return len(self.items)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        r = self.items[idx]
        img = Image.open(r["image_path"]).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        height = torch.tensor([float(r["height_meters"])], dtype=torch.float32)
        wave_type = torch.tensor(self.wave_type_to_id[r["wave_type"]], dtype=torch.long)
        direction = torch.tensor(self.direction_to_id[r["direction"]], dtype=torch.long)
        weight = torch.tensor([confidence_to_weight(r.get("confidence", "medium"))], dtype=torch.float32)

        return {
            "image": img,
            "height": height,
            "wave_type": wave_type,
            "direction": direction,
            "weight": weight,
            "meta": r
        }

## Example Usage

In [None]:
# Set seed for reproducibility
set_seed(42)

# Example of creating transforms
train_transform = build_transforms(train=True, image_size=224)
val_transform = build_transforms(train=False, image_size=224)

print("Training transforms:")
print(train_transform)
print("\nValidation transforms:")
print(val_transform)

In [None]:
# Example of loading a dataset (if data exists)
# Uncomment and modify path when you have actual data

# try:
#     dataset = SwellSightDataset(
#         "data/processed/splits/train.jsonl",
#         transform=train_transform
#     )
#     print(f"Dataset loaded with {len(dataset)} samples")
#     print(f"Wave types: {dataset.wave_type_to_id}")
#     print(f"Directions: {dataset.direction_to_id}")
#     
#     # Show first sample
#     sample = dataset[0]
#     print(f"\nFirst sample:")
#     print(f"Image shape: {sample['image'].shape}")
#     print(f"Height: {sample['height'].item():.2f}m")
#     print(f"Wave type ID: {sample['wave_type'].item()}")
#     print(f"Direction ID: {sample['direction'].item()}")
#     print(f"Weight: {sample['weight'].item():.2f}")
# except FileNotFoundError:
#     print("Dataset file not found. Run data preparation notebooks first.")