# Image Segmentation using FCN on Pascal VOC

### Data Loading

In [None]:
# !pip install lightning albumentations torchinfo torchmetrics wandb

In [None]:
# import wandb
# wandb.login()

In [None]:
import numpy as np

import torch
from torch.utils.data import random_split, DataLoader
from torchvision import datasets, transforms

import lightning.pytorch as pl

import albumentations as A

In [None]:
class AugmentedDataset:
    def __init__(
        self,
        dataset,
        augmentation=None,
        img_transforms=None,
        mask_transforms=None,
    ):
        self.dataset = dataset
        self.augmentation = augmentation
        self.img_transforms = img_transforms
        self.mask_transforms = mask_transforms

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, mask = self.dataset[idx]

        img = np.array(img)
        mask = np.array(mask)

        if self.augmentation:
            augmented = self.augmentation(image=img, mask=mask)
            img = augmented["image"]
            mask = augmented["mask"]

        if self.img_transforms:
            img = self.img_transforms(img)

        if self.mask_transforms:
            mask = self.mask_transforms(mask)

        return img, mask

In [None]:
def get_augmentation(img_shape):
    train_transform = A.Compose(
        [
            A.Resize(*img_shape),
            A.Rotate(),
            A.HorizontalFlip(),
            A.RGBShift(),
            A.Blur(),
            A.RandomBrightnessContrast(),
            A.CLAHE(),
            A.Resize(*img_shape),
        ]
    )

    val_transform = A.Compose(
        [
            A.Resize(*img_shape),
        ]
    )

    return train_transform, val_transform

In [None]:
def load_VOCSegmentationDataset(data_dir, img_shape=(512, 512)):
    img_transforms = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ]
    )

    def mask_to_tensor(mask):
        mask = np.array(mask)
        mask[mask == 255] = 0 # Ignore the void class
        tensor_mask = torch.from_numpy(mask).long()
        return tensor_mask

    mask_transforms = transforms.Compose(
        [
            transforms.Lambda(lambda x: mask_to_tensor(x)),
        ]
    )

    dataset = datasets.VOCSegmentation(
        root=data_dir,
        image_set="trainval",
        download=True,
    )

    train_dataset, val_dataset = random_split(dataset, [0.9, 0.1])
    train_augmentation, val_augmentation = get_augmentation(img_shape)
    train_dataset = AugmentedDataset(
        train_dataset, train_augmentation, img_transforms, mask_transforms
    )
    val_dataset = AugmentedDataset(
        val_dataset, val_augmentation, img_transforms, mask_transforms
    )

    num_classes = 21

    return train_dataset, val_dataset, num_classes

In [None]:
train_dataset, val_dataset, num_classes = load_VOCSegmentationDataset(
    "datasets",
)

In [None]:
# Check shapes and visualize
print(train_dataset[0][0].shape, train_dataset[0][1].shape, train_dataset[0][0].min(), train_dataset[0][0].max())

import matplotlib.pyplot as plt

def visualize_dataset(dataset, num_samples=5):
    fig, axs = plt.subplots(2, num_samples, figsize=(15, 5))
    for i in range(num_samples):
        img, mask = dataset[i]
        img = img.numpy().transpose(1, 2, 0)
        # Img in range -1 to 1, so we need to rescale to 0 to 1
        img = (img + 1) / 2
        axs[0, i].imshow(img)
        axs[1, i].imshow(mask)
        print(mask.unique())
    plt.show()

visualize_dataset(train_dataset)

In [None]:
class SegmentationDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir,
        inp_size=512,
        batch_size=1,
        num_workers=0,
    ):
        super().__init__()

        self.data_dir = data_dir
        self.inp_size = inp_size
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.save_hyperparameters()

    def setup(self, stage):
        self.train_dataset, self.val_dataset, _ = load_VOCSegmentationDataset(
            self.data_dir,
            img_shape=(self.inp_size, self.inp_size),
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            drop_last=True,
            num_workers=self.num_workers,
        )

### Model Definition

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

In [None]:
# Check backbone architecture
# resnet = models.resnet18(pretrained=True)
# resnet

In [None]:
class FCN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.num_classes = num_classes

        # backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        # backbone = models.resnet50(pretrained=True)
        backbone = models.vgg16(pretrained=True)

        # self.backbone = nn.Sequential(
        #     *list(backbone.children())[:-2],
        # )
        self.backbone = nn.Sequential(
            *list(backbone.features.children())
        )

        self.classifier = nn.LazyConv2d(num_classes, 1)

    def forward(self, x):
        x = self.backbone(x)
        x = self.classifier(x)
        x = F.interpolate(x, size=(512, 512), mode="bilinear", antialias=True)
        return x

In [None]:
# Check output shape
# fcn = FCN_ResNet18(num_classes=21)
# test_input = torch.randn(1, 3, 512, 512)
# test_output = fcn(test_input)
# test_output.shape

