In [10]:
import torch
import torch.nn as nn
import random
import torch.optim as optim
from torchvision import transforms
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import os
import csv
import pandas as pd
from PIL import Image
from torchvision import transforms
from pytorch_msssim import ms_ssim

from collections import OrderedDict

from tensorboardX import SummaryWriter 


#### Setting the Global Variables

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# DEVICE = "cpu"
BATCH_SIZE = 16
MODEL_NAME = "IMAGE2IMAGE"

# Code for Auto Encoder 

In [4]:
class EncoderBlock(nn.Module):
    def __init__(self, blk, in_channels, out_channels):
        super().__init__()
        self.blk = blk
        self.conv1_a = nn.Conv2d(in_channels, out_channels, 3, 1, padding="same")
        self.conv1_b = nn.Conv2d(3, in_channels, 3, 1, padding="same")
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding="same")
        self.conv3 = nn.Conv2d(out_channels, out_channels, 3, 1, padding="same")
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.3)
        self.maxpool = nn.MaxPool2d(kernel_size=(2, 2)) 

    def forward(self, x, scale_img="none"):
        if ((self.blk=="first") or (self.blk=="bottleneck")):
            x1 = self.relu(self.conv1_a(x))
            x1 = self.relu(self.conv2(x1))
        else:
            skip_x = self.relu(self.conv1_b(scale_img))
            x1 = torch.cat([skip_x, x], dim=1)
            x1 = self.relu(self.conv2(x1))
            x1 = self.relu(self.conv3(x1))
        out = self.maxpool(self.dropout(x1))
        return out




class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor = 2)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, padding="same")
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, padding="same")
        self.conv3 = nn.Conv2d(out_channels, out_channels, 3, 1, padding="same")
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        x1 = self.upsample(x)
        x1 = self.relu(self.conv1(x1))
        x1 = self.relu(self.conv2(x1))
        x1 = self.relu(self.conv3(x1))
        out = self.dropout(x1)
        return out




class DeepSupervisionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor = 2)
        self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, padding="same")
        self.conv2 = nn.Conv2d(in_channels, in_channels, 3, 1, padding="same")
        self.conv3 = nn.Conv2d(in_channels, out_channels, 3, 1, padding="same")
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x1 = self.upsample(x)
        x1 = self.relu(self.conv1(x1))
        x1 = self.relu(self.conv2(x1))
        out = self.sigmoid(self.conv3(x1))
        return out




class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        filters = [8, 16, 32, 64, 128, 512] 
        self.drp_out = 0.3
        self.scale_img = nn.AvgPool2d(2, 2)   

        self.block_1 = EncoderBlock("first", 3, filters[0])
        self.block_2 = EncoderBlock("second", filters[0], filters[1])
        self.block_3 = EncoderBlock("third", filters[1], filters[2])
        self.block_4 = EncoderBlock("fourth", filters[2], filters[3])
        self.block_5 = EncoderBlock("fifth", filters[3], filters[4])
        self.block_6 = EncoderBlock("bottleneck", filters[4], filters[5])


    def forward(self, x):
        # Multi-scale input
        scale_img_2 = self.scale_img(x)
        scale_img_3 = self.scale_img(scale_img_2)
        scale_img_4 = self.scale_img(scale_img_3)  
        scale_img_5 = self.scale_img(scale_img_4)

        x1 = self.block_1(x)
        x2 = self.block_2(x1, scale_img_2)
        x3 = self.block_3(x2, scale_img_3)
        x4 = self.block_4(x3, scale_img_4)
        x5 = self.block_5(x4, scale_img_5)
        x6 = self.block_6(x5)
        return x6



