In [None]:
!pip install torchvision
!pip install torchmetrics

# Data Import & Visualization

In [None]:
import torch
from torch import nn
import os
from os import path
import torchvision
import torchvision.transforms as T
from typing import Sequence
from torchvision.transforms import functional as F
import numbers
import random
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
import torchmetrics as TM

# Convert a pytorch tensor into a PIL image
t2img = T.ToPILImage()
# Convert a PIL image into a pytorch tensor
img2t = T.ToTensor()

dataset_path = "/Users/weijia/college/DS4400/FinalProject"

In [None]:
# Oxford IIIT Pets Segmentation dataset loaded via torchvision.
pets_train_orig = torchvision.datasets.OxfordIIITPet(root=dataset_path, split="trainval", target_types="segmentation", download=True)
pets_test_orig = torchvision.datasets.OxfordIIITPet(root=dataset_path, split="test", target_types="segmentation", download=True)

In [None]:
from enum import IntEnum
class TrimapClasses(IntEnum):
    PET = 0
    BACKGROUND = 1
    BORDER = 2

In [None]:
# Convert a float trimap ({1, 2, 3} / 255.0) into a float tensor with
# pixel values in the range 0.0 to 1.0 so that the border pixels
# can be properly displayed.
def trimap2f(trimap):
    return (img2t(trimap) * 255.0 - 1) / 2

In [None]:
idx_pet = 17 #@param {type:"slider", min:0, max:50, step:1}
(train_pets_input, train_pets_target) = pets_train_orig[idx_pet]

plt.figure(figsize=(10,5))
plt.subplot(1, 2, 1)
plt.imshow(train_pets_input)
plt.title("Image"); plt.grid(False)

plt.subplot(1, 2, 2)
plt.imshow(t2img(trimap2f(train_pets_target)))
plt.title("Label"); plt.grid(False)

# Helper Functions & Classes

In [None]:
def save_model_checkpoint(model, cp_name):
    torch.save(model.state_dict(), os.path.join(working_dir, cp_name))


def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")

# Load model from saved checkpoint
def load_model_from_checkpoint(model, ckp_path):
    return model.load_state_dict(
        torch.load(
            ckp_path,
            map_location=get_device(),
        )
    )

# Send the Tensor or Model (input argument x) to the right device
# for this notebook. i.e. if GPU is enabled, then send to GPU/CUDA
# otherwise send to CPU.
def to_device(x):
    if torch.cuda.is_available():
        return x.cuda()
    else:
        return x.cpu()

def get_model_parameters(m):
    total_params = sum(
        param.numel() for param in m.parameters()
    )
    return total_params

def print_model_parameters(m):
    num_model_parameters = get_model_parameters(m)
    print(f"The Model has {num_model_parameters} parameters")
# end if

def close_figures():
    while len(plt.get_fignums()) > 0:
        plt.close()
    # end while
# end def

# Validation: Check if CUDA is available
print(f"CUDA: {torch.cuda.is_available()}")

In [None]:
# Simple torchvision compatible transform to send an input tensor
# to a pre-specified device.
class ToDevice(torch.nn.Module):
    """
    Sends the input object to the device specified in the
    object's constructor by calling .to(device) on the object.
    """
    def __init__(self, device):
        super().__init__()
        self.device = device

    def forward(self, img):
        return img.to(self.device)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(device={device})"

In [None]:
# Create a dataset wrapper that allows us to perform custom image augmentations
# on both the target and label (segmentation mask) images.
#
# These custom image augmentations are needed since we want to perform
# transforms such as:
# 1. Random horizontal flip
# 2. Image resize
#
# and these operations need to be applied consistently to both the input
# image as well as the segmentation mask.
class OxfordIIITPetsAugmented(torchvision.datasets.OxfordIIITPet):
    def __init__(
        self,
        root: str,
        split: str,
        target_types="segmentation",
        download=False,
        pre_transform=None,
        post_transform=None,
        pre_target_transform=None,
        post_target_transform=None,
        common_transform=None,
    ):
        super().__init__(
            root=root,
            split=split,
            target_types=target_types,
            download=download,
            transform=pre_transform,
            target_transform=pre_target_transform,
        )
        self.post_transform = post_transform
        self.post_target_transform = post_target_transform
        self.common_transform = common_transform

    def __len__(self):
        return super().__len__()

    def __getitem__(self, idx):
        (input, target) = super().__getitem__(idx)

        # Common transforms are performed on both the input and the labels
        # by creating a 4 channel image and running the transform on both.
        # Then the segmentation mask (4th channel) is separated out.
        if self.common_transform is not None:
            both = torch.cat([input, target], dim=0)
            both = self.common_transform(both)
            (input, target) = torch.split(both, 3, dim=0)
        # end if

        if self.post_transform is not None:
            input = self.post_transform(input)
        if self.post_target_transform is not None:
            target = self.post_target_transform(target)

        return (input, target)

