<a href="https://www.kaggle.com/code/prabhanjanjadhav/image-colorization-using-gans?scriptVersionId=114252738" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

<span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Image Colorization using GANs 🧨</h1></span> 

In this notebook you'll learn to implement GANs, unets, and realize how pretraining is used in improve the model performance.
This notebook tackles the task of image colorization using conditional GANs. The implementation is inspired from pix2pix paper.
# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Intro: Generative Adversarial Networks 🍭<a id="Intro"></a></h1></span>
This module contains brief introduction to GANs. Feel free to skip it if you're familiar with it.

A GAN model comprises of 2 sub-models:

* The <b>generator</b> learns to generate plausible data. The generated instances become negative training examples for the discriminator.
* The <b>discriminator</b> learns to distinguish the generator's fake data from real data. The discriminator penalizes the generator for producing implausible results.

We simultaneously train two models: a generative model G that captures the data distribution, and a discriminative model D that estimates
the probability that a sample came from the training data rather than G. The training procedure for G is to maximize the probability of D making a mistake. This framework corresponds to a minimax two-player game. In the space of arbitrary functions G and D, a unique solution exists, with G recovering the training data distribution and D equal to ${1 \over 2}$ everywhere. 


<img src="https://media.springernature.com/full/springer-static/image/art%3A10.1038%2Fs41524-020-00352-0/MediaObjects/41524_2020_352_Fig1_HTML.png" width="800"></img>


# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Install Required Libraries 📦 <a id="Installing required libraries"></a></h1></span>

In [None]:
# !pip install --upgrade torch torchvision

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Import Required Libraries 🧺<a id="Importing libraries"></a></h1></span>

In [None]:
import os
import glob
import time

# For data manipulation
import numpy as np
from PIL import Image
import cv2 as cv
from pathlib import Path

# Pytorch imports
import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Utils
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb

In [None]:
!pip install fastai==2.4
from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18
from fastai.vision.models.unet import DynamicUnet

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Making Dataset and Dataloaders🥡<a id="Datasets and Dataloaders"></a></h1></span>
We'll be using a subset of <b>COCO dataset</b>: 8000 training images and 2000 validation images.

In [None]:
# setting seed
np.random.seed(123)

paths = glob.glob('/kaggle/input/coco-2017-dataset/coco2017/train2017' + "/*.jpg") # Grabbing all the image file paths
paths_subset = np.random.choice(paths, 10_000, replace=False) # choosing 10000 paths randomly
rand_idxs = np.random.permutation(10_000) # generate a numpy array of numbers from 0 to 9999 in any random order
train_idxs = rand_idxs[:8000]
val_idxs = rand_idxs[8000:]
train_paths = paths_subset[train_idxs]
val_paths = paths_subset[val_idxs]

print(len(train_paths), len(val_paths))

In [None]:
# Visualizing the dataset
_, axes = plt.subplots(4, 4, figsize=(10, 10))
for ax, img_path in zip(axes.flatten(), train_paths):
    ax.imshow(Image.open(img_path))
    ax.axis("off")

In [None]:
SIZE = 256
class ColorizationDataset(Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([                  
                transforms.Resize((SIZE, SIZE),  Image.BICUBIC),
                transforms.RandomHorizontalFlip(), 
            ])
        elif split == 'val':
            self.transforms = transforms.Resize((SIZE, SIZE),  Image.BICUBIC)
        
        self.split = split
        self.size = SIZE
        self.paths = paths
    
    def __getitem__(self, idx): # function for getting L and ab tensors of an image idx
        img = Image.open(self.paths[idx]).convert("RGB") # converting images to RGB to tackle any grayscale image if present.
        img = self.transforms(img)
        img = np.array(img) 
        img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b; (h,w,c)
        img_lab = transforms.ToTensor()(img_lab) # Converting Lab img to tensor; also converts it to (c, h, w)
        L = img_lab[[0], ...] / 50. - 1.
        ab = img_lab[[1, 2], ...] / 110.
        
        return {'L': L, 'ab': ab}
    
    def __len__(self):
        return len(self.paths)

def make_dataloaders(batch_size=16, n_workers=2, pin_memory=True, **kwargs): # a handy function to make our dataloaders
                                                              
    dataset = ColorizationDataset(**kwargs)
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,     
                            pin_memory=pin_memory)
    return dataloader

