In [13]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from matplotlib import pyplot as plt
import numpy as np
from PIL import Image

import albumentations as A
from albumentations.pytorch import ToTensorV2

from torchinfo import summary

import segmentation_models_pytorch as smp 

from torchvision.datasets import Cityscapes
from torchvision import transforms as T
import torchvision.models as models

In [4]:
IGNORE_INDEX = 255
void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1]
valid_classes = [IGNORE_INDEX, 7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]
class_names = ['unlabelled', 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic_light', \
               'traffic_sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', \
               'train', 'motorcycle', 'bicycle']

class_map = dict(zip(valid_classes, range(len(valid_classes))))
n_classes = len(valid_classes)
class_map

colors = [
    [0, 0, 0],
    [128, 64, 128],
    [244, 35, 232],
    [70, 70, 70],
    [102, 102, 156],
    [190, 153, 153],
    [153, 153, 153],
    [250, 170, 30],
    [220, 220, 0],
    [107, 142, 35],
    [152, 251, 152],
    [0, 130, 180],
    [220, 20, 60],
    [255, 0, 0],
    [0, 0, 142],
    [0, 0, 70],
    [0, 60, 100],
    [0, 80, 100],
    [0, 0, 230],
    [119, 11, 32],
    ]

label_colours = dict(zip(range(n_classes), colors))

In [61]:
def encode_segmap(mask):
    #remove unwanted classes and recitify the labels of wanted classes
    for _voidc in void_classes:
        mask[mask == _voidc] = IGNORE_INDEX
    for _validc in valid_classes:
        mask[mask == _validc] = class_map[_validc]
    return mask

def decode_segmap(mask, return_np=False):
    #convert gray scale to color
    temp = mask.cpu().numpy()
    r = temp.copy()
    g = temp.copy()
    b = temp.copy()
    for l in range(0, n_classes):
        r[temp == l] = label_colours[l][0]
        g[temp == l] = label_colours[l][1]
        b[temp == l] = label_colours[l][2]

    rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
    rgb[:, :, 0] = r / 255.0
    rgb[:, :, 1] = g / 255.0
    rgb[:, :, 2] = b / 255.0
    if return_np:
        return rgb
    else:
        return torch.from_numpy(rgb)

In [6]:
transform = A.Compose([
    A.Resize(256, 512),
    # A.RandomCrop(256, 256),
    A.HorizontalFlip(),
    A.ColorJitter(hue=0),
    A.Normalize(),
    ToTensorV2()])

class CityscapesDataset(Cityscapes):
    def __getitem__(self, index):
        img = Image.open(self.images[index]).convert('RGB')
        smnt = Image.open(self.targets[index][0])

        if self.transforms is not None:
            transformed = self.transforms(image=np.array(img), mask=np.array(smnt))
            img = transformed['image']
            smnt = transformed['mask']
        smnt = encode_segmap(smnt)
        return img, smnt
        
# train_dataset = CityscapesDataset('./data/cityscapes', split='train', mode='fine',
#                                   target_type='semantic', transforms=transform)

# val_dataset = CityscapesDataset('./data/cityscapes', split='val', mode='fine',
#                                 target_type='semantic', transforms=transform)

In [7]:
class DecoderBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        skip_channels,
        out_channels
    ):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels + skip_channels, out_channels, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x, skip=None):
        x = F.interpolate(x, scale_factor=2)
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x, inplace=True)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x, inplace=True)
        return x
    