# Model Training Functions

In [None]:
# Train the model for a single epoch
def train_model(model, loader, optimizer, cel=True, binary_classification=False):
    to_device(model.train())
    if cel:
        criterion = nn.CrossEntropyLoss(reduction='mean')
    else:
        criterion = IoULoss(softmax=True)
    # end if
    if binary_classification:
        criterion = nn.BCELoss(reduction='mean')
    # end if

    running_loss = 0.0
    running_samples = 0

    for batch_idx, (inputs, targets) in enumerate(loader, 0):
        optimizer.zero_grad()
        inputs = to_device(inputs)
        targets = to_device(targets)
        outputs = model(inputs)

        # The ground truth labels have a channel dimension (NCHW).
        # We need to remove it before passing it into
        # CrossEntropyLoss so that it has shape (NHW) and each element
        # is a value representing the class of the pixel.
        if cel:
            targets = targets.squeeze(dim=1)
        # end if
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_samples += targets.size(0)
        running_loss += loss.item()
    # end for

    print("Trained {} samples, Loss: {:.4f}".format(
        running_samples,
        running_loss / (batch_idx+1),
    ))
# end def

In [None]:
import matplotlib.pyplot as plt


# Define training loop. This will train the model for multiple epochs.
#
# epochs: A tuple containing the start epoch (inclusive) and end epoch (exclusive).
#         The model is trained for [epoch[0] .. epoch[1]) epochs.
#
def train_loop(model, train_data_loader, test_data_loader, epochs, optimizer, scheduler, save_path, cel=True, binary_classification=False):
    epoch_i, epoch_j = epochs
    avg_iou = []
    for i in range(epoch_i, epoch_j):
        epoch = i
        print(f"Epoch: {i:02d}, Learning Rate: {optimizer.param_groups[0]['lr']}")
        train_model(model, train_data_loader, optimizer, cel=cel, binary_classification=binary_classification)
        # evaluation on 210 pictures
        with torch.inference_mode():
            iou_accuracy_sum = 0
            for batch_idx, (inputs, targets) in enumerate(test_data_loader, 0):
              to_device(model.eval())
              predictions = model(to_device(inputs))
              pred, pred_labels, pred_mask = get_prediction_mask(predictions)
              iou_accuracy = IoUMetric(pred, targets)
              iou_accuracy_sum += iou_accuracy
              # only print out the first batch (21 pictures)
              if batch_idx == 0:
                # print_image_gt_pred(pred, pred_labels, pred_mask, inputs, targets, epoch=epoch, save_path=save_path, show_plot=(epoch == epoch_j-1), binary_classification=binary_classification)
                  print_image_gt_pred(pred, pred_labels, pred_mask, inputs, targets, epoch=epoch, save_path=save_path, show_plot=True, binary_classification=binary_classification)
              if batch_idx == 9: # only use the first 10 batches as validation set
                    break
            iou_accuracy_avg = iou_accuracy_sum / 10 # average accuracy for the 10 batches
            avg_iou.append(iou_accuracy_avg)
        # end with

        if scheduler is not None:
            scheduler.step()
        # end if
        print("")
    # end for
    plt.figure(figsize=(8, 6))
    epoch_numbers = range(epoch_i, epoch_j)
    plt.plot(epoch_numbers, avg_iou, linestyle='-')
    plt.xlabel('Epoch')
    plt.ylabel('IoU (All Classes)')
    plt.title('IoU vs. Training Epoch')
    plt.grid(True)
    plt.ylim(0, 1)
    plt.show()