In [None]:
train_dl = make_dataloaders(paths=train_paths, split='train')
val_dl = make_dataloaders(paths=val_paths, split='val')
print(len(train_dl)) # 8000 / dataloader_batch_size(=16) 
print(len(val_dl))  # 2000 / 16
data = next(iter(train_dl))
Ls, abs_ = data['L'], data['ab']
print(Ls.shape, abs_.shape)

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Generator✨<a id="Model Architecture"></a></h1></span>

![unet](https://raw.githubusercontent.com/prabhanjan-jadhav/image-colorization-using-gans/main/images/UNet%20architechture.png)

In [None]:
# Utility function required to crop a tensor while passing through skip connection
def crop_tensor(tensor, target_tensor):
    target_size = target_tensor.size()[2]
    tensor_size = tensor.size()[2]
    delta = tensor_size - target_size
    delta = delta // 2
    return tensor[
        :,
        :,
        delta:tensor_size - delta,
        delta:tensor_size - delta
    ]

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=2):
        super().__init__()
        self.dconv1 = nn.Conv2d(in_channels, 64, kernel_size=3) 
        self.dconv2 = nn.Conv2d(64, 128, kernel_size=4)
        self.dconv3 = nn.Conv2d(128, 256, kernel_size=3)
        self.dconv4 = nn.Conv2d(256, 512, kernel_size=3)
        self.dconv5 = nn.Conv2d(512, 512, kernel_size=3)

        self.uconv1 = nn.Conv2d(512, 256, kernel_size=3)
        self.uconv2 = nn.Conv2d(256, 128, kernel_size=3)
        self.uconv3 = nn.Conv2d(128, 64, kernel_size=3)
        self.uconv4 = nn.Conv2d(64, 2, kernel_size=1)

        self.maxpool2d = nn.MaxPool2d(2)

        self.trans1 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.trans2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.trans3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)

    def forward(self, image):
        x1 = self.dconv1(image)
        x2 = self.maxpool2d(x1) 

        x3 = self.dconv2(x2) 
        x4 = self.maxpool2d(x3)

        x5 = self.dconv3(x4)    
        x6 = self.maxpool2d(x5)

        x7 = self.dconv4(x6)
        x8 = self.dconv5(x7)

        x9 = self.trans1(x8) 
        y = crop_tensor(x5, x9)
        x9 = torch.cat([x9, y], axis=1)
        x10 = self.uconv1(x9)

        x11 = self.trans2(x10)

        y = crop_tensor(x3, x11)
        x11 = torch.cat([x11, y], axis=1)
        x12 = self.uconv2(x11)

        x13 = self.trans3(x12)
        y = crop_tensor(x1, x13)
        x13 = torch.cat([x13, y], axis=1)
        x14 = self.uconv3(x13)

        out = self.uconv4(x14)
#         print(f"Output image size : {out.size()}")
        return out
    
if __name__ == "__main__":
    model = UNet()
    print(summary(model, (1,1,256,256)))        

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Discriminator </h1></span>

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3,64,4,2,1)
        self.leakyrelu1 = nn.LeakyReLU(0.2, inplace=True)

        self.conv2 = nn.Conv2d(64,128,4,2,1)
        self.bn1 = nn.BatchNorm2d(128)
        self.leakyrelu2 = nn.LeakyReLU(0.2, inplace=True)

        self.conv3 = nn.Conv2d(128,256,4,2,1)
        self.bn2 = nn.BatchNorm2d(256)
        self.leakyrelu3 = nn.LeakyReLU(0.2, inplace=True)

        self.conv4 = nn.Conv2d(256,512,4,1,1)
        self.bn3 = nn.BatchNorm2d(512)
        self.leakyrelu4 = nn.LeakyReLU(0.2, inplace=True)

        self.conv5 = nn.Conv2d(512,1,4,1,1)

    def forward(self, x):
        layers = [self.conv1, self.leakyrelu1, self.conv2, self.bn1, self.leakyrelu2, self.conv3, 
                  self.bn2, self.leakyrelu3, self.conv4, self.bn3, self.leakyrelu4, self.conv5]

        model = nn.Sequential(*layers)
        self.model = nn.Sequential(*model)
        return model(x)

