<a href="https://colab.research.google.com/github/umfieldrobotics/shipwreck_finder_demo/blob/main/shipwreck_demo_oceans25.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# <img src="https://raw.githubusercontent.com/umfieldrobotics/shipwreck_finder_demo/main/figures/frog-logo.png" alt="FROG logo" style="height:1.2em; vertical-align:-0.2em; margin-right:0.4em;"> Shipwreck Finder Demo

<div align="center">
  <img src="https://raw.githubusercontent.com/umfieldrobotics/shipwreck_finder_demo/main/figures/teaser.png" alt="Teaser" width="750">
</div>

### Install requirements and import necessary packages

In [None]:
!pip install -U segmentation-models-pytorch --quiet
!git clone --quiet --recursive https://github.com/umfieldrobotics/shipwreck_finder_demo.git

In [None]:
import torch
from pathlib import Path
from torch.utils.data import DataLoader
from PIL import Image, ImageDraw
from torch import nn
from torch.optim import AdamW
import numpy as np
import segmentation_models_pytorch as smp
from shipwreck_finder_demo.utils.data import MBESDataset
from shipwreck_finder_demo.utils.utils import (
    compute_balanced_weights,
    foreground_iou_from_logits,
    visualize_triplets_inline,
    dump_visuals,
    download_public_gdrive_file,
)
from typing import Tuple
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

print(f"[SUCCESS] All packages imported successfully.")

### Download publicly available bathymmetry data and labels of shipwrecks

In [None]:
PUBLIC_GDRIVE_LINK = "https://drive.google.com/file/d/1iswtn95_1LqB_u1M4u0KTbNwnG7S3C_r/view?usp=drive_link"
OUTPUT_PATH = "/content/downloads/TUTORIAL_DATASET.zip"

In [None]:
download_public_gdrive_file(PUBLIC_GDRIVE_LINK, OUTPUT_PATH)
!unzip {OUTPUT_PATH} -d /content/downloads

### Prepare dataset for training

We will be using the Irish Bathymmetry dataset. This dataset has bathymmetry data with shipwrecks in it.

We went through and labeled the shipwrecks with a `foreground` (`1`) label and labeled everything else in the bathymmetry data as `background` (`0`)

In [None]:
# Set parameters
BATCH_SIZE_TRAIN = 16
BATCH_SIZE_TEST = 1
LEARNING_RATE = 5e-4
IGNORE_INDEX = -1
TARGET_DIR = Path("/content/segmentation_vis")

TARGET_DIR.mkdir(parents=True, exist_ok=True)

Here, we will split the data in to `train`, `validation` and `test`. 

We will use the `train` dataset to train the model.

We will use the `validation` dataset to validate the model performance after each training epoch.

Finally, we will test our models performance on the `test` dataset, which is completely unseen during the training and will be the final metrics we use to evaluate our model.

### Visualizing the dataset
In the cell below, we do N things:
1. We split the training dataset into `train` and `validation`
2. We set the `test` dataset
3. We finally visualize our bathymmetry data, the label and the label overlaid on the bathymmetry data.

In [None]:
from shipwreck_finder_demo.utils.visualization_utils import plot_train_test_grid

train_dataset = MBESDataset(f"{OUTPUT_PATH[:-4]}/train", img_size=512)
val_size = int(0.1 * len(train_dataset))
train_size = len(train_dataset) - val_size
train_dataset, val_dataset = torch.utils.data.random_split(
    train_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42)
)

test_dataset = MBESDataset(f"{OUTPUT_PATH[:-4]}/test", img_size=512)
plot_train_test_grid(train_dataset, test_dataset, N=3)


# create loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE_TRAIN,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
)
test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
)


In [None]:
print(f'Number of training samples: {len(train_dataset)}')
print(f'Number of validation samples: {len(val_dataset)}')
print(f'Number of test samples: {len(test_dataset)}')

### Define model and optimizer

**Model:** For this demonstration, we will be using a [UNet](https://arxiv.org/abs/1505.04597) segmentation model, one that is very common in segmentation tasks. We will be using the implementation from the [PyTorch Segmentation Models library](https://github.com/qubvel-org/segmentation_models.pytorch), which readily comes with many segmentation models.

**Loss**: We are using cross entropy loss for the segmentation task.

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=1,
    classes=2,
).to(DEVICE)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)

### Let's train the model

In [None]:
NUM_EPOCHS = 100
LOG_EVERY_EPOCHS = 5
NUM_VIS_TRIPLETS = 3
epoch_bar = tqdm(range(1, NUM_EPOCHS + 1), desc="Epochs")
epoch_train_losses, epoch_val_losses, epoch_val_ious = [], [], []

