In [None]:
from matplotlib import pyplot as plt
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
from collections import namedtuple
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
# !pip install torch==2.0.0 torchvision==0.15.1 torchaudio==2.0.1 --index-url https://download.pytorch.org/whl/cu118
print(torch.__version__)
from unet import UNet
from trainer import Trainer

In [None]:
import os
root = r'C:\Users\Lenovo\Downloads\SelfDriveSuite\Scene segmentation\dataset_large'

print(os.listdir(root))

In [None]:
class Urbanscapes():
    # Based on https://github.com/mcordts/cityscapesScripts
    CityscapesClass = namedtuple(
        "CityscapesClass",
        ["name", "id", "color"],
    )
    
    classes = [
        CityscapesClass("unlabeled", 0, (0, 0, 0)),
        CityscapesClass("terrain", 1, (210, 0, 200)),
        CityscapesClass("sky", 2, (90, 200, 255)),
        CityscapesClass("tree", 3, (0, 199, 0)),
        CityscapesClass("vegetation", 4, (90, 240, 0)),
        CityscapesClass("building", 5, (140, 140, 140)),
        CityscapesClass("road", 6, (100, 60, 100)),
        CityscapesClass("guard rail", 7, (250, 100, 255)),
        CityscapesClass("traffic sign", 8, (255, 255, 0)),
        CityscapesClass("traffic light", 9, (200, 200, 0)),
        CityscapesClass("pole", 10, (255, 130, 0)),
        CityscapesClass("misc", 11, (80, 80, 80)),
        CityscapesClass("truck", 12, (160, 60, 60)),
        CityscapesClass("car", 13, (255, 127, 80)),
        CityscapesClass("van", 14, (0, 139, 139)),
    ]

    def __init__(
        self,
        root,
        mode="train",
        target_type= "instance",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        transforms: Optional[Callable] = None,
    ) -> None:
        self.root = root
        self.mode = mode 
        self.transforms = transforms
        self.target_transform = target_transform
        self.images_dir = os.path.join(self.root, self.mode, "images")
        self.targets_dir = os.path.join(self.root, self.mode, "targets")
        self.target_type = target_type
        self.images = []
        self.targets = []

        # print(os.listdir(self.images_dir))
        for city in os.listdir(self.images_dir):
            img_dir = self.images_dir
            target_dir = self.targets_dir
            for file_name in os.listdir(img_dir):
                self.images.append(os.path.join(img_dir, file_name))
                self.targets.append(os.path.join(target_dir, file_name))

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
            than one item. Otherwise, target is a json object if target_type="polygon", else the image segmentation.
        """

        image = Image.open(self.images[index]).convert("RGB")

        target = Image.open(self.targets[index])
        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target


    def __len__(self) -> int:
        return len(self.images)

In [None]:
dataset = Urbanscapes(root=root, mode='train', target_type='semantic')

In [None]:
dataset[0][0].size

In [None]:
fig,ax=plt.subplots(ncols=2,figsize=(12,8))
ax[0].imshow(dataset[0][0])
ax[1].imshow(dataset[0][1],cmap='gray')

# Some utility functions 

In [None]:
ignore_index=0
valid_classes = [ignore_index, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
class_names = ['unlabelled', 'terrain', 'sky', 'tree', 'vegetation', 'building', 'road', 'guard_rail', \
               'traffic_sign', 'traffic_light', 'pole', 'misc', 'truck', 'car', 'van']

#why i choose 15 classes
#https://stackoverflow.com/a/64242989

class_map = dict(zip(valid_classes, range(len(valid_classes))))
n_classes=len(valid_classes)
class_map

In [None]:
colors = [ [0, 0, 0],
        [210, 0, 200],     
        [90, 200, 255],    
        [0, 199, 0],       
        [90, 240, 0],      
        [140, 140, 140],   
        [100, 60, 100],    
        [250, 100, 255],   
        [255, 255, 0],     
        [200, 200, 0],     
        [255, 130, 0],    
        [80, 80, 80],     
        [160, 60, 60],    
        [255, 127, 80],   
        [0, 139, 139],     
    ]

label_colours = dict(zip(range(n_classes), colors))

In [None]:
def encode_segmap(mask):
    #remove unwanted classes and recitify the labels of wanted classes
    # for _voidc in void_classes:
    #     mask[mask == _voidc] = ignore_index
    for _validc in valid_classes:
        mask[mask == _validc] = class_map[_validc]
    return mask

In [None]:
def decode_segmap(temp):
    #convert gray scale to color
    temp=temp.numpy()
    r = temp.copy()
    g = temp.copy()
    b = temp.copy()
    for l in range(0, n_classes):
        r[temp == l] = label_colours[l][0]
        g[temp == l] = label_colours[l][1]
        b[temp == l] = label_colours[l][2]

    rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
    rgb[:, :, 0] = r / 255.0
    rgb[:, :, 1] = g / 255.0
    rgb[:, :, 2] = b / 255.0
    return rgb

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
transform=A.Compose(
[
    A.Resize(256, 512),
    A.HorizontalFlip(),
    A.RandomSizedCrop(min_max_height=(64, 128), size=(256, 512), p=0.5),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
]
)
val_transform=A.Compose(
[
    A.Resize(256, 512),
    ToTensorV2(),
]
)

In [None]:
from typing import Any, Callable, Dict, List, Optional, Union, Tuple


class Loaddataclass(Urbanscapes):
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        image = Image.open(self.images[index]).convert('RGB')

        target = Image.open(self.targets[index])
        if self.transforms is not None:
            transformed=transform(image=np.array(image), mask=np.array(target))            
        return transformed['image'],transformed['mask']

In [None]:
dataset = Loaddataclass(root=root, mode='val', target_type='semantic',transforms=transform)
img,seg= dataset[0]
print(img.shape,seg.shape)

In [None]:
fig,ax=plt.subplots(ncols=2,nrows=1,figsize=(16,8))
ax[0].imshow(img.permute(1, 2, 0))
ax[1].imshow(seg,cmap='gray')

In [None]:
#class labels before label correction
print(torch.unique(seg))
print(len(torch.unique(seg)))

In [None]:
#class labels after label correction
res=encode_segmap(seg.clone())
print(res.shape)
print(torch.unique(res))
print(len(torch.unique(res)))

In [None]:
#let do coloring
res1=decode_segmap(res.clone())

In [None]:
fig,ax=plt.subplots(ncols=2,figsize=(12,10))  
ax[0].imshow(res,cmap='gray')
ax[1].imshow(res1)

# Training

In [None]:
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch

In [None]:
dataset_train = Loaddataclass(root=root, mode='train', target_type='semantic',transforms=transform)

dataset_valid = Loaddataclass(root=root, mode='val', target_type='semantic',transforms=val_transform)

# dataloader training
dataloader_training = DataLoader(dataset=dataset_train, batch_size=2, shuffle=True)

# dataloader validation
dataloader_validation = DataLoader(dataset=dataset_valid, batch_size=1, shuffle=False)


class MulticlassDiceLoss(nn.Module):
    """Reference: https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch#Dice-Loss
    """
    def __init__(self, num_classes, softmax_dim=None):
        super().__init__()
        self.num_classes = num_classes
        self.softmax_dim = softmax_dim
    def forward(self, logits, targets, reduction='mean', smooth=1e-6):
        """The "reduction" argument is ignored. This method computes the dice
        loss for all classes and provides an overall weighted loss.
        """
        probabilities = logits
        if self.softmax_dim is not None:
            probabilities = nn.Softmax(dim=self.softmax_dim)(logits)
        # end if
        targets_one_hot = torch.nn.functional.one_hot(targets, num_classes=self.num_classes)
        
        # Convert from NHWC to NCHW
        targets_one_hot = targets_one_hot.permute(0, 3, 1, 2)
        # print(targets_one_hot.shape)
        # Multiply one-hot encoded ground truth labels with the probabilities to get the
        # prredicted probability for the actual class.
        intersection = (targets_one_hot * probabilities).sum()
        
        mod_a = intersection.sum()
        mod_b = targets.numel()
        
        dice_coefficient = 2. * intersection / (mod_a + mod_b + smooth)
        dice_loss = -dice_coefficient.log()
        # dice_loss = 1 - dice_coefficient
        return dice_loss, dice_coefficient

In [None]:
from torch.utils.data import Dataset
import time
import torchmetrics

class Trainer:
    def __init__(
        self,
        model: torch.nn.Module,
        device: torch.device,
        criterion: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        training_dataloader: Dataset,
        validation_dataloader: None,
        lr_scheduler: None,
        epochs: int = 100,
        epoch: int = 0,
        notebook: bool = False,
        checkpoint_dir: str = 'experiment',
        save_frequency: int = 1
    ):

        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.training_dataloader = training_dataloader
        self.validation_dataloader = validation_dataloader
        self.device = device
        self.epochs = epochs
        self.epoch = epoch
        self.notebook = notebook
        self.checkpoint_dir = checkpoint_dir
        self.save_frequency = save_frequency
        self.training_loss = []
        self.validation_loss = []
        self.learning_rate = []
        self.data = []
        self.checkpoint_dir = f"checkpoints/{self.checkpoint_dir}_{time.time()}"
        self.best_iou = 0.00001
        self.metrics = torchmetrics.classification.MulticlassJaccardIndex(num_classes=15).to(device)
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)

    def run_trainer(self):

        if self.notebook:
            from tqdm.notebook import tqdm, trange
        else:
            from tqdm import tqdm, trange

        progressbar = trange(self.epochs, desc="Progress")
        for i in progressbar:
            """Epoch counter"""
            self.epoch += 1  # epoch counter
            print("epoch: ", self.epoch)
            """Training block"""
            self._train()

            # if self.epoch % self.save_frequency == 0:
            #     model_name = 'epoch_' + str(self.epoch)+'.pth'
            #     torch.save(self.model.state_dict(), os.path.join(self.checkpoint_dir, model_name))
            

            """Validation block"""
            if self.validation_dataloader is not None:
                iou_value = self._validate()

            if iou_value > self.best_iou:
                model_name = 'epoch_best.pth'
                torch.save(self.model.state_dict(), os.path.join(self.checkpoint_dir, model_name))
                print(f"weights saved at epoch: {self.epoch}")
                self.best_iou = iou_value

            """Learning rate scheduler block"""
            if self.lr_scheduler is not None:
                if (
                    self.validation_dataloader is not None
                    and self.lr_scheduler.__class__.__name__ == "ReduceLROnPlateau"
                ):
                    self.lr_scheduler.step(
                        self.validation_loss[i]
                    )  # learning rate scheduler step with validation loss
                else:
                    self.lr_scheduler.step()

                    #self.lr_scheduler.batch()  # learning rate scheduler step
        return self.training_loss, self.validation_loss, self.learning_rate, self.data

    def _train(self):

        if self.notebook:
            from tqdm.notebook import tqdm, trange
        else:
            from tqdm import tqdm, trange

        self.model.train()  # train mode
        train_losses = []  # accumulate the losses here
        train_iou = []
        torch.manual_seed(0)
        batch_iter = tqdm(
            enumerate(self.training_dataloader),
            "Training",
            total=len(self.training_dataloader),
            leave=False,
            disable=True,
        )
        torch.manual_seed(0)
        for i, (x, y) in batch_iter:
            input_x, target_y = x.to(self.device), encode_segmap(y.to(
                self.device)
            )  # send to device (GPU or CPU)
            self.optimizer.zero_grad()  # zerograd the parameters
            # print("input_x shape: ", input_x.shape)
            out = self.model(input_x)  # one forward pass
            # print("train out.shape: ", out.shape)
            out = np.squeeze(out, axis=1)
            # print("loss computation")
            # print("input_x.shape: ", input_x.shape)
            # print("target_y.shape: ", target_y.shape)
            self.data = [input_x, target_y, out]
            # print("train out.shape: ", out.shape)
            loss, dice_score = self.criterion(out, target_y.long())  # calculate loss
            loss_value = loss.item()
            train_losses.append(loss_value)
            # print("train loss_value: ", loss_value)
            iou = self.metrics(out, target_y)
            iou_value = iou.item()
            train_iou.append(iou_value)
            # print("train_iou: ", iou_value)
            # if i%200==0:
            #     print("train loss_value: ", np.mean(train_losses))
            #     print("train iou_value: ", np.mean(train_iou))
            #     if i%4000==0:
            #         print("validating now..")
            #         iou_value = self._validate()
            #         if iou_value > self.best_iou:
            #             model_name = 'epoch_best.pth'
            #             torch.save(self.model.state_dict(), os.path.join(self.checkpoint_dir, model_name))
            #             print(f"weights saved at epoch: {self.epoch}")
            #             self.best_iou = iou_value
            # print("train iou: ", iou)
            # print("train dice_score: ", dice_score)
            loss.backward()  # one backward pass
            self.optimizer.step()  # update the parameters

            batch_iter.set_description(
                f"Training: (loss {loss_value:.4f})"
            )  # update progressbar
            

        self.training_loss.append(np.mean(train_losses))
        print("train mean loss_value: ", np.mean(train_losses))
        print("train mean iou_value: ", np.mean(train_iou))
        self.learning_rate.append(self.optimizer.param_groups[0]["lr"])

        batch_iter.close()

    def _validate(self):

        if self.notebook:
            from tqdm.notebook import tqdm, trange
        else:
            from tqdm import tqdm, trange

        self.model.eval()  # evaluation mode
        valid_losses = []  # accumulate the losses here
        valid_iou = []
        batch_iter = tqdm(
            enumerate(self.validation_dataloader),
            "Validation",
            total=len(self.validation_dataloader),
            leave=False,
            disable=True,
        )

        for i, (x, y) in batch_iter:
            input, target = x.to(self.device), encode_segmap(y.to(
                self.device)
            )  # send to device (GPU or CPU)

            with torch.no_grad():
                out = self.model(input)
                # print(out.shape)
                loss, dice_score = self.criterion(out, target.long())
                loss_value = loss.item()
                valid_losses.append(loss_value)
                iou = self.metrics(out, target)
                iou_value = iou.item()
                valid_iou.append(iou_value)
                batch_iter.set_description(f"Validation: (loss {loss_value:.4f})")


        self.validation_loss.append(np.mean(valid_losses))
        print("val loss_value: ", np.mean(valid_losses))
        mean_iou = np.mean(valid_iou)
        print("val iou_value: ", mean_iou)
        batch_iter.close()
        return mean_iou

In [None]:
# device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# model
model = UNet(
    in_channels=3,
    out_channels=15,
    n_blocks=4,
    start_filters=128,
    activation="relu",
    normalization="batch",
    conv_mode="same",
    dim=2,
).to(device)

criterion = MulticlassDiceLoss(num_classes=15, softmax_dim=1)
# optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.004) #lr=0.0025 min_lr=0
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.4, patience=15, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0.0001, eps=1e-08, verbose=False)

# trainer
trainer = Trainer(
    model=model,
    device=device,
    criterion=criterion,
    optimizer=optimizer,
    training_dataloader=dataloader_training,
    validation_dataloader=dataloader_validation,
    lr_scheduler=scheduler,
    epochs=100,
    epoch=0,
    notebook=True,
    checkpoint_dir=f"experiment_lr_on_plateau_dataset_large_color_aug",
    save_frequency=1
)
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(count_parameters(model=model))

In [None]:
training_losses, validation_losses, lr_rates, data = trainer.run_trainer()

In [None]:
from visual import plot_training

fig = plot_training(
    training_losses,
    validation_losses,
    lr_rates,
    gaussian=True,
    sigma=1,
    figsize=(10, 4),
)
print(validation_losses)
print(training_losses)
print(f"best weights at epoch {np.argmin(validation_losses)+1} having loss {np.min(validation_losses)}")

# Testing (& Saving the predictions)

In [None]:
def save_image(target, prediction, iteration, output_dir="output_images"):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # probabilities
    
    # Decode the target and prediction
    decoded_target = decode_segmap(target.cpu())
    # Get the predicted classes by taking argmax along the channel dimension
    predicted_classes = torch.argmax(prediction, dim=0)
    
    # Decode the prediction
    decoded_prediction = decode_segmap(predicted_classes.cpu())
    # decoded_prediction = decode_segmap(prediction.cpu())
    
    # Save target image
    plt.imsave(os.path.join(output_dir, f"target_iter{iteration}.png"), decoded_target)
    
    # Save prediction image
    plt.imsave(os.path.join(output_dir, f"prediction_iter{iteration}.png"), decoded_prediction)

root = r"C:\Users\Lenovo\Downloads\SelfDriveSuite\Scene segmentation\dataset_large\test_data"
# dataset_valid = Loaddataclass(root=root, mode='val', target_type='semantic',transforms=transform)
# dataset = Urbanscapes(root=root, mode='val', target_type='semantic')
transform=A.Compose(
[
    A.Resize(256, 512),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
]
)
dataset_valid = Loaddataclass(root=root, mode='val', target_type='semantic',transforms=transform)

# dataloader validation
dataloader_validation = DataLoader(dataset=dataset_valid, batch_size=1, shuffle=False)

from tqdm import tqdm, trange
# model.eval()  # evaluation mode
model.load_state_dict(torch.load(r'C:\Users\Lenovo\Downloads\SelfDriveSuite\Scene segmentation\checkpoints\experiment_lr_on_plateau_dataset_large_color_aug\epoch_best.pth'))
model.eval()  # evaluation mode
valid_losses = []  # accumulate the losses here
valid_iou = []
batch_iter = tqdm(
    enumerate(dataloader_validation),
    "Validation",
    total=len(dataloader_validation),
    leave=False,
    disable=True,
)
print(len(dataloader_validation))
metrics = torchmetrics.classification.MulticlassJaccardIndex(num_classes=15).to(device)
for i, (x, y) in batch_iter:
    input, target = x.to(device), encode_segmap(y.to(
        device)
    )  # send to device (GPU or CPU)

    with torch.no_grad():
        out = model(input)
        print(out.shape)
        print(target.shape)
        # loss, dice_score = criterion(out, target.long())
        # loss_value = loss.item()
        # valid_losses.append(loss_value)
        iou = metrics(out, target)
        iou_value = iou.item()
        print(f"i: {i}; iou_value: ", iou_value)
        valid_iou.append(iou_value)
        # batch_iter.set_description(f"Validation: (loss {loss_value:.4f})")
        target = np.squeeze(target, axis = 0)
        out = np.squeeze(out, axis = 0)
        save_image(target, out, i)
        if i>=len(batch_iter)/(339*1)-1:
            break


# print("val loss_value: ", np.mean(valid_losses))
mean_iou = np.mean(valid_iou)
print("val iou_value: ", mean_iou)
batch_iter.close()