In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter

from matplotlib import pyplot as plt
import numpy as np
from PIL import Image
import os

import albumentations as A
from albumentations.pytorch import ToTensorV2

from torchinfo import summary

import segmentation_models_pytorch as smp 

from torchvision.datasets import Cityscapes
from torchvision import transforms as T
import torchvision.models as models

from torchmetrics.segmentation import MeanIoU

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
IGNORE_INDEX = 255
void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1]
valid_classes = [IGNORE_INDEX, 7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]
class_names = ['unlabelled', 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic_light', \
               'traffic_sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', \
               'train', 'motorcycle', 'bicycle']

class_map = dict(zip(valid_classes, range(len(valid_classes))))
n_classes = len(valid_classes)
class_map

colors = [
    [0, 0, 0],
    [128, 64, 128],
    [244, 35, 232],
    [70, 70, 70],
    [102, 102, 156],
    [190, 153, 153],
    [153, 153, 153],
    [250, 170, 30],
    [220, 220, 0],
    [107, 142, 35],
    [152, 251, 152],
    [0, 130, 180],
    [220, 20, 60],
    [255, 0, 0],
    [0, 0, 142],
    [0, 0, 70],
    [0, 60, 100],
    [0, 80, 100],
    [0, 0, 230],
    [119, 11, 32],
    ]

label_colours = dict(zip(range(n_classes), colors))

In [3]:
def encode_segmap(mask):
    #remove unwanted classes and recitify the labels of wanted classes
    for _voidc in void_classes:
        mask[mask == _voidc] = IGNORE_INDEX
    for _validc in valid_classes:
        mask[mask == _validc] = class_map[_validc]
    return mask

def decode_segmap(mask, return_np=False, color_last=True):
    #convert gray scale to color
    temp = mask.cpu().numpy()
    r = temp.copy()
    g = temp.copy()
    b = temp.copy()
    for l in range(0, n_classes):
        r[temp == l] = label_colours[l][0]
        g[temp == l] = label_colours[l][1]
        b[temp == l] = label_colours[l][2]
    
    rgb = np.zeros(list(temp.shape)+[3])
    rgb[..., 0] = r / 255.0
    rgb[..., 1] = g / 255.0
    rgb[..., 2] = b / 255.0

    if not color_last:
        if len(rgb.shape)==4:
            rgb = np.transpose(rgb, (0, 3, 1, 2))
        elif len(rgb.shape)==3:
            rgb = np.transpose(rgb, (2, 0, 1))
        else:
            raise ValueError(f'mask must have shape either (H,W) or (N,H,W), but {mask.size()} is provided')

    if return_np:
        return rgb
    else:
        return torch.from_numpy(rgb)

In [4]:
transform = A.Compose([
    A.HorizontalFlip(),
    A.ColorJitter(hue=0),
    A.Normalize(),
    ToTensorV2()])

class CityscapesPreprocessedDataset(Dataset):
    def __init__(self, root, split, transforms=None):
        self.root = root
        self.split = split
        self.transforms = transforms

    def __getitem__(self, index):
        x = np.load(os.path.join(self.root, self.split, f'{index}.npy'))
        y = np.load(os.path.join(self.root, self.split, f'y{index}.npy'))

        if self.transforms is not None:
            transformed = self.transforms(image=np.array(x), mask=np.array(y))
            x = transformed['image']
            y = transformed['mask']
        return x, y
    
    def __len__(self):
        if self.split=='train':
            return 2975
        elif self.split=='val':
            return 500
        elif self.split=='test':
            return 1525
        else:
            raise ValueError(f'split must be one of {"train", "val", "test"}, but {self.split} is given')

In [5]:
class DecoderBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        skip_channels,
        out_channels
    ):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels + skip_channels, out_channels, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x, skip=None):
        x = F.interpolate(x, scale_factor=2)
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x, inplace=True)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x, inplace=True)
        return x
    