class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        filters = [512, 128, 64, 32, 16, 8]
        self.drp_out = 0.3

        self.block_5 = DecoderBlock(filters[0], filters[1])
        self.block_4 = DecoderBlock(filters[1], filters[2])
        self.block_3 = DecoderBlock(filters[2], filters[3])
        self.block_2 = DecoderBlock(filters[3], filters[4])
        self.block_1 = DecoderBlock(filters[4], filters[5])
        self.ds = DeepSupervisionBlock(filters[5], 3)
        
    def forward(self, x):
        x = self.block_5(x)
        x = self.block_4(x)
        x = self.block_3(x)
        x = self.block_2(x)
        x = self.block_1(x)
        out9 = self.ds(x)
        return out9



class AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
    
    def forward(self, x):
        latent = self.encoder(x)
        output = self.decoder(latent)
        return latent, output



print("AUTOENCODER")
data = (torch.rand(size=(1, 3, 256, 256)))
AE = AutoEncoder()
img_out = AE(data)
print("Latent's Shape:", img_out[0].shape)
print("Output's ShapeL", img_out[1].shape)

AUTOENCODER
Latent's Shape: torch.Size([1, 512, 4, 4])
Output's ShapeL torch.Size([1, 3, 256, 256])


## Code for UNet

In [5]:

class UNet(nn.Module):

    def __init__(self, in_channels=3, out_channels=1, init_features=32):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")
        self.upconv4 = nn.ConvTranspose2d(features * 16, features * 8, kernel_size=2, stride=2)
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(features * 8, features * 4, kernel_size=2, stride=2)
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(features * 4, features * 2, kernel_size=2, stride=2)
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(features * 2, features, kernel_size=2, stride=2)
        self.decoder1 = UNet._block(features * 2, features, name="dec1")
        self.conv = nn.Conv2d(in_channels=features, out_channels=out_channels, kernel_size=1)
        # self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))
        bottleneck = self.bottleneck(self.pool4(enc4))
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        # return self.softmax(self.conv(dec1))
        return torch.sigmoid(self.conv(dec1))

    
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (name + "conv1",nn.Conv2d( in_channels=in_channels, out_channels=features, kernel_size=3, padding=1, bias=False)),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    ( name + "conv2", nn.Conv2d( in_channels=features, out_channels=features, kernel_size=3, padding=1, bias=False, )),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

print("\nUNET")
data = img_out[1]
unet = UNet()
seg_out = unet(data)
print("Output's Shape", seg_out.shape)


UNET
Output's Shape torch.Size([1, 1, 256, 256])


## Code for Image2Image and Reconstructed Image to Mask

In [6]:
class Image2Image2Mask(nn.Module):

    def __init__(self):
        super(Image2Image2Mask, self).__init__()

        self.image2imageAE = AutoEncoder()
        self.unet = UNet()

    def forward(self, x):
        imageLatent, reconsImage = self.image2imageAE(x)
        segMask = self.unet(reconsImage)
        return imageLatent, reconsImage, segMask


print("Combined Model")
data = (torch.rand(size=(4, 3, 256, 256)))
i2i2m = Image2Image2Mask()
imageLatent, reconsImage, segMask = i2i2m(data)
print("Latent's Shape: ", imageLatent.shape)
print("Reconstructed Image's Shape: ", reconsImage.shape)
print("Segmentation Mask's Shape: ", segMask.shape)
    

Combined Model
Latent's Shape:  torch.Size([4, 512, 4, 4])
Reconstructed Image's Shape:  torch.Size([4, 3, 256, 256])
Segmentation Mask's Shape:  torch.Size([4, 1, 256, 256])


## Loading the Dataset

In [9]:
class CustomDataset(Dataset):
    def __init__(self, csv_file):
        self.data = pd.read_csv(csv_file)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image_path = self.data.iloc[index, 0]
        image = Image.open(image_path).convert('RGB')
        # Preprocess the image if needed
        # ...

        # Convert image to tensor
        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            # Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        image_tensor = transform(image)

        return image_tensor


