In [1]:
import os
import os.path as osp
import numpy as np
import math

from tqdm import tqdm
from PIL import Image
from sklearn.metrics import roc_auc_score

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.transforms as transforms

from datasets.cityscapes import trainDataset, valDataset, testDataset

Total Images in train dataset =  2975
Total Images in val dataset =  500
Total Images in test dataset =  1525
Shape of image =  torch.Size([32, 3, 1024, 2048])
Shape of smnt =  torch.Size([32, 4, 1024, 2048])
Shape of image =  torch.Size([32, 3, 1024, 2048])
Shape of smnt =  torch.Size([32, 4, 1024, 2048])
Shape of image =  torch.Size([32, 3, 1024, 2048])
Shape of smnt =  torch.Size([32, 4, 1024, 2048])


In [4]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device = ", DEVICE)

Device =  cuda


In [2]:
train_dataloader = DataLoader(trainDataset, shuffle = True, batch_size = 32)
test_dataloader = DataLoader(valDataset, shuffle = True, batch_size = 32)
val_dataloader = DataLoader(testDataset, shuffle = True, batch_size = 32)

for i, (img, smnt) in enumerate(train_dataloader):
    print("Shape of image = ", img.shape)
    print("Shape of smnt = ", smnt.shape)

    if i == 2:
        break

Total Images in train dataset =  2975
Total Images in val dataset =  500
Total Images in test dataset =  1525
Shape of image =  torch.Size([32, 3, 1024, 2048])
Shape of smnt =  torch.Size([32, 4, 1024, 2048])
Shape of image =  torch.Size([32, 3, 1024, 2048])
Shape of smnt =  torch.Size([32, 4, 1024, 2048])
Shape of image =  torch.Size([32, 3, 1024, 2048])
Shape of smnt =  torch.Size([32, 4, 1024, 2048])


In [11]:
# Double convolution block.
class Conv2x_Block(torch.nn.Module):
    def __init__(self, inChannelCount, outChannelCount):
        super().__init__()
        conv2x = nn.Sequential(
            # Remove bias if adding batchnorm later.
            nn.Conv2d(inChannelCount, outChannelCount, kernel_size=3, stride=1, bias=True),
            nn.ReLU(),
            nn.Conv2d(outChannelCount, outChannelCount, kernel_size=3, stride=1, bias=True),
            nn.ReLU()
        )

    def forward(self, X):
        return self.conv2x(X)

class UNet_Encoder(torch.nn.Module):
    def __init__(self, channels_per_layer=[3, 64, 128, 256, 512]):
        super().__init__()
        self.encoder_conv2x = torch.nn.ModuleList(
            [Conv2x_Block(channels_per_layer[i], channels_per_layer[i+1]) for i in range(0, len(channels_per_layer) - 1)]
        )
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
    
    def forward(self, X):
        self.residual_layers_output = []
        for conv_block in self.EncoderConv2x:
            X = conv_block(X)
            self.residual_layers_output.append(X)
            X = self.max_pool(X)
        
        return X, self.residual_layers_output


class UNet_Decoder(torch.nn.Module):
    def __init__(self, channels_per_layer=[512, 256, 128, 64]):
        super().__init__()
        self.channels_per_layer = channels_per_layer
        self.transpose_conv = torch.nn.ModuleList(
			[nn.ConvTranspose2d(channels_per_layer[i], channels_per_layer[i + 1], kernel_size=2, stride=2) for i in range(0, len(channels_per_layer) - 1)])
			 	
        self.decoder_conv2x = torch.nn.ModuleList(
            [Conv2x_Block(channels_per_layer[i], channels_per_layer[i+1]) for i in range(0, len(channels_per_layer) - 1)]
        )
        
        #self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, X, encoder_layers_output):
        for trans_conv,  conv2x, encoder_layer in zip(self.transpose_conv, self.decoder_conv2x, encoder_layers_output):
            X = trans_conv(X)
            if X.shape != encoder_layer.shape:
                X = transforms.functional.resize(X, encoder_layer.shape[2:])

            X_concatenated = torch.cat((encoder_layer, X), dim = 1)
            X = conv2x(X_concatenated)
        return X

class UNet(torch.nn.Module):
    def __init__(self, encoder_channels, decoder_channels):
          super().__init__()
          self.encoder = UNet_Encoder(encoder_channels)
          self.decoder = UNet_Encoder(decoder_channels)
          self.final_conv = nn.Conv2d(decoder_channels[-1], num_classes=19, kernel_size=1)

    def forward(self, X):
        X_encoded, encoder_layers_output = self.encoder(X)
        X_decoded = self.decoder(X_encoded, encoder_layers_output[::-1])
        out = self.final_conv(X_decoded)