# UNet model with ResNet backbone
class UNet(nn.Module):
    def __init__(self, out_channels, resnet_layers=34):
        super().__init__()
        
        # Load pre-trained ResNet model
        if resnet_layers == 18:
            resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT, )
        elif resnet_layers == 34:
            resnet = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
        else:
            raise ValueError("ResNet layers must be 18 or 34.")
        
        # Remove the fully connected layers and the average pooling layer from ResNet
        encoder = nn.Sequential(
            *list(resnet.children())[:-2]  # Exclude the final fully connected layer and avg pool
        )
        
        # Extract feature maps from specific layers of ResNet
        self.encoder_blocks = nn.ModuleList([
            nn.Sequential(*encoder[:3]), # Conv7x7 - BN - ReLU
            nn.Sequential(*encoder[3:5]), # MaxPool - ResNet L1
            encoder[5], # ResNet L2
            encoder[6], # ResNet L3
            encoder[7]]) # ResNet L4
        
        self.decoder_blocks = nn.ModuleList([
            DecoderBlock(512, 256, 256), # in, skip, out
            DecoderBlock(256, 128, 128),
            DecoderBlock(128, 64, 64),
            DecoderBlock(64, 64, 32),
            DecoderBlock(32, 0, 16)])
        
        self.segmentation_head = nn.Conv2d(16, out_channels, kernel_size=1)
    
    def forward(self, x):
        # Encoder
        encoder_outputs = []
        for encoder_block in self.encoder_blocks:
            x = encoder_block(x)
            encoder_outputs.append(x)

        skip_connections = encoder_outputs[-2::-1] + [None,] # remove the last one, reverse the rest, add None

        # Decoder
        for skip, decoder_block in zip(skip_connections, self.decoder_blocks):
            x = decoder_block(x, skip)

        # Segmentation head
        x = self.segmentation_head(x)

        return x

In [8]:
def get_smp_unet(out_channels, resnet_layers=34):
    return smp.Unet(encoder_name=f"resnet{resnet_layers}",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
                    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
                    classes=out_channels,                      # model output channels (number of classes in your dataset)
                    )

In [9]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth

    def forward(self, pred, target):
        # Pred is of shape (batch_size, num_classes, height, width)
        # Target is of shape (batch_size, height, width) and contains class indices (not one-hot encoded)
        
        num_classes = pred.size(1)
        pred = F.softmax(pred, dim=1)  # Apply softmax to get class probabilities
        target_one_hot = F.one_hot(target.to(torch.int64), num_classes).permute(0, 3, 1, 2)  # Convert target to one-hot

        pred = pred.contiguous()
        target_one_hot = target_one_hot.contiguous()

        # Flatten
        pred_flat = pred.view(pred.size(0), pred.size(1), -1)
        target_flat = target_one_hot.view(target_one_hot.size(0), target_one_hot.size(1), -1)

        intersection = (pred_flat * target_flat).sum(2)
        dice_score = (2. * intersection + self.smooth) / (pred_flat.sum(2) + target_flat.sum(2) + self.smooth)
        
        return 1 - dice_score.mean()  # Return Dice Loss
    
dice_loss = DiceLoss()

In [10]:
# writer = SummaryWriter('runs/feature_maps_experiment')

# model.eval()
# with torch.no_grad():
#     for img, smnt in train_loader:
#         break

# # Forward pass to get feature maps
# def log_feature_maps(model, inputs, writer, step):
#     with torch.no_grad():
#         # Forward pass through the model
#         x = inputs
#         for name, layer in model.named_children():
#             x = layer(x)
#             print(name)
#             # Log feature maps from a specific layer
#             #if isinstance(layer, torch.nn.Conv2d) or isinstance(layer, torch.nn.ReLU):
#             if True:
#                 writer.add_images(f'feature_maps/{name}', x, step)
#                 print(name, 'written')
#             if 'bottle' in name:
#                 break

# # Log feature maps
# log_feature_maps(model, img[:1].to(device), writer, step=0)

# # Close the writer
# writer.close()

In [11]:
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from torchmetrics.segmentation import MeanIoU
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.tuner import Tuner

In [28]:
BATCH_SIZE = 32

INV_NORMALIZE = T.Normalize(
            mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
            std=[1/0.229, 1/0.224, 1/0.255]
        )

class CityscapesDataModule(LightningDataModule):
    def __init__(self, data_dir="./data/cityscapes", batch_size=BATCH_SIZE, transforms=transform):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transforms = transforms

    def setup(self, stage: str):
        if stage == "fit":
            self.train_dataset = CityscapesDataset(self.data_dir, split='train', mode='fine',
                                                target_type='semantic', transforms=self.transforms)
            self.val_dataset = CityscapesDataset(self.data_dir, split='val', mode='fine',
                                                target_type='semantic', transforms=self.transforms)
        if stage == "test":
            self.test_dataset = CityscapesDataset(self.data_dir, split='test', mode='fine',
                                                target_type='semantic', transforms=self.transforms)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)
    