# end def

Function for Visualization that will display the Target image, Ground Truth Labels, and Predicted Labels.

In [None]:
def print_test_dataset_masks(model, test_pets_targets, test_pets_labels, epoch, save_path, show_plot, binary_classification=False):
    to_device(model.eval())
    predictions = model(to_device(test_pets_targets))
    test_pets_labels = to_device(test_pets_labels)
    print_image_gt_pred(predictions, test_pets_targets, test_pets_labels, epoch, save_path, show_plot, binary_classification=binary_classification)

def get_prediction_mask(predictions):
    if predictions.ndim == 4: # there is a color channel dimension
        pred = nn.Softmax(dim=1)(predictions) # the 3 channels(labels) for each pixel will sum to 1
        pred_labels = pred.argmax(dim=1) # get the channel(labels) that has the highest possibility for each pixel
    else:
        pred_labels = predictions
        pred_labels[predictions < 0.5] = 0
        pred_labels[predictions >= 0.5] = 1
        pred = pred_labels.unsqueeze(1)

    # Add a value 1 dimension at dim=1
    pred_labels = pred_labels.unsqueeze(1)
    pred_mask = pred_labels.to(torch.float)
    return pred, pred_labels, pred_mask

# def print_image_gt_pred(predictions, test_pets_targets, test_pets_labels, epoch, save_path, show_plot, binary_classification=False):
def print_image_gt_pred(pred, pred_labels, pred_mask, test_pets_targets, test_pets_labels, epoch, save_path, show_plot, binary_classification=False):
    # pred_mask = get_prediction_mask(predictions)

    if (not binary_classification):
        iou = to_device(TM.classification.MulticlassJaccardIndex(3, average='micro', ignore_index=TrimapClasses.BACKGROUND))
        pixel_metric = to_device(TM.classification.MulticlassAccuracy(3, average='micro'))
    else:
        iou = to_device(TM.classification.BinaryJaccardIndex(ignore_index=TrimapClasses.BACKGROUND))
        pixel_metric = to_device(TM.classification.BinaryAccuracy())
        
    iou_accuracy_ignore_bg = iou(pred_mask, test_pets_labels)
    pixel_accuracy = pixel_metric(pred_labels, test_pets_labels)
    iou_all = IoUMetric(pred, test_pets_labels)

    title = f'Epoch: {epoch:02d}, Accuracy[Pixel: {pixel_accuracy:.4f}, IoU (ignore background): {iou_accuracy_ignore_bg:.4f}, IoU (all classes): {iou_all:.4f}]'
    print(title)

    # Close all previously open figures.
    close_figures()

    fig = plt.figure(figsize=(10, 12))
    fig.suptitle(title, fontsize=12)

    fig.add_subplot(3, 1, 1)
    plt.imshow(t2img(torchvision.utils.make_grid(test_pets_targets, nrow=7)))
    plt.axis('off')
    plt.title("Targets")

    fig.add_subplot(3, 1, 2)
    if (binary_classification):
        plt.imshow(t2img(torchvision.utils.make_grid(test_pets_labels.float(), nrow=7)))
    else:
        plt.imshow(t2img(torchvision.utils.make_grid(test_pets_labels.float() / 2.0, nrow=7)))
    plt.axis('off')
    plt.title("Ground Truth Labels")

    fig.add_subplot(3, 1, 3)
    if (binary_classification):
        plt.imshow(t2img(torchvision.utils.make_grid(pred_mask, nrow=7)))
    else:
        plt.imshow(t2img(torchvision.utils.make_grid(pred_mask / 2.0, nrow=7)))
    plt.axis('off')
    plt.title("Predicted Labels")

    if save_path is not None:
        plt.savefig(os.path.join(save_path, f"epoch_{epoch:02}.png"), format="png", bbox_inches="tight", pad_inches=0.4)
    # end if

    if show_plot is False:
        close_figures()
    else:
        plt.show()
    # end if
# end def

# Model Evaluation Metrics