Let's take a look at its blocks:

In [None]:
summary(Discriminator())

And its output shape:

In [None]:
discriminator = Discriminator()
dummy_input = torch.randn(16, 3, 194, 194) # batch_size, channels, size, size
out = discriminator(dummy_input)
print(out.shape)

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Loss Function 🧭<a id="Loss function"></a></h1></span>
* **Loss function we optimize** :
\begin{equation}
G^* = arg min_G max_D L_{cGAN} (G,D) + \lambda L_{L1} (G)
\end{equation}


* **L1 Loss** : 
\begin{equation}
L_{L1} (G) = \mathbb{E}_{x,y,z}[\|y-G(x,z)\|_1]
\end{equation}


* **GAN Loss** : 
\begin{equation}
L_{cGAN}(G,D) = \mathbb{E}_{x,y}[logD(x,y)] + \mathbb{E}_{x,z}[log(1-D(x, G(x,z))]
\end{equation}
Where **x** as the grayscale image, **z** as the input noise for the generator, and **y** as the 2-channel output we want from the generator (it can also represent the 2 color channels of a real image). Also, **G** is the generator model and **D** is the discriminator. 


In [None]:
class GANLoss(nn.Module):
    def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
        super().__init__()
        # If you have parameters in your model, which should be saved and restored
        # in the state_dict, but not trained by the optimizer, you should register them as buffers.
        self.register_buffer('real_label', torch.tensor(real_label))    
        self.register_buffer('fake_label', torch.tensor(fake_label))  
        if gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
    
    def get_labels(self, preds, target_is_real):
        if target_is_real:
            labels = self.real_label
        else:
            labels = self.fake_label
        return labels.expand_as(preds)
    
    def __call__(self, preds, target_is_real):
        labels = self.get_labels(preds, target_is_real)
        loss = self.loss(preds, labels)
        return loss

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Model Initialization⚙️</h1></span>

In [None]:
def init_weights(net, init='norm', gain=0.02):
    
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and 'Conv' in classname:
            if init == 'norm':
                nn.init.normal_(m.weight.data, mean=0.0, std=gain)
            elif init == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=gain)
            elif init == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)
        elif 'BatchNorm2d' in classname:
            nn.init.normal_(m.weight.data, 1., gain)
            nn.init.constant_(m.bias.data, 0.)
            
    net.apply(init_func)
    print(f"model initialized with {init} initialization")
    return net

def init_model(model, device):
    model = model.to(device)
    model = init_weights(model)
    return model

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Model🪄</h1></span>

In [None]:
class MainModel(nn.Module):
    def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4, 
                 beta1=0.5, beta2=0.999, lambda_L1=100.):
        super().__init__()
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lambda_L1 = lambda_L1
        
        if net_G is None:
            self.net_G = init_model(UNet(), self.device)
        else:
            self.net_G = net_G.to(self.device)
        self.net_D = init_model(Discriminator(), self.device)
        self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
        self.L1criterion = nn.L1Loss()
        self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
        self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))
    
    def set_requires_grad(self, model, requires_grad=True):
        for p in model.parameters():
            p.requires_grad = requires_grad
        
    def setup_input(self, data):
        self.L = data['L'].to(self.device)
        self.ab = data['ab'].to(self.device)
        img = torch.cat((self.L, self.ab), 1)
        img = transforms.Resize((194, 194), Image.BICUBIC)(img)
        self.L_resized = img[:, [0], ...]
        
    def forward(self):
        self.fake_color = self.net_G(self.L)
        img = torch.cat((self.L, self.fake_color),1)
        img = transforms.Resize((256, 256), Image.BICUBIC)(img)
        self.fake_color = img[:, [1,2], ...]
    
    def backward_D(self, data):
        fake_image = torch.cat([self.L, self.fake_color], dim=1) # concat side by side
        fake_preds = self.net_D(fake_image.detach()) # detach() : It returns a new tensor that doesn't require a gradient.
        self.loss_D_fake = self.GANcriterion(fake_preds, False)
        
        real_image = torch.cat([self.L, self.ab], dim=1)
        real_preds = self.net_D(real_image)
        self.loss_D_real = self.GANcriterion(real_preds, True)
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()
    
    def backward_G(self, data):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image)
        self.loss_G_GAN = self.GANcriterion(fake_preds, True)

        self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()
    
    def optimize(self):
        self.forward()
        self.net_D.train()
        self.set_requires_grad(self.net_D, True)
        self.opt_D.zero_grad()
        self.backward_D(data)
        self.opt_D.step()
        
        self.net_G.train()
        self.set_requires_grad(self.net_D, False)
        self.opt_G.zero_grad()
        self.backward_G(data)
        self.opt_G.step()

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Utility functions🎈</h1></span>