best_val_loss = float("inf")
best_epoch = -1

for epoch_idx in epoch_bar:
    model.train()
    running_train_loss = 0.0
    for batch in train_loader:
        images = batch["image"].to(DEVICE, non_blocking=True)
        labels = batch["label"].to(DEVICE, non_blocking=True)
        if labels.dim() == 4:
            labels = labels.squeeze(1)

        optimizer.zero_grad(set_to_none=True)
        logits = model(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_train_loss += loss.item() * images.size(0)

    mean_train_loss = running_train_loss / len(train_loader.dataset)
    epoch_train_losses.append(mean_train_loss)

    model.eval()
    running_val_loss, running_val_iou = 0.0, 0.0
    with torch.no_grad():
        for batch in val_loader:
            images = batch["image"].to(DEVICE, non_blocking=True)
            labels = batch["label"].to(DEVICE, non_blocking=True)
            if labels.dim() == 4:
                labels = labels.squeeze(1)
            logits = model(images)
            loss = criterion(logits, labels)
            running_val_loss += loss.item() * images.size(0)
            running_val_iou += foreground_iou_from_logits(
                logits, labels, ignore_index=IGNORE_INDEX
            ) * images.size(0)

    mean_val_loss = running_val_loss / len(val_loader.dataset)
    mean_val_iou = running_val_iou / len(val_loader.dataset)
    epoch_val_losses.append(mean_val_loss)
    epoch_val_ious.append(mean_val_iou)

    if mean_val_loss < best_val_loss:
        best_val_loss = mean_val_loss
        best_epoch = epoch_idx
        torch.save(model.state_dict(), TARGET_DIR / "best_model.pth")
        print(f"\n[INFO] New best model saved at epoch {best_epoch} with val_loss {best_val_loss:.4f}")

    epoch_bar.set_postfix(
        {
            "train_loss": f"{mean_train_loss:.4f}",
            "val_loss": f"{mean_val_loss:.4f}",
            "fg_IoU": f"{mean_val_iou:.4f}",
        }
    )
    torch.save(model.state_dict(), TARGET_DIR / "latest_model.pth")
    if (epoch_idx % LOG_EVERY_EPOCHS) == 0:
        dump_visuals(
            val_loader,
            model,
            out_dir=TARGET_DIR / f"epoch_{epoch_idx:03d}",
            max_items=None,
            device=DEVICE,
        )

### Visualize the performance

In [None]:
print(f"Saved test-set visuals under: {TARGET_DIR.resolve()}")
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
axs[0].plot(epoch_train_losses, label="Train Loss")
axs[0].plot(epoch_val_losses, label="Val Loss")
axs[0].set_xlabel("Epoch")
axs[0].set_ylabel("Loss")
axs[0].legend()
axs[0].grid(True)

axs[1].plot(epoch_val_ious, label="Val Foreground IoU", color="orange")
axs[1].set_xlabel("Epoch")
axs[1].set_ylabel("Foreground IoU")
axs[1].legend()
axs[1].grid(True)
plt.tight_layout()
plt.show()

### Lets evaluate the data!

In [None]:
# load the best model for testing
model.load_state_dict(torch.load(TARGET_DIR / "best_model.pth"))
model.eval()
running_test_loss, running_test_iou = 0.0, 0.0
with torch.no_grad():
    for batch in test_loader:
        images = batch["image"].to(DEVICE, non_blocking=True)
        labels = batch["label"].to(DEVICE, non_blocking=True)
        if labels.dim() == 4:
            labels = labels.squeeze(1)
        logits = model(images)
        loss = criterion(logits, labels)
        running_test_loss += loss.item() * images.size(0)
        running_test_iou += foreground_iou_from_logits(
            logits, labels, ignore_index=IGNORE_INDEX
        ) * images.size(0)
mean_test_loss = running_test_loss / len(test_loader.dataset)
mean_test_iou = running_test_iou / len(test_loader.dataset)
print(f"Test Loss: {mean_test_loss:.4f}, Test Foreground IoU: {mean_test_iou:.4f}")

visualize_triplets_inline(model, test_loader, device=DEVICE)


### Visualize the progression of the model

In [None]:
from IPython.display import Image as IPyImage, display
from shipwreck_finder_demo.utils.visualization_utils import (
    make_labeled_segmentation_gif,
)

gif_path = make_labeled_segmentation_gif(TARGET_DIR, duration_ms=500, frame_glob="epoch_*/batch_000000.png")
if gif_path is not None:
    display(IPyImage(filename=str(gif_path)))