In [None]:
# Define a custom IoU Metric for validating the model.
def IoUMetric(pred, gt, softmax=False):
    # Run softmax if input is logits.
    if softmax is True:
        pred = nn.Softmax(dim=1)(pred)
    # end if

    # Add the one-hot encoded masks for all 3 output channels
    # (for all the classes) to a tensor named 'gt' (ground truth).
    gt = torch.cat([ (gt == i) for i in range(3) ], dim=1)

    intersection = gt * pred
    union = gt + pred - intersection

    # Compute the sum over all the dimensions except for the batch dimension.
    iou = (intersection.sum(dim=(1, 2, 3)) + 0.001) / (union.sum(dim=(1, 2, 3)) + 0.001)

    # Compute the mean over the batch dimension.
    return iou.mean()

class IoULoss(nn.Module):
    def __init__(self, softmax=False):
        super().__init__()
        self.softmax = softmax

    # pred => Predictions (logits, B, 3, H, W)
    # gt => Ground Truth Labales (B, 1, H, W)
    def forward(self, pred, gt):
        # return 1.0 - IoUMetric(pred, gt, self.softmax)
        # Compute the negative log loss for stable training.
        return -(IoUMetric(pred, gt, self.softmax).log())
    # end def
# end class

# Model Definitions

In [None]:
from torch.nn.functional import relu

### Logistics Regression

In [None]:
class LogisticRegression(nn.Module):
    def __init__(self):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(3, 1)
        self.sigmoid = nn.Sigmoid()

w    def forward(self, x):
        original_shape = x.shape
        batch_size, channels, height, width = x.size()
        x = x.reshape(batch_size, channels, -1).permute(0, 2, 1)
        x = self.linear(x)
        x = self.sigmoid(x)
        return x.reshape(batch_size, height, width)

### FCN

In [None]:
class FCN(nn.Module):

    def __init__(self, n_class):
        super().__init__()
        self.n_class = n_class
        self.relu    = nn.ReLU(inplace=True)

        self.e11 = nn.Conv2d(3, 8, kernel_size=3, padding=1)
        self.e12 = nn.Conv2d(8, 8, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) #original size / 2

        self.e21 = nn.Conv2d(8, 16, kernel_size=3, padding=1)
        self.e22 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) #original size / 4

        self.e31 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.e32 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) #original size / 8

        self.e41 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.e42 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) #original size / 16

        self.e51 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.e52 = nn.Conv2d(128, 128, kernel_size=3, padding=1)

        self.deconv1 = nn.ConvTranspose2d(128, 32, kernel_size=16, stride=16)
        self.bn1 = nn.BatchNorm2d(32)
        self.classifier = nn.Conv2d(32, n_class, kernel_size=1)

    def forward(self, x):
        xe11 = relu(self.e11(x))
        xe12 = relu(self.e12(xe11))
        xp1 = self.pool1(xe12)

        xe21 = relu(self.e21(xp1))
        xe22 = relu(self.e22(xe21))
        xp2 = self.pool2(xe22)

        xe31 = relu(self.e31(xp2))
        xe32 = relu(self.e32(xe31))
        xp3 = self.pool3(xe32)

        xe41 = relu(self.e41(xp3))
        xe42 = relu(self.e42(xe41))
        xp4 = self.pool4(xe42)

        xe51 = relu(self.e51(xp4))
        xe52 = relu(self.e52(xe51))

        score = self.bn1(self.relu(self.deconv1(xe52)))
        score = self.classifier(score)

        return score  # size=(N, n_class, x.H/1, x.W/1)

### U-net

