In [1]:
# !/usr/bin/env python3
!pip install torch>=2.0.0
!pip install torchvision>=0.15.0
!pip install numpy>=1.24.0
!pip install Pillow>=9.0.0
!pip install opencv-python>=4.5.0
!pip install pytorch-lightning

Collecting pytorch-lightning
  Downloading pytorch_lightning-2.5.3-py3-none-any.whl.metadata (20 kB)
Collecting torchmetrics>0.7.0 (from pytorch-lightning)
  Downloading torchmetrics-1.8.1-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch-lightning)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading pytorch_lightning-2.5.3-py3-none-any.whl (828 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m828.2/828.2 kB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Downloading torchmetrics-1.8.1-py3-none-any.whl (982 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.0/983.0 kB[0m [31m44.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lightning-utilities, torchmetrics, pytorch-lightning
Successfully installed lightning-utilities-0.15.2 pytorch-lightning-2.5.3 torchmetrics-1.8.1


In [2]:
import os
import glob
import argparse
from PIL import Image
import random
import numpy as np
from collections import defaultdict

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
import torchvision.transforms as T
import torchvision.models.segmentation as seg_models


In [3]:
# -------------------------
# Dataset
# -------------------------
class VideoSegmentationDataset(Dataset):
    """
    Expects root_dir with structure:
      root_dir/
        video_01/
          frames_original/
            000001.png
            ...
          segmentation/
            000001.png
            ...
        video_02/...
    images/frame_original are rgb images
    Masks/segmentations are grayscale with values 0..9 (integers, 0=background).
    """
    def __init__(self, root_dir, img_size=(256,256), augment=False):
        self.samples = []
        self.img_size = tuple(img_size)
        self.augment = augment

        if not os.path.isdir(root_dir):
            raise ValueError(f"{root_dir} is not a directory")

        videos = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        for video in videos:
            frame_dir = os.path.join(root_dir, video, "frames_original")
            mask_dir = os.path.join(root_dir, video, "segmentation")
            if not os.path.isdir(frame_dir) or not os.path.isdir(mask_dir):
                continue
            frames = sorted(glob.glob(os.path.join(frame_dir, "*.png")))
            for f in frames:
                filename = os.path.basename(f)
                mask_path = os.path.join(mask_dir, filename)
                if os.path.exists(mask_path):
                    self.samples.append((f, mask_path))

        if len(self.samples) == 0:
            raise RuntimeError(f"No samples found in {root_dir}.")

    def __len__(self):
        return len(self.samples)

    def _sync_transform(self, image, mask):
        image = image.resize(self.img_size, resample=Image.BILINEAR)
        mask = mask.resize(self.img_size, resample=Image.NEAREST)

        # Random horizontal flip
        if self.augment and random.random() > 0.5:
            image = TF.hflip(image)
            mask = TF.hflip(mask)

        # Convert to tensors
        image = TF.to_tensor(image)  # float [0,1] CxHxW
        image = TF.normalize(image, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])

        mask = np.array(mask, dtype=np.int64)
        mask = torch.from_numpy(mask)
        return image, mask

    def __getitem__(self, idx):
        img_path, mask_path = self.samples[idx]
        img = Image.open(img_path).convert("RGB")
        m = Image.open(mask_path).convert("L")
        img, m = self._sync_transform(img, m)
        return img, m


In [4]:
# -------------------------
# Metrics
# -------------------------
def fast_confusion_matrix(preds, labels, num_classes):
    """
    preds: HxW or (N,H,W) LongTensor predicted class ids
    labels: same shape ground-truth
    returns: (num_classes, num_classes) confusion matrix where
      conf[i,j] = count of pixels where true=i and pred=j
    """
    preds_np = preds.cpu().numpy().ravel()
    labels_np = labels.cpu().numpy().ravel()
    mask = (labels_np >= 0) & (labels_np < num_classes)
    hist = np.bincount(
        num_classes * labels_np[mask].astype(int) + preds_np[mask].astype(int),
        minlength=num_classes**2
    ).reshape(num_classes, num_classes)
    return hist

def compute_metrics_from_confusion(conf, ignore_index=None):
    """
    conf: confusion matrix num_classes x num_classes
    Returns dict with per-class IoU, mean IoU, pixel accuracy
    """
    tp = np.diag(conf).astype(np.float64)
    pos_gt = conf.sum(axis=1).astype(np.float64)
    pos_pred = conf.sum(axis=0).astype(np.float64)
    union = pos_gt + pos_pred - tp

    eps = 1e-12
    iou = tp / (union + eps)
    # If a class has zero gt pixels, set IoU to nan
    iou[pos_gt == 0] = np.nan

    mean_iou = np.nanmean(iou)
    pixel_acc = tp.sum() / (conf.sum() + eps)

    per_class = {f"class_{c}": float(iou[c]) if not np.isnan(iou[c]) else None
                 for c in range(len(iou))}
    metrics = {"per_class_iou": per_class, "mean_iou": float(mean_iou), "pixel_acc": float(pixel_acc)}
    return metrics

# -------------------------
# Focal Loss
# -------------------------
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=None, ignore_index=None):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.ignore_index = ignore_index

    def forward(self, logits, targets):
        # Only pass ignore_index if it's not None othrwise was giving error
        if self.ignore_index is not None:
            ce_loss = nn.CrossEntropyLoss(weight=self.alpha, ignore_index=self.ignore_index, reduction='none')(logits, targets)
        else:
            ce_loss = nn.CrossEntropyLoss(weight=self.alpha, reduction='none')(logits, targets)
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
        return focal_loss


In [5]:
# -------------------------
# Training / Validation loops
# -------------------------
def train_one_epoch(model, dataloader, optimizer, device, criterion, num_classes, epoch, log_every=20):
    model.train()
    running_loss = 0.0
    conf = np.zeros((num_classes, num_classes), dtype=np.int64)
    for i, (images, masks) in enumerate(dataloader):
        images = images.to(device)
        masks = masks.to(device)

        outputs = model(images)['out']  # (N, C, H, W)
        loss = criterion(outputs, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

        preds = outputs.argmax(dim=1)  # (N,H,W)
        conf += fast_confusion_matrix(preds, masks, num_classes)

        if (i + 1) % log_every == 0:
            print(f"  [Epoch {epoch}] Iter {i+1}/{len(dataloader)}  loss={loss.item():.4f}")

    epoch_loss = running_loss / (len(dataloader.dataset))
    metrics = compute_metrics_from_confusion(conf, ignore_index=None)
    metrics['loss'] = epoch_loss
    return metrics

@torch.no_grad()
def validate(model, dataloader, device, criterion, num_classes):
    model.eval()
    running_loss = 0.0
    conf = np.zeros((num_classes, num_classes), dtype=np.int64)
    for images, masks in dataloader:
        images = images.to(device)
        masks = masks.to(device)

        outputs = model(images)['out']
        loss = criterion(outputs, masks)
        running_loss += loss.item() * images.size(0)

        preds = outputs.argmax(dim=1)
        conf += fast_confusion_matrix(preds, masks, num_classes)

    epoch_loss = running_loss / (len(dataloader.dataset))
    metrics = compute_metrics_from_confusion(conf, ignore_index=None)
    metrics['loss'] = epoch_loss
    return metrics


In [6]:
# -------------------------
# Arguments
# -------------------------

train_dir='/content/drive/MyDrive/endovis256/train/'
val_dir='/content/drive/MyDrive/endovis256/test/'
out_dir='/content/drive/MyDrive/endovis256/focal_checkpoints/'
epochs=20
batch_size=16
img_size=256
lr=1e-4
num_workers=32
num_classes=10
device='cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [7]:
# -------------------------
# Main
# -------------------------

def main():

    print(out_dir)
    #repeated as sometime collab notebook was behaving wierdly
    device='cuda' if torch.cuda.is_available() else 'cpu'
    print(device)

    os.makedirs(out_dir, exist_ok=True)
    device = torch.device(device)

    print("Preparing datasets...")
    train_dataset = VideoSegmentationDataset(train_dir, img_size=(img_size, img_size), augment=True)
    val_dataset = VideoSegmentationDataset(val_dir, img_size=(img_size, img_size), augment=False)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                              num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                            num_workers=num_workers, pin_memory=True)

    print("Creating model...")
    model = seg_models.deeplabv3_resnet50(pretrained=False, num_classes=num_classes)
    # for using the pretrained model, uncomment the lines below instead
    # TODO: update path below
    # model = seg_models.deeplabv3_resnet50(pretrained=False, num_classes=num_classes)
    # ckpt = torch.load('/content/drive/MyDrive/endovis256/focal_checkpoints/best_model.pth', map_location=device)
    # model.load_state_dict(ckpt['model_state_dict'])

    model.to(device)

    ce_criterion = nn.CrossEntropyLoss()
    focal_criterion = FocalLoss(gamma=2.0)

    def combined_loss(logits, targets):
        return ce_criterion(logits, targets) + focal_criterion(logits, targets)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)

    best_miou = -1.0
    for epoch in range(1, epochs + 1):
        print(f"\n=== Epoch {epoch}/{epochs} ===")
        train_metrics = train_one_epoch(model, train_loader, optimizer, device, combined_loss, num_classes, epoch)
        print(f"Train loss {train_metrics['loss']:.4f}  pix_acc {train_metrics['pixel_acc']:.4f}  mean_iou {train_metrics['mean_iou']:.4f}")

        val_metrics = validate(model, val_loader, device, combined_loss, num_classes)
        print(f"Val   loss {val_metrics['loss']:.4f}  pix_acc {val_metrics['pixel_acc']:.4f}  mean_iou {val_metrics['mean_iou']:.4f}")
        for k, v in val_metrics['per_class_iou'].items():
            print(f"  {k}: {v}")

        scheduler.step(val_metrics['mean_iou'])

        # Save checkpoint
        # ckpt_path = os.path.join(out_dir, f"epoch_{epoch:03d}.pth")
        # torch.save({
        #     'epoch': epoch,
        #     'model_state_dict': model.state_dict(),
        #     'optimizer_state_dict': optimizer.state_dict(),
        #     'miou': val_metrics['mean_iou']
        # }, ckpt_path)

        if val_metrics['mean_iou'] > best_miou:
            best_miou = val_metrics['mean_iou']
            best_path = os.path.join(out_dir, "best_model.pth")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'miou': val_metrics['mean_iou']
            }, best_path)
            print(f"  -> New best model saved to {best_path} (mIoU={best_miou:.4f})")

    print("Training finished. Best val mIoU:", best_miou)