# UNet model with ResNet backbone
class UNet(nn.Module):
    def __init__(self, out_channels, resnet_layers=34):
        super().__init__()
        
        # Load pre-trained ResNet model
        if resnet_layers == 18:
            resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT, )
        elif resnet_layers == 34:
            resnet = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
        else:
            raise ValueError("ResNet layers must be 18 or 34.")
        
        # Remove the fully connected layers and the average pooling layer from ResNet
        encoder = nn.Sequential(
            *list(resnet.children())[:-2]  # Exclude the final fully connected layer and avg pool
        )
        
        # Extract feature maps from specific layers of ResNet
        self.encoder_blocks = nn.ModuleList([
            nn.Sequential(*encoder[:3]), # Conv7x7 - BN - ReLU
            nn.Sequential(*encoder[3:5]), # MaxPool - ResNet L1
            encoder[5], # ResNet L2
            encoder[6], # ResNet L3
            encoder[7]]) # ResNet L4
        
        self.decoder_blocks = nn.ModuleList([
            DecoderBlock(512, 256, 256), # in, skip, out
            DecoderBlock(256, 128, 128),
            DecoderBlock(128, 64, 64),
            DecoderBlock(64, 64, 32),
            DecoderBlock(32, 0, 16)])
        
        self.segmentation_head = nn.Conv2d(16, out_channels, kernel_size=1)
    
    def forward(self, x):
        # Encoder
        encoder_outputs = []
        for encoder_block in self.encoder_blocks:
            x = encoder_block(x)
            encoder_outputs.append(x)

        skip_connections = encoder_outputs[-2::-1] + [None,] # remove the last one, reverse the rest, add None

        # Decoder
        for skip, decoder_block in zip(skip_connections, self.decoder_blocks):
            x = decoder_block(x, skip)

        # Segmentation head
        x = self.segmentation_head(x)

        return x

In [6]:
def get_smp_unet(out_channels, resnet_layers=34):
    return smp.Unet(encoder_name=f"resnet{resnet_layers}",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
                    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
                    classes=out_channels,                      # model output channels (number of classes in your dataset)
                    )

In [7]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth

    def forward(self, pred, target):
        # Pred is of shape (batch_size, num_classes, height, width)
        # Target is of shape (batch_size, height, width) and contains class indices (not one-hot encoded)
        
        num_classes = pred.size(1)
        pred = F.softmax(pred, dim=1)  # Apply softmax to get class probabilities
        target_one_hot = F.one_hot(target.to(torch.int64), num_classes).permute(0, 3, 1, 2)  # Convert target to one-hot

        pred = pred.contiguous()
        target_one_hot = target_one_hot.contiguous()

        # Flatten
        pred_flat = pred.view(pred.size(0), pred.size(1), -1)
        target_flat = target_one_hot.view(target_one_hot.size(0), target_one_hot.size(1), -1)

        intersection = (pred_flat * target_flat).sum(2)
        dice_score = (2. * intersection + self.smooth) / (pred_flat.sum(2) + target_flat.sum(2) + self.smooth)
        
        return 1 - dice_score.mean()  # Return Dice Loss
    
dice_loss = DiceLoss()

In [8]:
class FormattedMeanIoU(nn.Module):
    def __init__(self, num_classes, include_background=True, per_class=False):
        super().__init__()
        self.metric = MeanIoU(num_classes, include_background, per_class)
    
    def forward(self, pred, target):
        # takes batched prediction before softmax, and one-hot encoded target
        return self.metric(torch.argmax(pred, dim=1), target.to(torch.int64))

In [9]:
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.tuner import Tuner

In [10]:
BATCH_SIZE = 32

transform = A.Compose([
    A.HorizontalFlip(),
    A.ColorJitter(hue=0),
    A.Normalize(),
    ToTensorV2()])

INV_NORMALIZE = T.Normalize(
            mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
            std=[1/0.229, 1/0.224, 1/0.255]
        )

class CityscapesDataModule(LightningDataModule):
    def __init__(self, data_dir="./data/cityscapes_preprocessed", batch_size=BATCH_SIZE, transforms=transform):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transforms = transforms

    def setup(self, stage: str):
        if stage == "fit":
            self.train_dataset = CityscapesPreprocessedDataset(self.data_dir, 'train', self.transforms)
            self.val_dataset = CityscapesPreprocessedDataset(self.data_dir, 'val', self.transforms)
        if stage == "test":
            self.test_dataset = CityscapesPreprocessedDataset(self.data_dir, 'test', self.transforms)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)


In [11]:
cityscapes_data_module = CityscapesDataModule()