In [None]:
class AverageMeter:
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.count, self.avg, self.sum = [0.] * 3
    
    def update(self, val, count=1):
        self.count += count
        self.sum += count * val
        self.avg = self.sum / self.count

def create_loss_meters():
    loss_D_fake = AverageMeter()
    loss_D_real = AverageMeter()
    loss_D = AverageMeter()
    loss_G_GAN = AverageMeter()
    loss_G_L1 = AverageMeter()
    loss_G = AverageMeter()
    
    return {'loss_D_fake': loss_D_fake,
            'loss_D_real': loss_D_real,
            'loss_D': loss_D,
            'loss_G_GAN': loss_G_GAN,
            'loss_G_L1': loss_G_L1,
            'loss_G': loss_G}

def update_losses(model, loss_meter_dict, count):
    for loss_name, loss_meter in loss_meter_dict.items():
        loss = getattr(model, loss_name)
        loss_meter.update(loss.item(), count=count)

def lab_to_rgb(L, ab):
    """
    Takes a batch of images
    """
    
    L = (L + 1.) * 50.
    ab = ab * 110.
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    return np.stack(rgb_imgs, axis=0)
    
def visualize(model, data, save=False):
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    model.net_G.train()
    fake_color = model.fake_color.detach()
    real_color = model.ab
    L = model.L
    fake_imgs = lab_to_rgb(L, fake_color)
    real_imgs = lab_to_rgb(L, real_color)
    fig = plt.figure(figsize=(15, 8))
    for i in range(5):
        ax = plt.subplot(3, 5, i + 1)
        ax.imshow(L[i][0].cpu(), cmap='gray')
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 5)
        ax.imshow(fake_imgs[i])
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 10)
        ax.imshow(real_imgs[i])
        ax.axis("off")
    plt.show()
    if save:
        fig.savefig(f"colorization_{time.time()}.png")
        
def log_results(loss_meter_dict):
    for loss_name, loss_meter in loss_meter_dict.items():
        print(f"{loss_name}: {loss_meter.avg:.5f}")

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Training💪</h1></span>

Now to train the model run the below cell. Running for about 20 epochs, taking approx. 5 mins per epoch, would give reasonable results. \
Instead, you can directly download the weights for the model trained for 20 epochs in the following cell.

In [None]:
def train_model(model, train_dl, epochs, display_every=200):
    data = next(iter(val_dl)) # getting a batch for visualizing the model output after fixed intrvals
    for e in range(epochs):
        loss_meter_dict = create_loss_meters() # function returing a dictionary of objects to 
        i = 0                                  # log the losses of the complete network
        for data in tqdm(train_dl):
            model.setup_input(data) 
            model.optimize()
            update_losses(model, loss_meter_dict, count=data['L'].size(0)) # function updating the log objects
            i += 1
            if i % display_every == 0:
                print(f"\nEpoch {e+1}/{epochs}")
                print(f"Iteration {i}/{len(train_dl)}")
                log_results(loss_meter_dict) # function to print out the losses
                visualize(model, data, save=False) # function displaying the model's outputs
        if e%5==0:
            torch.save(model.state_dict(), f"/kaggle/working/model_{e+1}.pt")
            
