## Vision Transformer (google/vit-base-patch16-224-in21k)
We chose google/vit-base-patch16-224-in21k over other ViT models because it provides the strongest balance of accuracy, stability, and computational efficiency for a project like ours: it is pretrained on the much larger ImageNet-21k dataset which gives it richer visual features than standard ImageNet-1k ViTs, it uses the well-established Base architecture which is powerful without being too heavy for a single-GPU Colab workflow, and it avoids the higher memory requirements and slower training times of larger ViT or Swin variants while still offering noticeably better performance on small to medium-sized image datasets like ours.

In [None]:
# Allow imports from the src folder
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))

if project_root not in sys.path:
    sys.path.append(project_root)

In [None]:
# Import libraries
import torch
import torch.nn as nn
from torch.optim import AdamW
import torchvision.transforms as T
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from transformers import ViTModel
from tqdm import tqdm
# Import custom functions
from src.s3_loader import get_image_s3

### Data Augmentations
We apply transformations and augmentations so the model learns to recognize benthic features under many lighting conditions, angles, and image variations instead of memorizing a narrow set of appearances. This improves generalization by exposing the Vision Transformer to more realistic underwater variation, helping it perform better on new, unseen reef images.
#### Transforms

In [None]:
# ImageNet statistics for normalization
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

# Define data augmentations and transformations
train_transform = T.Compose([
    T.RandomResizedCrop(224, scale=(0.8, 1.0)),
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    T.ColorJitter(
        brightness=0.3,
        contrast=0.3,
        saturation=0.3
    ),
    T.ToTensor(),
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

# Define test/validation transformations
test_transform = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

### Build the CoralReefDataset Class
We use a Dataset class because it gives PyTorch a consistent way to load individual training examples, apply transforms, and retrieve labels on demand. This keeps the data-loading logic organized and allows the DataLoader to automatically batch, shuffle, and efficiently feed data to the model during training.

In [None]:
class CoralReefDataset(Dataset):
    """
    PyTorch Dataset for MERMAID multi-label classification.
    Streams images from S3, applies transforms, and returns label vectors.

    CSV format expected:
        image_id, region_name, label1, label2, ..., label16
    """

    def __init__(self, csv_path, transform=None):
        self.df = pd.read_csv(csv_path)

        # Drop region_name (col index 1)
        if "region_name" in self.df.columns:
            self.df = self.df.drop(columns=["region_name"])

        self.transform = transform
        self.image_ids = self.df["image_id"].tolist()

        # label columns are everything except image_id
        self.label_cols = self.df.columns[1:].tolist()

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # --- Get image_id ---
        image_id = self.image_ids[idx]

        # --- Load image from S3 ---
        image = get_image_s3(image_id)

        # --- Apply transform ---
        if self.transform is not None:
            image = self.transform(image)

        # --- Get labels (float32 tensor, shape [16]) ---
        labels = torch.tensor(self.df.loc[idx, self.label_cols].values.astype("float32"))

        return image, labels

### Create Train / Test Datasets

In [None]:
train_csv = "../data/processed/final_labels_train.csv"
test_csv  = "../data/processed/final_labels_test.csv"

train_dataset = CoralReefDataset(train_csv, transform=train_transform)
test_dataset  = CoralReefDataset(test_csv,  transform=test_transform)

### Create DataLoaders
A DataLoader efficiently feeds batches of data to your model during training by pulling samples from a Dataset, grouping them into mini-batches, and handling details like shuffling and parallel loading.

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=16, #number of samples per batch
    shuffle=True, #prevents memorization of data order
    num_workers=0, #parallel data loading
    pin_memory=True # speeds up transfer to GPU
)

test_loader = DataLoader(
    test_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

### Build and Configure the ViT Model
#### Implement the Model Wrapper
A model wrapper is a small class that sits around the pretrained Vision Transformer and replaces its original classification head with a new one, allowing the backbone to stay the same while adapting the model for our specific multi-label coral reef prediction task.

In [None]:
class ViTMultiLabel(nn.Module):
    """
    Vision Transformer for multi-label classification.
    Uses google/vit-base-patch16-224-in21k as the backbone.
    """

    def __init__(self, num_labels=16):
        super().__init__()

        # Load pretrained backbone
        self.backbone = ViTModel.from_pretrained(
            "google/vit-base-patch16-224-in21k"
        )

        # Hidden size of CLS token representation
        hidden_dim = self.backbone.config.hidden_size

        # Custom classifier for multi-label prediction
        self.classifier = nn.Linear(hidden_dim, num_labels)

    def forward(self, pixel_values):
        """
        pixel_values: tensor of shape (B, 3, 224, 224)
        """
        outputs = self.backbone(pixel_values=pixel_values)

        # CLS token is at index 0 of the sequence
        cls_embedding = outputs.last_hidden_state[:, 0, :]

        logits = self.classifier(cls_embedding)  # shape: (B, num_labels)

        return logits

#### Instantiate the Model

In [None]:
model = ViTMultiLabel(num_labels=16)
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

#### Loss Function and Optimizer
We use BCEWithLogitsLoss because multi-label classification requires treating each of the 16 benthic attributes as an independent binary prediction, and this loss applies a sigmoid to each output while handling all labels jointly. We use AdamW optimizer because it is the standard, stable optimizer for transformer models and maintains good generalization by decoupling weight decay from the gradient update.

In [None]:
# Loss function
criterion = torch.nn.BCEWithLogitsLoss()

# Stage 1 Optimizer
optimizer_stage1 = AdamW(
    filter(lambda p: p.requires_grad, model.parameters()), 
    lr=5e-4
)

# Stage 2 Optimizer
optimizer_stage2 = AdamW(
    model.parameters(), 
    lr=2e-5
)

### Stage 1 Training (Backbone Frozen)
We freeze the pretrained ViT backbone during Stage 1 so the model can first learn a stable, well-behaved classifier head without disrupting the high-quality visual features already learned from ImageNet-21k, which prevents early overfitting and makes later full fine-tuning more effective.

In [None]:
# Freeze Backbone
for param in model.backbone.parameters():
    param.requires_grad = False

#### Define Training Loop

In [None]:
def train_stage1(model, train_loader, criterion, optimizer, device, epochs=5):
    model.train() 

    for epoch in range(epochs):
        running_loss = 0.0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

        for images, labels in progress_bar:
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            logits = model(images)

            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            progress_bar.set_postfix({"loss": loss.item()})

        epoch_loss = running_loss / len(train_loader)
        print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}")

    print("Stage 1 training complete.")

#### Train

In [None]:
epochs_stage1 = 5  
train_stage1(
    model,
    train_loader,
    criterion,
    optimizer_stage1,
    device,
    epochs=epochs_stage1
)