In [12]:
class CityscapesSemanticSegmentation(LightningModule):
    def __init__(self, model, learning_rate=1e-3, lr_gamma=0.7):
        super().__init__()
        self.save_hyperparameters(ignore='model')
        self.example_input_array = torch.Tensor(BATCH_SIZE, 3, 256, 512)
        self.model = model
        self.lr = learning_rate
        self.lr_gamma = lr_gamma
        self.metrics = FormattedMeanIoU(n_classes)
        self.validation_batch = None

    def forward(self, x):
        return self.model(x)
    
    def _calculate_loss_and_iou(self, x, y, return_prediction=False):
        prediction = self.model(x)
        metric = self.metrics(prediction, y)
        loss = dice_loss(prediction, y)
        if return_prediction:
             return loss, metric, prediction
        else:
            return loss, metric

    def training_step(self, batch, batch_idx):
        x, y = batch
        loss, iou = self._calculate_loss_and_iou(x, y)
        self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        self.log("train_iou", iou, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        loss, iou = self._calculate_loss_and_iou(x, y)
        self.log("val_loss", loss)
        self.log("val_iou", iou)
        if self.validation_batch is None:
            self.validation_batch = x
            self.decoded_targets = decode_segmap(y).to(self.validation_batch).permute(0, 3, 1, 2)
            self.validation_images = torch.concat(list(INV_NORMALIZE(self.validation_batch)), dim=1)
            self.validation_truth = torch.concat(list(self.decoded_targets), dim=1)

    def on_validation_epoch_end(self):
        prediction = self.forward(self.validation_batch)
        decoded_predictions = decode_segmap(torch.argmax(prediction, 1)).to(self.validation_batch).permute(0, 3, 1, 2)
        
        validation_predictions = torch.concat(list(decoded_predictions), dim=1)
        image = torch.concat([self.validation_images, self.validation_truth, validation_predictions], dim=2)
        self.logger.experiment.add_image('image', image, self.current_epoch)
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, self.lr_gamma)
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

In [13]:
DEBUG = False
LR = 0.01
LR_GAMMA = 0.96
RESNET_LAYERS = 18
LOAD_LAST_CHECKPOINT = False
OVERRIDE_HP = False

version_name = f'resnet18_LR{LR}_LRG{LR_GAMMA}'

model = UNet(n_classes, resnet_layers=RESNET_LAYERS)
lightning_model = CityscapesSemanticSegmentation(model, learning_rate=LR, lr_gamma=LR_GAMMA)
if LOAD_LAST_CHECKPOINT:
    lightning_model = CityscapesSemanticSegmentation.load_from_checkpoint('./checkpoints/last.ckpt', model=model)
if OVERRIDE_HP:
    lightning_model.hparams['learning_rate'] = LR
    lightning_model.lr = LR 
    lightning_model.hparams['lr_gamma'] = LR_GAMMA
    lightning_model.lr_gamma = LR_GAMMA
print(lightning_model.hparams)

checkpoint_callback = ModelCheckpoint(
    dirpath=os.path.join('lightning_logs', version_name),  # Where to save the checkpoints
    filename='{epoch}-{val_loss:.4f}',  # Filename format
    monitor='val_loss',  # Metric to monitor
    save_last=True,  # Always save the last checkpoint
    enable_version_counter = False
)

logger = TensorBoardLogger('./', version=version_name, log_graph=True)

log_every_n_steps = 50
trainer = Trainer(callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=400),
                             LearningRateMonitor(logging_interval='epoch'),
                             checkpoint_callback],
                  fast_dev_run=DEBUG,
                #   limit_train_batches=0.1,
                #   limit_val_batches=0.1,
                #   profiler='simple',
                #  max_epochs=30,
                  precision='bf16-mixed',
                  logger=logger,
                #  log_every_n_steps=log_every_n_steps
                 )

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


"learning_rate": 0.01
"lr_gamma":      0.96


In [14]:
# tuner = Tuner(trainer)
# lr_finder = tuner.lr_find(lightning_model, cityscapes_data_module)

# # Plot 
# fig = lr_finder.plot(suggest=True)
# fig.show()

# lightning_model.hparams.lr = lr_finder.suggestion()

In [15]:
trainer.fit(lightning_model, datamodule=cityscapes_data_module)

c:\Users\zaits\anaconda3\envs\cuda\Lib\site-packages\pytorch_lightning\loops\utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
You are using a CUDA device ('NVIDIA GeForce RTX 4070') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
c:\Users\zaits\anaconda3\envs\cuda\Lib\site-packages\pytorch_lightning\callbacks\model_checkpoint.py:652: Checkpoint directory C:\main\repos\semantic_segmentation\lightning_logs\resnet18_LR0.01_LRG0.96 exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params | Mode  | In sizes          | Out sizes         
----------------------------------------------------------------------------

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

c:\Users\zaits\anaconda3\envs\cuda\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


                                                                           

c:\Users\zaits\anaconda3\envs\cuda\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Epoch 202:  78%|███████▊  | 73/93 [00:14<00:04,  4.88it/s, v_num=0.96, train_loss=0.327]

c:\Users\zaits\anaconda3\envs\cuda\Lib\site-packages\pytorch_lightning\trainer\call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [None]:
for x, y in cityscapes_data_module.val_dataloader():
    prediction = lightning_model(x.to(device=lightning_model.device))
    print(prediction.size())
    decoded_predictions = decode_segmap(torch.argmax(prediction, 1))
    print(decoded_predictions.size())
    plt.imshow(decoded_predictions[1].detach().cpu().numpy())
    plt.show()
    break