In [None]:
class UNet(nn.Module):
    def __init__(self, n_class):
        super().__init__()

        # Encoder
        # In the encoder, convolutional layers with the Conv2d function are used to extract features from the input image.
        # Each block in the encoder consists of two convolutional layers followed by a max-pooling layer, with the exception of the last block which does not include a max-pooling layer.
        # -------

        self.e11 = nn.Conv2d(3, 8, kernel_size=3, padding=1)
        self.e12 = nn.Conv2d(8, 8, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.e21 = nn.Conv2d(8, 16, kernel_size=3, padding=1)
        self.e22 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.e31 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.e32 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.e41 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.e42 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.e51 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.e52 = nn.Conv2d(128, 128, kernel_size=3, padding=1)


        # Decoder
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.d11 = nn.Conv2d(128, 64, kernel_size=3, padding=1) # input channel = 128 because the input will be concat of result from upconv1 and encoder layer e42
        self.d12 = nn.Conv2d(64, 64, kernel_size=3, padding=1)

        self.upconv2 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.d21 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
        self.d22 = nn.Conv2d(32, 32, kernel_size=3, padding=1)

        self.upconv3 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
        self.d31 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
        self.d32 = nn.Conv2d(16, 16, kernel_size=3, padding=1)

        self.upconv4 = nn.ConvTranspose2d(16, 8, kernel_size=2, stride=2)
        self.d41 = nn.Conv2d(16, 8, kernel_size=3, padding=1)
        self.d42 = nn.Conv2d(8, 8, kernel_size=3, padding=1)

        # Output layer
        self.outconv = nn.Conv2d(8, n_class, kernel_size=1)

    def forward(self, x):
        # Encoder
        xe11 = relu(self.e11(x))
        xe12 = relu(self.e12(xe11))
        xp1 = self.pool1(xe12)

        xe21 = relu(self.e21(xp1))
        xe22 = relu(self.e22(xe21))
        xp2 = self.pool2(xe22)

        xe31 = relu(self.e31(xp2))
        xe32 = relu(self.e32(xe31))
        xp3 = self.pool3(xe32)

        xe41 = relu(self.e41(xp3))
        xe42 = relu(self.e42(xe41))
        xp4 = self.pool4(xe42)

        xe51 = relu(self.e51(xp4))
        xe52 = relu(self.e52(xe51))

        # Decoder
        xu1 = self.upconv1(xe52)
        xu11 = torch.cat([xu1, xe42], dim=1)
        xd11 = relu(self.d11(xu11))
        xd12 = relu(self.d12(xd11))

        xu2 = self.upconv2(xd12)
        xu22 = torch.cat([xu2, xe32], dim=1)
        xd21 = relu(self.d21(xu22))
        xd22 = relu(self.d22(xd21))

        xu3 = self.upconv3(xd22)
        xu33 = torch.cat([xu3, xe22], dim=1)
        xd31 = relu(self.d31(xu33))
        xd32 = relu(self.d32(xd31))

        xu4 = self.upconv4(xd32)
        xu44 = torch.cat([xu4, xe12], dim=1)
        xd41 = relu(self.d41(xu44))
        xd42 = relu(self.d42(xd41))

        # Output layer
        out = self.outconv(xd42)

        return out

### DeepLabV3+

In [None]:
from collections import OrderedDict
from torch.nn.functional import interpolate

class DeepLabV3Plus(nn.Module):
    """
    DeepLab v3+: Dilated ResNet with multi-grid + improved ASPP + decoder
    """

    def __init__(self, n_classes, n_blocks, atrous_rates, multi_grids, output_stride):
        super(DeepLabV3Plus, self).__init__()

        # Stride and dilation
        if output_stride == 8:
            s = [1, 2, 1, 1]
            d = [1, 1, 2, 2]
        elif output_stride == 16:
            s = [1, 2, 2, 1]
            d = [1, 1, 1, 2]

        # Encoder
        ch = [8 * 2 ** p for p in range(6)]
        self.layer1 = _Stem(ch[0])
        self.layer2 = _ResLayer(n_blocks[0], ch[0], ch[2], s[0], d[0])
        self.layer3 = _ResLayer(n_blocks[1], ch[2], ch[3], s[1], d[1])
        self.layer4 = _ResLayer(n_blocks[2], ch[3], ch[4], s[2], d[2])
        self.layer5 = _ResLayer(n_blocks[3], ch[4], ch[5], s[3], d[3], multi_grids)
        self.aspp = _ASPP(ch[5], ch[2], atrous_rates)
        concat_ch = ch[2] * (len(atrous_rates) + 2)
        self.add_module("fc1", _ConvBnReLU(concat_ch, ch[2], 1, 1, 0, 1))

        # Decoder
        self.reduce = _ConvBnReLU(ch[2], ch[0], 1, 1, 0, 1)
        self.fc2 = nn.Sequential(
            OrderedDict(
                [
                    ("conv1", _ConvBnReLU(ch[2] + ch[0], ch[2], 3, 1, 1, 1)),
                    ("conv2", _ConvBnReLU(ch[2], ch[2], 3, 1, 1, 1)),
                    ("conv3", nn.Conv2d(ch[2], n_classes, kernel_size=1)),
                ]
            )
        )

    def forward(self, x):
        h = self.layer1(x)
        h = self.layer2(h)
        h_ = self.reduce(h)
        h = self.layer3(h)
        h = self.layer4(h)
        h = self.layer5(h)
        h = self.aspp(h)
        h = self.fc1(h)
        h = interpolate(h, size=h_.shape[2:], mode="bilinear", align_corners=False)
        h = torch.cat((h, h_), dim=1)
        h = self.fc2(h)
        h = interpolate(h, size=x.shape[2:], mode="bilinear", align_corners=False)
        return h


class _ResLayer(nn.Sequential):
    """
    Residual layer with multi grids
    """

    def __init__(self, n_layers, in_ch, out_ch, stride, dilation, multi_grids=None):
        super(_ResLayer, self).__init__()

        if multi_grids is None:
            multi_grids = [1 for _ in range(n_layers)]
        else:
            assert n_layers == len(multi_grids)

        # Downsampling is only in the first block
        for i in range(n_layers):
            self.add_module(
                "block{}".format(i + 1),
                _Bottleneck(
                    in_ch=(in_ch if i == 0 else out_ch),
                    out_ch=out_ch,
                    stride=(stride if i == 0 else 1),
                    dilation=dilation * multi_grids[i],
                    downsample=(True if i == 0 else False),
                ),
            )


class _Stem(nn.Sequential):
    """
    The 1st conv layer.
    """
    def __init__(self, out_ch):
        super(_Stem, self).__init__()
        self.add_module("conv1", _ConvBnReLU(3, out_ch, 7, 2, 3, 1)) # size /=2
        self.add_module("pool", nn.MaxPool2d(3, 2, 1, ceil_mode=True)) # size /=2


class _ASPP(nn.Module):
    """
    Atrous spatial pyramid pooling with image-level feature
    """

    def __init__(self, in_ch, out_ch, rates):
        super(_ASPP, self).__init__()
        self.stages = nn.Module()
        self.stages.add_module("c0", _ConvBnReLU(in_ch, out_ch, 1, 1, 0, 1))
        for i, rate in enumerate(rates):
            self.stages.add_module(
                "c{}".format(i + 1),
                _ConvBnReLU(in_ch, out_ch, 3, 1, padding=rate, dilation=rate),
            )
        self.stages.add_module("imagepool", _ImagePool(in_ch, out_ch))

    def forward(self, x):
        return torch.cat([stage(x) for stage in self.stages.children()], dim=1)

class _ImagePool(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.conv = _ConvBnReLU(in_ch, out_ch, 1, 1, 0, 1)

    def forward(self, x):
        _, _, H, W = x.shape
        h = self.pool(x)
        h = self.conv(h)
        h = interpolate(h, size=(H, W), mode="bilinear", align_corners=False)
        return h


_BOTTLENECK_EXPANSION = 4

class _Bottleneck(nn.Module):
    """
    Bottleneck block of MSRA ResNet.
    """

    def __init__(self, in_ch, out_ch, stride, dilation, downsample):
        super(_Bottleneck, self).__init__()
        mid_ch = out_ch // _BOTTLENECK_EXPANSION
        self.reduce = _ConvBnReLU(in_ch, mid_ch, 1, stride, 0, 1, True)
        self.conv3x3 = _ConvBnReLU(mid_ch, mid_ch, 3, 1, dilation, dilation, True)
        self.increase = _ConvBnReLU(mid_ch, out_ch, 1, 1, 0, 1, False)
        self.shortcut = (
            _ConvBnReLU(in_ch, out_ch, 1, stride, 0, 1, False)
            if downsample
            else nn.Identity()
        )

    def forward(self, x):
        h = self.reduce(x)
        h = self.conv3x3(h)
        h = self.increase(h)
        h += self.shortcut(x)
        return relu(h)

class _ConvBnReLU(nn.Sequential):
    """
    Cascade of 2D convolution, batch norm, and ReLU.
    """

    def __init__(
        self, in_ch, out_ch, kernel_size, stride, padding, dilation, relu=True
    ):
        super(_ConvBnReLU, self).__init__()
        self.add_module(
            "conv",
            nn.Conv2d(
                in_ch, out_ch, kernel_size, stride, padding, dilation, bias=False
            ),
        )
        self.add_module("bn", nn.BatchNorm2d(out_ch, eps=1e-5, momentum=1 - 0.999))

        if relu:
            self.add_module("relu", nn.ReLU())

# Model Training

In [None]:
# Create a tensor for a segmentation trimap.
# Input: Float tensor with values in [0.0 .. 1.0]
# Output: Long tensor with values in {0, 1, 2}
def tensor_trimap(t):
    x = t * 255
    x = x.to(torch.long)
    x = x - 1
    return x

def args_to_dict(**kwargs):
    return kwargs

transform_dict = args_to_dict(
    pre_transform=T.ToTensor(),
    pre_target_transform=T.ToTensor(),
    common_transform=T.Compose([
        ToDevice(get_device()),
        T.Resize((128, 128), interpolation=T.InterpolationMode.NEAREST),
        # Random Horizontal Flip as data augmentation.
        T.RandomHorizontalFlip(p=0.5),
    ]),
    post_transform=T.Compose([
        # Color Jitter as data augmentation.
        T.ColorJitter(contrast=0.3),
    ]),
    post_target_transform=T.Compose([
        T.Lambda(tensor_trimap),
    ]),
)

# Create the train and test instances of the data loader for the
# Oxford IIIT Pets dataset with random augmentations applied.
# The images are resized to 128x128 squares, so the aspect ratio
# will be chaged. We use the nearest neighbour resizing algorithm
# to avoid disturbing the pixel values in the provided segmentation
# mask.
pets_train = OxfordIIITPetsAugmented(
    root=dataset_path,
    split="trainval",
    target_types="segmentation",
    download=False,
    **transform_dict,
)
pets_test = OxfordIIITPetsAugmented(
    root=dataset_path,
    split="test",
    target_types="segmentation",
    download=False,
    **transform_dict,
)

pets_train_loader = torch.utils.data.DataLoader(
    pets_train,
    batch_size=64,
    shuffle=True,
)
pets_test_loader = torch.utils.data.DataLoader(
    pets_test,
    batch_size=21,
    shuffle=False,
)

In [None]:
# Create a tensor for a segmentation trimap.
# Input: Float tensor with values in [0.0 .. 1.0]
# Output: Float tensor with values in {0, 1} with all border classified as pet
def binary_tensor_trimap(t):
    x = t * 255
    x = x - 1
    x[x == TrimapClasses.BORDER] = TrimapClasses.PET
    return x

binary_transform_dict = args_to_dict(
    pre_transform=T.ToTensor(),
    pre_target_transform=T.ToTensor(),
    common_transform=T.Compose([
        ToDevice(get_device()),
        T.Resize((128, 128), interpolation=T.InterpolationMode.NEAREST),
        # Random Horizontal Flip as data augmentation.
        T.RandomHorizontalFlip(p=0.5),
    ]),
    post_transform=T.Compose([
        # Color Jitter as data augmentation.
        T.ColorJitter(contrast=0.3),
    ]),
    post_target_transform=T.Compose([
        T.Lambda(binary_tensor_trimap),
    ]),
)

binary_pets_train = OxfordIIITPetsAugmented(
    root=dataset_path,
    split="trainval",
    target_types="segmentation",
    download=False,
    **binary_transform_dict,
)
binary_pets_test = OxfordIIITPetsAugmented(
    root=dataset_path,
    split="test",
    target_types="segmentation",
    download=False,
    **binary_transform_dict,
)

binary_pets_train_loader = torch.utils.data.DataLoader(
    binary_pets_train,
    batch_size=64,
    shuffle=True,
)
binary_pets_test_loader = torch.utils.data.DataLoader(
    binary_pets_test,
    batch_size=21,
    shuffle=False,
)

In [None]:
logistics_regression_model = LogisticRegression()
optimizer = torch.optim.Adam(logistics_regression_model.parameters(), lr=0.0001)
to_device(logistics_regression_model)
train_loop(logistics_regression_model, binary_pets_train_loader, binary_pets_test_loader, (1, 11), optimizer, scheduler=None, save_path=None, binary_classification=True)

In [None]:
fcn_model = FCN(n_class=3)
optimizer = torch.optim.Adam(fcn_model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)
to_device(fcn_model)
train_loop(fcn_model, pets_train_loader, pets_test_loader, (1, 51), optimizer, scheduler, save_path=None)

In [None]:
unet_model = UNet(n_class=3)
optimizer = torch.optim.Adam(unet_model.parameters(), lr=0.002)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)
to_device(unet_model)
train_loop(unet_model, pets_train_loader, pets_test_loader, (1, 51), optimizer, scheduler, save_path=None)

In [None]:
deeplab_model = DeepLabV3Plus(n_classes=3, n_blocks=[3, 4, 9, 3], atrous_rates=[1, 2, 3], multi_grids=[1, 2, 1], output_stride=8)
optimizer = torch.optim.Adam(deeplab_model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)
to_device(deeplab_model)
train_loop(deeplab_model, pets_train_loader, pets_test_loader, (1, 101), optimizer, scheduler, save_path=None)

# Testing results

In [None]:
def test_dataset_accuracy(model, loader, binary_classification=False):
    to_device(model.eval())

    if (not binary_classification):
        iou = to_device(TM.classification.MulticlassJaccardIndex(3, average='micro', ignore_index=TrimapClasses.BACKGROUND))
        pixel_metric = to_device(TM.classification.MulticlassAccuracy(3, average='micro'))
    else:
        iou = to_device(TM.classification.BinaryJaccardIndex(ignore_index=TrimapClasses.BACKGROUND))
        pixel_metric = to_device(TM.classification.BinaryAccuracy())

    iou_accuracies = []
    pixel_accuracies = []
    custom_iou_accuracies = []

    print_model_parameters(model)

    for batch_idx, (inputs, targets) in enumerate(loader, 0):
        if batch_idx < 10:
            continue
        inputs = to_device(inputs)
        targets = to_device(targets)
        predictions = model(inputs)

        # pred_probabilities = nn.Softmax(dim=1)(predictions)
        # pred_labels = predictions.argmax(dim=1)

        # # Add a value 1 dimension at dim=1
        # pred_labels = pred_labels.unsqueeze(1)
        # # print("pred_labels.shape: {}".format(pred_labels.shape))
        # pred_mask = pred_labels.to(torch.float)

        pred_probabilities, pred_labels, pred_mask = get_prediction_mask(predictions)


        iou_accuracy = iou(pred_mask, targets)
        iou_accuracies.append(iou_accuracy.item())

        pixel_accuracy = pixel_metric(pred_labels, targets)
        custom_iou = IoUMetric(pred_probabilities, targets)
        pixel_accuracies.append(pixel_accuracy.item())
        custom_iou_accuracies.append(custom_iou.item())

        del inputs
        del targets
        del predictions
    # end for

    iou_tensor = torch.FloatTensor(iou_accuracies)
    pixel_tensor = torch.FloatTensor(pixel_accuracies)
    custom_iou_tensor = torch.FloatTensor(custom_iou_accuracies)

    print("Test Dataset Accuracy :")
    print(f"Pixel Accuracy: {pixel_tensor.mean():.4f}, IoU (ignore background): {iou_tensor.mean():.4f}, IoU (all classes): {custom_iou_tensor.mean():.4f}")


In [None]:
# logistics regression model
with torch.inference_mode():
    test_dataset_accuracy(logistics_regression_model, binary_pets_test_loader)

In [None]:
# FCN model
with torch.inference_mode():
    test_dataset_accuracy(fcn_model, pets_test_loader)

In [None]:
# U-net model
with torch.inference_mode():
    test_dataset_accuracy(unet_model, pets_test_loader)

In [None]:
# DeepLabV3+ model
with torch.inference_mode():
    test_dataset_accuracy(deeplab_model, pets_test_loader)