cityscapes_data_module = CityscapesDataModule()

In [77]:
class CityscapesSemanticSegmentation(LightningModule):
    def __init__(self, model, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters(ignore='model')
        self.example_input_array = torch.Tensor(BATCH_SIZE, 3, 256, 512)
        self.model = model
        self.lr = learning_rate
        self.metrics = MeanIoU(num_classes=n_classes)

    def forward(self, x):
        return self.model(x)
    
    def _calculate_loss_and_iou(self, x, y, return_prediction=False):
        prediction = self.model(x)
        metric = 0  # #self.metrics(prediction, y)
        loss = dice_loss(prediction, y)
        if return_prediction:
             return loss, metric, prediction
        else:
            return loss, metric

    def training_step(self, batch, batch_idx):
        x, y = batch
        loss, iou = self._calculate_loss_and_iou(x, y)
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_iou", iou)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        loss, iou, prediction = self._calculate_loss_and_iou(x, y, return_prediction=True)
        self.log("val_loss", loss)
        self.log("val_iou", iou)

        # log images. Todo: do it once per epoch!
        tensorboard_logger = self.logger.experiment
        i = 0
        invimg = INV_NORMALIZE(x[i])
        decoded_target = decode_segmap(y[i]).to(invimg).permute(2, 0, 1)
        decoded_prediction = decode_segmap(torch.argmax(prediction[i], 0)).to(invimg).permute(2, 0, 1)
        image = torch.concat([invimg, decoded_target], dim=2)
        tensorboard_logger.add_image('image', image, self.current_epoch)
        tensorboard_logger.add_image('image_output', decoded_prediction, self.current_epoch)
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

# checkpoint_callback = ModelCheckpoint(monitor='val_loss',dirpath='checkpoints',
#                                         filename='file',save_last=True)

In [78]:
DEBUG = False
model = UNet(n_classes, resnet_layers=18)
LR = 7e-3
lightning_model = CityscapesSemanticSegmentation(model, learning_rate=LR)

logger = TensorBoardLogger('./')
log_every_n_steps = 5
trainer = Trainer(callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=3)], fast_dev_run=DEBUG,
                  limit_train_batches=0.1, limit_val_batches=0.1,
                  profiler='simple',
                  max_epochs=2,
                  precision='bf16-mixed',
                  logger=logger,
                  log_every_n_steps=log_every_n_steps)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [34]:
tuner = Tuner(trainer)
lr_finder = tuner.lr_find(lightning_model, cityscapes_data_module, num_training=50)

# Plot 
fig = lr_finder.plot(suggest=True)
fig.show()

lightning_model.hparams.lr = lr_finder.suggestion()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\zaits\anaconda3\envs\cuda\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.
c:\Users\zaits\anaconda3\envs\cuda\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.
Finding best initial lr: 100%|██████████| 50/50 [03:18<00:00,  2.45s/it]`Trainer.fit` stopped: `max_steps=50` reached.
Finding best initial lr: 100%|██████████| 50/50 [03:18<00:00,  3.96s/it]
Learning rate set to 0.036307805477010104
Restoring states from the checkpoint path at c:\main\repos\semantic_segmentation\.lr_find_b8f309

In [79]:
trainer.fit(lightning_model, datamodule=cityscapes_data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type    | Params | Mode  | In sizes          | Out sizes         
-------------------------------------------------------------------------------------
0 | model   | UNet    | 14.3 M | train | [32, 3, 256, 512] | [32, 20, 256, 512]
1 | metrics | MeanIoU | 0      | train | ?                 | ?                 
-------------------------------------------------------------------------------------
14.3 M    Trainable params
0         Non-trainable params
14.3 M    Total params
57.314    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\zaits\anaconda3\envs\cuda\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


                                                                           

c:\Users\zaits\anaconda3\envs\cuda\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Epoch 1: 100%|██████████| 9/9 [00:24<00:00,  0.37it/s, v_num=13, train_loss=0.902]

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 9/9 [00:24<00:00,  0.36it/s, v_num=13, train_loss=0.902]

FIT Profiler Report

---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                         	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                                                          	|  -              	|  958            


