# Modeling

This notebook only cover the training of the SegFormer model.

### Data Loading

In [1]:
import os
import torch
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from PIL import Image

In [2]:
class TumorSegDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = Path(image_dir)
        self.mask_dir = Path(mask_dir)
        self.image_files = sorted(self.image_dir.glob("*.jpg"))
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        mask_path = Path(self.mask_dir) / img_path.name

        image = Image.open(img_path).convert("L")
        mask = Image.open(mask_path).convert("L")

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask

In [3]:
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
])

In [4]:
dataset = TumorSegDataset("../data/converted/images", "../data/converted/masks", transform = transform)

### Data Splitting

In [5]:
from torch.utils.data import random_split

In [6]:
total_size = len(dataset)
train_size = int(0.7 * total_size)
val_size = int(0.15 * total_size)
test_size = total_size - train_size - val_size

In [7]:
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

In [8]:
train_loader = DataLoader(train_dataset, batch_size = 8, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size = 8, shuffle = False)
test_loader = DataLoader(test_dataset, batch_size = 8, shuffle = False)

### Modeling

From https://github.com/mkara44/transunet_pytorch/blob/main/utils/transunet.py

In [9]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import copy
import logging
import math
from os.path import join as pjoin

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [11]:
criterion = nn.BCEWithLogitsLoss()

In [12]:
def dice_score(preds, targets, smooth = 1e-6):
    """
    Computes Dice Score (per batch).
    preds: tensor, shape (N, C, H, W) after applying sigmoid/softmax or thresholding
    targets: tensor, shape (N, C, H, W) one-hot encoded or same shape as preds
    """
    preds = preds.contiguous().view(preds.shape[0], -1)
    targets = targets.contiguous().view(targets.shape[0], -1)

    intersection = (preds * targets).sum(dim = 1)
    dice = (2. * intersection + smooth) / (preds.sum(dim = 1) + targets.sum(dim = 1) + smooth)
    return dice.mean().item()

In [13]:
def iou_score(preds, targets, smooth=1e-6):
    """
    Computes IoU (Jaccard Index) per batch.
    """
    preds = preds.contiguous().view(preds.shape[0], -1)
    targets = targets.contiguous().view(targets.shape[0], -1)

    intersection = (preds * targets).sum(dim=1)
    union = preds.sum(dim=1) + targets.sum(dim=1) - intersection
    iou = (intersection + smooth) / (union + smooth)
    return iou.mean().item()

##### SegFormer

In [14]:
from transformers import SegformerForSemanticSegmentation

  from .autonotebook import tqdm as notebook_tqdm


In [15]:
segformer_model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/segformer-b5-finetuned-ade-640-640",
    num_labels = 1, 
    ignore_mismatched_sizes = True
)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b5-finetuned-ade-640-640 and are newly initialized because the shapes did not match:
- decode_head.classifier.weight: found shape torch.Size([150, 768, 1, 1]) in the checkpoint and torch.Size([1, 768, 1, 1]) in the model instantiated
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([1]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [16]:
segformer_model = segformer_model.to(device)
optimizer = optim.Adam(segformer_model.parameters(), lr = 1e-4)

In [17]:
best_val_loss = float('inf')
patience = 10
trigger_times = 0

In [18]:
EPOCHS = 100

for epoch in range(EPOCHS):
    segformer_model.train()
    train_loss = 0
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)

        if images.shape[1] == 1:
            images = images.repeat(1, 3, 1, 1)

        outputs = segformer_model(images).logits
        outputs = nn.functional.interpolate(outputs, size=masks.shape[-2:], mode="bilinear", align_corners=False)

        loss = criterion(outputs, masks)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    # Validation
    segformer_model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)

            if images.shape[1] == 1:
                images = images.repeat(1, 3, 1, 1)

            outputs = segformer_model(images).logits
            outputs = nn.functional.interpolate(outputs, size=masks.shape[-2:], mode="bilinear", align_corners=False)
            
            loss = criterion(outputs, masks)
            val_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)
    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        trigger_times = 0
        torch.save(segformer_model.state_dict(), "../models/segformer.pth")
    else:
        trigger_times += 1
        print(f"Early Stopping counter: {trigger_times} out of {patience}")
        if trigger_times >= patience:
            print("Early stopping triggered. Stopping training.")
            break

Epoch 1: Train Loss = 0.1741, Val Loss = 0.0507
Epoch 2: Train Loss = 0.0367, Val Loss = 0.0287
Epoch 3: Train Loss = 0.0205, Val Loss = 0.0216
Epoch 4: Train Loss = 0.0147, Val Loss = 0.0189
Epoch 5: Train Loss = 0.0108, Val Loss = 0.0164
Epoch 6: Train Loss = 0.0087, Val Loss = 0.0169
Early Stopping counter: 1 out of 10
Epoch 7: Train Loss = 0.0075, Val Loss = 0.0161
Epoch 8: Train Loss = 0.0089, Val Loss = 0.0180
Early Stopping counter: 1 out of 10
Epoch 9: Train Loss = 0.0068, Val Loss = 0.0163
Early Stopping counter: 2 out of 10
Epoch 10: Train Loss = 0.0056, Val Loss = 0.0165
Early Stopping counter: 3 out of 10
Epoch 11: Train Loss = 0.0050, Val Loss = 0.0175
Early Stopping counter: 4 out of 10
Epoch 12: Train Loss = 0.0046, Val Loss = 0.0173
Early Stopping counter: 5 out of 10
Epoch 13: Train Loss = 0.0045, Val Loss = 0.0179
Early Stopping counter: 6 out of 10
Epoch 14: Train Loss = 0.0044, Val Loss = 0.0179
Early Stopping counter: 7 out of 10
Epoch 15: Train Loss = 0.0041, Val 

In [22]:
segformer_model.eval()
dice_total = 0
iou_total = 0
num_batches = 0

with torch.no_grad():
    for images, masks in test_loader:
        images = images.to(device)
        masks = masks.to(device)

        if images.shape[1] == 1:
            images = images.repeat(1, 3, 1, 1)

        outputs = segformer_model(images).logits
        outputs = nn.functional.interpolate(outputs, size=masks.shape[-2:], mode="bilinear", align_corners=False)

        if outputs.shape[1] == 1:  # Binary segmentation
            # probs = torch.sigmoid(outputs)
            probs = outputs
            preds = (probs > 0.5).float()
            targets = (masks > 0.5).float()
        else:  # Multiclass segmentation
            preds = torch.argmax(outputs, dim = 1)  # (N, H, W)
            targets = masks.long().squeeze(1)     # (N, H, W)

            # Convert preds and targets to one-hot for Dice and IoU
            preds = torch.nn.functional.one_hot(preds, num_classes=outputs.shape[1])  # (N,H,W,C)
            preds = preds.permute(0, 3, 1, 2).float()  # (N,C,H,W)
            targets = torch.nn.functional.one_hot(targets, num_classes=outputs.shape[1])
            targets = targets.permute(0, 3, 1, 2).float()

        dice = dice_score(preds, targets)
        iou = iou_score(preds, targets)

        dice_total += dice
        iou_total += iou
        num_batches += 1

print(f"Test Dice Score: {dice_total / num_batches:.4f}")
print(f"Test IoU Score: {iou_total / num_batches:.4f}")

Test Dice Score: 0.8337
Test IoU Score: 0.7550