In [None]:
main()

/content/drive/MyDrive/endovis256/focal_checkpoints/
cuda
Preparing datasets...




Creating model...


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 134MB/s]



=== Epoch 1/20 ===
  [Epoch 1] Iter 20/816  loss=0.1293
  [Epoch 1] Iter 40/816  loss=0.1473
  [Epoch 1] Iter 60/816  loss=0.1099
  [Epoch 1] Iter 80/816  loss=0.1636
  [Epoch 1] Iter 100/816  loss=0.1565
  [Epoch 1] Iter 120/816  loss=0.1383
  [Epoch 1] Iter 140/816  loss=0.2010
  [Epoch 1] Iter 160/816  loss=0.1351
  [Epoch 1] Iter 180/816  loss=0.1402
  [Epoch 1] Iter 200/816  loss=0.1317
  [Epoch 1] Iter 220/816  loss=0.2174
  [Epoch 1] Iter 240/816  loss=0.1551
  [Epoch 1] Iter 260/816  loss=0.1323
  [Epoch 1] Iter 280/816  loss=0.1416
  [Epoch 1] Iter 300/816  loss=0.1475
  [Epoch 1] Iter 320/816  loss=0.1691
  [Epoch 1] Iter 340/816  loss=0.1328
  [Epoch 1] Iter 360/816  loss=0.1582
  [Epoch 1] Iter 380/816  loss=0.1466
  [Epoch 1] Iter 400/816  loss=0.1277
  [Epoch 1] Iter 420/816  loss=0.1027
  [Epoch 1] Iter 440/816  loss=0.1614
  [Epoch 1] Iter 460/816  loss=0.1359
  [Epoch 1] Iter 480/816  loss=0.1745
  [Epoch 1] Iter 500/816  loss=0.1324
  [Epoch 1] Iter 520/816  loss=0.1