<a href="https://colab.research.google.com/github/umfieldrobotics/shipwreck_finder_demo/blob/main/shipwreck_demo_oceans25.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Shipwreck Finder: Training Models to Find Shipwrecks

Make sure to use the GPU through clicking `Runtime --> Change runtime type` and select a GPU Option

In [None]:
!pip install -U segmentation-models-pytorch --quiet
!git clone --quiet --recursive https://github.com/umfieldrobotics/shipwreck_finder_demo.git

In [None]:
import torch
from pathlib import Path
from torch.utils.data import DataLoader
from PIL import Image, ImageDraw
from torch import nn
from torch.optim import AdamW
import numpy as np
import segmentation_models_pytorch as smp
from shipwreck_finder_demo.utils.data import MBESDataset
from shipwreck_finder_demo.utils.utils import (compute_balanced_weights, 
                                                foreground_iou_from_logits, 
                                                visualize_triplets_inline, 
                                                dump_test_visuals, 
                                                download_public_gdrive_file)
from typing import Tuple
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

In [None]:
PUBLIC_GDRIVE_LINK = "https://drive.google.com/file/d/1iswtn95_1LqB_u1M4u0KTbNwnG7S3C_r/view?usp=drive_link"
OUTPUT_PATH = "/content/downloads/TUTORIAL_DATASET.zip"
download_public_gdrive_file(PUBLIC_GDRIVE_LINK, OUTPUT_PATH)
!unzip {OUTPUT_PATH} -d /content/downloads

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = smp.Unet(
    encoder_name="resnet18",
    encoder_weights="imagenet",
    in_channels=1,
    classes=2,
).to(DEVICE)

train_dataset = MBESDataset(f"{OUTPUT_PATH[:-4]}/train", img_size=256)
test_dataset = MBESDataset(f"{OUTPUT_PATH[:-4]}/test", img_size=256)

In [None]:
from shipwreck_finder_demo.utils.visualization_utils import plot_train_test_grid

plot_train_test_grid(train_dataset, test_dataset, N=5)

In [None]:
# Set parameters
BATCH_SIZE_TRAIN = 16
BATCH_SIZE_TEST = 1
NUM_EPOCHS = 100
LEARNING_RATE = 5e-4
IGNORE_INDEX = -1
TARGET_DIR = Path("/content/segmentation_vis")
LOG_EVERY_EPOCHS = 5
NUM_VIS_TRIPLETS = 3

TARGET_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
# create loaders + optimizer
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE_TRAIN, shuffle=True,  num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE_TEST,  shuffle=False, num_workers=2, pin_memory=True)

weight0, weight1 = compute_balanced_weights(train_loader)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(
    weight=torch.tensor([weight0, weight1], dtype=torch.float32, device=DEVICE),
    ignore_index=IGNORE_INDEX,
)

In [None]:
epoch_bar = tqdm(range(1, NUM_EPOCHS + 1), desc="Epochs")
epoch_train_losses, epoch_test_losses, epoch_test_ious = [], [], []

for epoch_idx in epoch_bar:
    model.train()
    running_train_loss = 0.0
    for batch in train_loader:
        images = batch["image"].to(DEVICE, non_blocking=True)
        labels = batch["label"].to(DEVICE, non_blocking=True)
        if labels.dim() == 4:
            labels = labels.squeeze(1)

        optimizer.zero_grad(set_to_none=True)
        logits = model(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_train_loss += loss.item() * images.size(0)

    mean_train_loss = running_train_loss / len(train_loader.dataset)
    epoch_train_losses.append(mean_train_loss)

    model.eval()
    running_test_loss, running_test_iou = 0.0, 0.0
    with torch.no_grad():
        for batch in test_loader:
            images = batch["image"].to(DEVICE, non_blocking=True)
            labels = batch["label"].to(DEVICE, non_blocking=True)
            if labels.dim() == 4:
                labels = labels.squeeze(1)
            logits = model(images)
            loss = criterion(logits, labels)
            running_test_loss += loss.item() * images.size(0)
            running_test_iou += foreground_iou_from_logits(
                logits, labels, ignore_index=IGNORE_INDEX
            ) * images.size(0)

    mean_test_loss = running_test_loss / len(test_loader.dataset)
    mean_test_iou = running_test_iou / len(test_loader.dataset)
    epoch_test_losses.append(mean_test_loss)
    epoch_test_ious.append(mean_test_iou)

    epoch_bar.set_postfix(
        {
            "train_loss": f"{mean_train_loss:.4f}",
            "test_loss": f"{mean_test_loss:.4f}",
            "fg_IoU": f"{mean_test_iou:.4f}",
        }
    )

    if (epoch_idx % LOG_EVERY_EPOCHS) == 0:
        dump_test_visuals(
            test_loader,
            model,
            out_dir=TARGET_DIR / f"epoch_{epoch_idx:03d}",
            max_items=None,
            device=DEVICE,
        )
print(f"Saved test-set visuals under: {TARGET_DIR.resolve()}")
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
axs[0].plot(epoch_train_losses, label="Train Loss")
axs[0].plot(epoch_test_losses, label="Test Loss")
axs[0].set_xlabel("Epoch")
axs[0].set_ylabel("Loss")
axs[0].legend()
axs[0].grid(True)

axs[1].plot(epoch_test_ious, label="Test Foreground IoU", color="orange")
axs[1].set_xlabel("Epoch")
axs[1].set_ylabel("Foreground IoU")
axs[1].legend()
axs[1].grid(True)
plt.tight_layout()
plt.show()

In [None]:
from IPython.display import Image as IPyImage, display
from shipwreck_finder_demo.utils.visualization_utils import make_labeled_segmentation_gif

gif_path = make_labeled_segmentation_gif(TARGET_DIR, duration_ms=500)
if gif_path is not None:
    display(IPyImage(filename=str(gif_path)))