### Training Setup

In [None]:
import torch.optim as optim

from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger

from torchinfo import summary

import torchmetrics as tm

In [None]:
import numpy as np

import matplotlib
from PIL import Image


def generate_mask(pred, num_classes, one_hot=True):
    if one_hot:
        pred_labels = pred.argmax(dim=1, keepdim=False).cpu().numpy()
    else:
        pred_labels = pred.cpu().numpy()

    # Create color map given number of classes
    color_map = matplotlib.colormaps.get_cmap("gnuplot2")
    color_map = matplotlib.colors.ListedColormap(
        color_map(np.linspace(0, 1, num_classes))
    )

    rgb_mask = color_map(pred_labels[0] / num_classes)
    rgb_mask = (rgb_mask[:, :, :3] * 255).astype(np.uint8)
    rgb_image = Image.fromarray(rgb_mask)

    return rgb_image

In [None]:
class FCNModule(pl.LightningModule):
    def __init__(
        self,
        inp_size=512,
        num_classes=21,
        lr=1e-4,
        momentum=0.9,
        weight_decay=0.0625,
        compile=False,
    ):
        super().__init__()

        self.save_hyperparameters()

        self.model = FCN(num_classes)

        test_input_shape = (1, 3, inp_size, inp_size)
        test_input = torch.randn(test_input_shape)
        _ = self.model(test_input)

        print(summary(self.model, input_size=test_input_shape))

        if compile:
            self.model = self.model.compile()

        self.lr = lr
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.num_classes = num_classes

        self.criterion = nn.CrossEntropyLoss()

        self.accuracy = tm.Accuracy(task="multiclass", num_classes=num_classes)
        self.iou = tm.JaccardIndex(task="multiclass", num_classes=num_classes)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = self.criterion(output, target)
        self.log("loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = self.criterion(output, target)
        self.log("val_loss", loss)

        pred = output.argmax(dim=1, keepdim=False)

        self.val_pred = output
        self.val_target = target

        self.accuracy(pred, target)
        self.log("val_acc", self.accuracy, on_step=False, on_epoch=True)

        self.iou(pred, target)
        self.log("val_iou", self.iou, on_step=False, on_epoch=True)

    def on_validation_epoch_end(self):
        if self.num_classes == 1:
            self.val_pred = torch.sigmoid(self.val_pred)
            mask_image = generate_mask(self.val_pred, self.num_classes + 1, False)
        else:
            mask_image = generate_mask(self.val_pred, self.num_classes + 1)
        target_mask = generate_mask(self.val_target, self.num_classes + 1, False)

        for logger in self.loggers:
            if isinstance(logger, TensorBoardLogger):
                np_mask_image = np.array(mask_image.convert("RGB"))
                np_mask_image = np_mask_image.transpose(2, 0, 1)
                logger.experiment.add_image(
                    "val_mask", np_mask_image, global_step=self.current_epoch
                )

            if isinstance(logger, WandbLogger):
                logger.log_image(key="val_mask", images=[mask_image])
                logger.log_image(key="val_target", images=[target_mask])

    def configure_optimizers(self):
        optimizer = optim.SGD(
            self.parameters(),
            lr=self.lr,
            momentum=self.momentum,
            weight_decay=self.weight_decay,
        )
        return optimizer

### Training

In [None]:
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.tuner import Tuner

In [None]:
datamodule = SegmentationDataModule(
    "datasets",
)
fcn_module = FCNModule()
wandb_logger = WandbLogger(project="semantic_segmentation")
tensorboard_logger = TensorBoardLogger("tensorboard_logs/")
lr_monitor = LearningRateMonitor(log_momentum=True)

trainer = pl.Trainer(
    accelerator="cpu",
    max_epochs=100,
    logger=[wandb_logger, tensorboard_logger],
    callbacks=[lr_monitor],
    log_every_n_steps=10,
    # fast_dev_run=True,
    accumulate_grad_batches=20,
)

tuner = Tuner(trainer)
lr_finder = tuner.lr_find(fcn_module, datamodule=datamodule)
print(lr_finder.results)
print(lr_finder.suggestion())

trainer.fit(fcn_module, datamodule=datamodule)

### Inference / Testing

In [None]:
# Test model on some images from the validation dataset and visualize
def visualize_predictions(model, dataset, num_samples=5):
    fig, axs = plt.subplots(3, num_samples, figsize=(15, 5))
    for i in range(num_samples):
        img, mask = dataset[i]
        img = img.unsqueeze(0)
        pred = model(img)
        pred = pred.argmax(dim=1, keepdim=False)
        pred = pred.squeeze(0)
        pred = pred.numpy()
        axs[0, i].imshow(img.squeeze(0).numpy().transpose(1, 2, 0))
        axs[1, i].imshow(mask.numpy())
        axs[2, i].imshow(pred)
    plt.show()

visualize_predictions(fcn_module, val_dataset)