class CustomDataLoader:
    def __init__(self, batch_size):
        self.batch_size = batch_size

    def get_data(self):
        train_csv = "Datasets/BDD 100K/images/train.csv"
        val_csv = "Datasets/BDD 100K/images/val.csv"
        test_csv = "Datasets/BDD 100K/images/test.csv"

        train_dataset = CustomDataset(train_csv)
        val_dataset = CustomDataset(val_csv)
        test_dataset = CustomDataset(test_csv)

        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=True)

        return train_loader, val_loader, test_loader



FileNotFoundError: [WinError 3] The system cannot find the path specified: 'Datasets/BDD 100K\\images/100k\\train'

## Loss Functions

In [6]:
class DiceLoss(nn.Module):

    def __init__(self, num_classes=8):
        super(DiceLoss, self).__init__()
        self.smooth = 1.0
        self.classes = 3
        self.ignore_index = None
        self.eps = 1e-7

    def forward(self, y_pred, y_true):
        assert y_pred.size() == y_true.size()
        y_pred = y_pred[:, 0].contiguous().view(-1)
        y_true = y_true[:, 0].contiguous().view(-1)
        intersection = (y_pred * y_true).sum()
        dsc = (2. * intersection + self.smooth) / (
            y_pred.sum() + y_true.sum() + self.smooth
        )
        return 1. - dsc




class JaccardScore(nn.Module):

    def __init__(self):
        super(JaccardScore, self).__init__()
    
    def forward(self, y_pred, y_true):
        assert y_pred.size() == y_true.size()
        y_pred = y_pred[:, 0].contiguous().view(-1)
        y_true = y_true[:, 0].contiguous().view(-1)
        intersection = torch.logical_and(y_true, y_pred)
        union = torch.logical_or(y_true, y_pred)
        iou_score = torch.sum(intersection) / torch.sum(union)  
        return iou_score




class MixedLoss(nn.Module):
  def __init__(self, alpha, beta):
    super(MixedLoss, self).__init__()
    self.alpha = alpha
    self.beta = beta

  def forward(self, y_pred, y_true):
    # y_pred and y_true are of shape (batch_size, channels, height, width)
    # compute the MS-SSIM loss
    msssim_loss = 1 - ms_ssim(y_pred, y_true)
    # compute the L2 loss
    l2_loss = nn.MSELoss()(y_pred, y_true)
    # return the mixed loss
    return self.alpha*msssim_loss + self.beta*l2_loss

### A Model class to train, test, validate and infer the Image-to-Image Autoencoder Setup