model = MainModel()
# train_model(model, train_dl, 20)

In [None]:
# downloading weights for the model
!gdown --id 1owvzniVc_PQ3xqNHscGlR_mrbpdy7e_G

In [None]:
# create the model object and load the weights
model = MainModel()
model.load_state_dict(torch.load('/kaggle/working/model_21.pt'))

In [None]:
# Visualize the model
model.eval()

In [None]:
# Generating output on validation dataset
for data in tqdm(val_dl):
    model.setup_input(data)
    model.optimize()
    visualize(model, data, save=False)

This is the end of the 1st way of implementation. 
The following section contains a modified implementation which will improve the results.

<h1 style = "font-size:60px; font-family:Garamond ; font-weight : normal; background-color: #f6f5f5 ; color : #fe346e; text-align: center; border-radius: 100px 100px;">A Sophisticated Approach🚀<a id="Another approach"></a></h1>
This module uses fastai and torchvision models to build the generator model. 

A ResNet pretrained on ImageNet classification task is added as a backbone to the UNet of the generator. The discriminator is the same as before.

In [None]:
# !pip install fastai==2.4
from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18
from fastai.vision.models.unet import DynamicUnet

In [None]:
def build_res_unet(n_input=1, n_output=2, size=256):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    body = create_body(resnet18(), pretrained=True, n_in=n_input, cut=-2)
    net_G = DynamicUnet(body, n_output, (size, size)).to(device)
    return net_G

In [None]:
def pretrain_generator(net_G, train_dl, opt, criterion, epochs):
    ''' Pretraining generator on image colorization task using L1 loss.
    ResNet backbone has pretrained weights'''
    for e in range(epochs):
        loss_meter = AverageMeter()
        for data in tqdm(train_dl):
            L, ab = data['L'].to(device), data['ab'].to(device)
            preds = net_G(L)
            loss = criterion(preds, ab)
            opt.zero_grad()
            loss.backward()
            opt.step()
            
            loss_meter.update(loss.item(), L.size(0))
            
        print(f"Epoch {e + 1}/{epochs}")
        print(f"L1 Loss: {loss_meter.avg:.5f}")

net_G = build_res_unet(n_input=1, n_output=2, size=256)
opt = optim.Adam(net_G.parameters(), lr=1e-4)
criterion = nn.L1Loss()  
pretrain_generator(net_G, train_dl, opt, criterion, 20)
torch.save(net_G.state_dict(), "res18-unet.pt")

In [None]:
'''
Training the entire model for 20 epochs
'''
net_G = build_res_unet(n_input=1, n_output=2, size=256)
net_G.load_state_dict(torch.load("res18-unet.pt", map_location=device))
model = MainModel(net_G=net_G)
train_model(model, train_dl, 20)
torch.save(model.state_dict(), "final_model_weights.pt")

In [None]:
''' Download the weights for trained model
'''
!gdown --id 1UY5a07bVofwAyV7rI8konccnFCJjIKPZ
!gdown --id 1lR6DcS4m5InSbZ5y59zkH2mHt_4RQ2KV

'''

    Build the model object and initialize the weights

'''
net_G = build_res_unet(n_input=1, n_output=2, size=256)
net_G.load_state_dict(torch.load("/kaggle/working/res18-unet.pt", map_location=device))
model = MainModel(net_G=net_G)
model.load_state_dict(torch.load("/kaggle/working/final_model_weights.pt", map_location=device))

In [None]:
'''
Visualize the outputs on validation dataset

'''
for data in tqdm(val_dl):
    model.setup_input(data)
    model.optimize()
    visualize(model, data, save=False)

In [None]:
'''
Generate output on your own set of images

'''
img_path = '/path/to/folder'
print(img_path)
paths = glob.glob(img_path + "/*")
idxs = np.arange(len(paths))

test_dl = make_dataloaders(paths=paths, split='val')
for data in tqdm(test_dl):
  model.setup_input(data)
  model.optimize()
  visualize(model, data, save=False)