<a href="https://colab.research.google.com/github/umfieldrobotics/shipwreck_finder_demo/blob/main/shipwreck_demo_oceans25.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Shipwreck Finder: Training Models to Find Shipwrecks

In [None]:
!pip install -U segmentation-models-pytorch --quiet
!git clone --quiet --recursive https://github.com/umfieldrobotics/shipwreck_finder_demo.git

In [None]:
import torch
from pathlib import Path
from torch.utils.data import DataLoader
from torch import nn
from torch.optim import AdamW
import numpy as np
import segmentation_models_pytorch as smp
from shipwreck_finder_demo.utils.data import MBESDataset
from shipwreck_finder_demo.utils.utils import (compute_balanced_weights, 
                                                foreground_iou_from_logits, 
                                                visualize_triplets_inline, 
                                                dump_test_visuals, 
                                                download_public_gdrive_file)

In [None]:
PUBLIC_GDRIVE_LINK = "https://drive.google.com/file/d/1iswtn95_1LqB_u1M4u0KTbNwnG7S3C_r/view?usp=drive_link"
OUTPUT_PATH = "/content/downloads/TUTORIAL_DATASET.zip"
download_public_gdrive_file(PUBLIC_GDRIVE_LINK, OUTPUT_PATH)
!unzip {OUTPUT_PATH} -d /content/downloads

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = smp.Unet(
    encoder_name="resnet18",
    encoder_weights="imagenet",
    in_channels=1,
    classes=2,
).to(DEVICE)

train_dataset = MBESDataset(f"{OUTPUT_PATH}/train", img_size=256)
test_dataset = MBESDataset(f"{OUTPUT_PATH}/test", img_size=256)

In [None]:
import matplotlib.pyplot as plt

N = 5  # Number of samples to visualize

def overlay_label_on_image(image, label, alpha=0.4):
    # Normalize image for display
    img = image.squeeze()
    img = (img - img.min()) / (img.max() - img.min())
    # Create color overlay for label
    color_label = np.zeros((*label.shape, 3), dtype=np.float32)
    color_label[label == 1] = [1, 0, 0]  # Red for foreground
    overlay = np.stack([img]*3, axis=-1)
    return (1 - alpha) * overlay + alpha * color_label

def plot_samples(dataset, title, N=5):
    plt.figure(figsize=(12, N * 3))
    for i in range(N):
        sample = dataset[i]
        image = sample['image'].numpy()
        label = sample['label'].numpy()
        # Image
        plt.subplot(N, 3, i*3 + 1)
        plt.imshow(image.squeeze(), cmap='gray')
        plt.axis('off')
        if i == 0:
            plt.title(f"{title} Image")
        # Label
        plt.subplot(N, 3, i*3 + 2)
        plt.imshow(label.squeeze(), cmap='gray', vmin=0, vmax=1)
        plt.axis('off')
        if i == 0:
            plt.title(f"{title} Label")
        # Overlay
        plt.subplot(N, 3, i*3 + 3)
        plt.imshow(overlay_label_on_image(image, label))
        plt.axis('off')
        if i == 0:
            plt.title(f"{title} Overlay")
    plt.tight_layout()
    plt.show()

plot_samples(train_dataset, "Train", N)
plot_samples(test_dataset, "Test", N)

In [None]:
def train_step(model, optimizer, criterion, batch):
    
    # Put the model in train mode (dropout, batchnorm, etc. active)
    model.train()
    
    # Zero the gradients from the previous step to prevent accumulation
    optimizer.zero_grad()

    # Load the images and labels onto the GPU (if available)
    images = batch["image"].to(DEVICE, non_blocking=True)
    labels = batch["label"].to(DEVICE, non_blocking=True)
    if labels.dim() == 4:
        labels = labels.squeeze(1)

    # Forward pass, loss computation, backward pass, and optimizer step
    logits = model(images)
    loss = criterion(logits, labels)
    
    # First order gradients used in gradient descent
    loss.backward()
    optimizer.step()

    return loss.item()

In [None]:
def test_step(model, criterion, batch):
    
    # Put the model in eval mode (dropout, batchnorm, etc. inactive)
    model.eval()
    
    # Load the images and labels onto the GPU (if available)
    images = batch["image"].to(DEVICE, non_blocking=True)
    labels = batch["label"].to(DEVICE, non_blocking=True)

    # Forward pass and loss computation
    with torch.no_grad():
        logits = model(images)
        loss = criterion(logits, labels)
        iou = foreground_iou_from_logits(logits, labels, ignore_index=IGNORE_INDEX)

    return loss.item(), iou.item(), logits.cpu(), labels.cpu()

In [None]:
BATCH_SIZE_TRAIN = 4
BATCH_SIZE_TEST = 1
NUM_EPOCHS = 10
LEARNING_RATE = 5e-4
IGNORE_INDEX = -1
TARGET_DIR = Path("/content/segmentation_vis")
TARGET_DIR.mkdir(parents=True, exist_ok=True)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE_TRAIN,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE_TEST,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
)

weight0, weight1 = compute_balanced_weights(train_loader, ignore_index=IGNORE_INDEX)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
weight_tensor = torch.tensor([weight0, weight1], dtype=torch.float32, device=DEVICE)
criterion = nn.CrossEntropyLoss(weight=weight_tensor, ignore_index=IGNORE_INDEX)

epoch_bar = tqdm(range(1, NUM_EPOCHS + 1), desc="Epochs")
for epoch_idx in epoch_bar:
    if epoch_idx == 0:
        model.eval()
        running_test_loss = 0.0
        running_test_iou = 0.0
        with torch.no_grad():
            for batch in test_loader:
                test_loss, test_iou, test_logits, test_labels = test_step(
                    model, criterion, batch
                )
                running_test_loss += test_loss * batch["image"].size(0)
                running_test_iou += foreground_iou_from_logits(
                    test_logits, test_labels, ignore_index=IGNORE_INDEX
                ) * batch["image"].size(0)

        mean_test_loss = running_test_loss / len(test_loader.dataset)
        mean_test_iou = running_test_iou / len(test_loader.dataset)

    model.train()
    running_train_loss = 0.0
    for batch in train_loader:
        train_loss = train_step(model, optimizer, criterion, batch)
        running_train_loss += train_loss * images.size(0)

    mean_train_loss = running_train_loss / len(train_loader.dataset)

    model.eval()
    running_test_loss = 0.0
    running_test_iou = 0.0
    with torch.no_grad():
        for batch in test_loader:
            test_loss, test_iou, test_logits, test_labels = test_step(
                model, criterion, batch
            )
            running_test_loss += test_loss * batch["image"].size(0)
            running_test_iou += foreground_iou_from_logits(
                test_logits, test_labels, ignore_index=IGNORE_INDEX
            ) * batch["image"].size(0)

    mean_test_loss = running_test_loss / len(test_loader.dataset)
    mean_test_iou = running_test_iou / len(test_loader.dataset)

    epoch_bar.set_postfix(
        {
            "train_loss": f"{mean_train_loss:.4f}",
            "test_loss": f"{mean_test_loss:.4f}",
            "fg_IoU": f"{mean_test_iou:.4f}",
        }
    )


visualize_triplets_inline(model, test_loader, DEVICE, num_items=3)
dump_test_visuals(test_loader, TARGET_DIR, max_items=None)
print(f"Saved test-set visuals to: {TARGET_DIR.resolve()}")