In [11]:
class Model():
 
    def __init__(self, trained=False):
        self.model = AutoEncoder().to(DEVICE)
        self.jaccard = JaccardScore()

    def psnr(self, reconstructed, original, max_val=1.0): return 20 * torch.log10(max_val / torch.sqrt(F.mse_loss(reconstructed, original)))        


    def train(self, dataset, loss_func, optimizer):

        self.model.train()
        running_loss = 0.0
        running_psnr = 0.0
        counter = 0
        
        for i, img in tqdm(enumerate(dataset), total=len(dataset)):
            counter += 1
            image= img.to(DEVICE)

            noise_image = image + torch.randn(image.size()).to(DEVICE) * 0.05 + 0.0
            
            output = self.model(noise_image)
            loss = loss_func(output[1], image)
            running_loss += loss.item()

            loss.backward()
            optimizer.step()

            #calculate the Jaccard score here
            psnr = self.psnr(output[1], image)
            running_psnr += psnr.item()


        epoch_loss = running_loss / (counter*BATCH_SIZE)
        epoch_psnr = running_psnr / counter

        return epoch_loss, epoch_psnr



    def validate(self, dataset):

        self.model.eval()
        running_correct = 0.0
        running_psnr = 0.0
        counter = 0

        with torch.no_grad():
            for i, img in tqdm(enumerate(dataset), total=len(dataset)):
                counter += 1
                img = img.to(DEVICE)
                output = self.model(img)

                psnr = self.psnr(output[1], img)
                running_psnr += psnr.item()
    
        epoch_psnr = running_psnr / counter
        return epoch_psnr



    def test(self, dataset, epoch):
        running_psnr = 0.0  
        counter = 0
        num = random.randint(0, len(dataset) // (BATCH_SIZE // 2))

        with torch.no_grad():
            for i, img in tqdm(enumerate(dataset), total=len(dataset)):
                counter += 1
                img = img.to(DEVICE)
                output = self.model(img)
                pred = output[1]
                psnr = self.psnr(output[1], img)  
                running_psnr += psnr.item()

                if i == num:
                    try:
                        os.makedirs(f"saved_samples/{MODEL_NAME}", exist_ok=True)
                    except:
                        pass
                    image = img[0, :, :, :].cpu().numpy().transpose((1, 2, 0))
                    pred = pred[0, :, :, :].cpu().numpy().transpose((1, 2, 0))
                    image = (image * 255).astype('uint8')
                    pred = (pred * 255).astype('uint8')
                    image_pil = Image.fromarray(image)  
                    pred_pil = Image.fromarray(pred)  
                    stacked_image = Image.new('RGB', (image_pil.width * 2, image_pil.height))
                    stacked_image.paste(image_pil, (0, 0))
                    stacked_image.paste(pred_pil, (image_pil.width, 0))

                    stacked_image.save(f"saved_samples/{MODEL_NAME}/{epoch}.jpg")

        epoch_psnr = running_psnr / counter 
        return epoch_psnr



 
    def fit(self, epochs, lr):

        print(f"Using {DEVICE} device...")
        print("Loading Datasets...")
        train_data, val_data, test_data = CustomDataLoader(BATCH_SIZE).get_data()
        print("Dataset Loaded.")

        print("Initializing Parameters...")
        self.model = self.model.to(DEVICE)
        total_params = sum(p.numel() for p in self.model.parameters())
        print(f"The total parameters of the model are: {total_params}")

        print(f"Initializing the Optimizer")
        optimizer = optim.AdamW(self.model.parameters(), lr)
        print(f"Beginning to train...")

        diceloss = DiceLoss()

        val_psnr_epochs = []
        writer = SummaryWriter(f'runs/{MODEL_NAME}/')
        os.makedirs("checkpoints/", exist_ok=True)
        os.makedirs("saved_model/", exist_ok=True)


        for epoch in range(1, epochs+1):

            print(f"Epoch No: {epoch}")

            train_loss, train_psnr = self.train(dataset=train_data, loss_func=diceloss, optimizer=optimizer)

            val_psnr = self.validate(dataset=val_data)
            val_psnr_epochs.append(val_psnr)

            print(f"Train Loss:{train_loss}, Train Jaccard Score:{train_psnr}, Validation Jaccard Score:{val_psnr}")

            writer.add_scalar("Loss/Train", train_loss, epoch)
            writer.add_scalar("PSNR/Train", train_psnr, epoch)
            writer.add_scalar("PSNR/Val", val_psnr, epoch)


            if max(val_psnr_epochs) == val_psnr:
                torch.save(self.model.state_dict(), f"checkpoints/{MODEL_NAME}.pth")
            
            if epoch%5==0:
                print("Saving model")
                torch.save(self.model.state_dict(), f"saved_model/{MODEL_NAME}_{epoch}.pth")
                test_psnr = self.test(test_data, epoch)
                writer.add_scalar("PSNR/Test", test_psnr)
                print("Model Saved")

    
            print("Epoch Completed. Proceeding to next epoch...")

        print(f"Training Completed for {epochs} epochs.")


    def infer_a_random_sample(self):
        
        try:
            os.makedirs(f"test_samples/{MODEL_NAME}", exist_ok=True)
        except:
            pass
        
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])


In [None]:

model = Model()
model.fit(250, 5e-5)
