In [None]:
!pip install torchviz

In [None]:
! pip install torchsummary

In [None]:
!pip install torch

In [None]:
%matplotlib inline
import torch
import torchvision
import os
import shutil
from PIL import Image, ImageFile
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
import time
import glob
from torch.utils.data import Dataset, DataLoader
from skimage.color import rgb2lab, lab2rgb
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
import torch.nn.functional as F
ImageFile.LOAD_TRUNCATED_IMAGES = True

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
dataset_path = r"D:\imagecol\new"
BATCH_SIZE = 4
INPUT_SHAPE = 256

In [None]:
def DisplayImages(imagepaths):
    num_images = len(imagepaths)
    fig, axes = plt.subplots(1, num_images, figsize=(15, 6))
    for i, ax in enumerate(axes):
        img = Image.open(imagepaths[i])
        # print(img.mode)
        ax.imshow(img)
        ax.axis('off')
    plt.show()


In [None]:
images = os.listdir(dataset_path)
rdx = np.random.randint(0, len(images), 5)
ipath = [dataset_path+'/' + images[i] for i in rdx]

DisplayImages(ipath)

In [None]:
train_idx = int(len(images) * 0.8)
test_idx = int(len(images) * 0.2)
rand_idx = np.random.permutation(len(images))
train_idxs = rand_idx[:train_idx]
test_idxs = rand_idx[train_idx:]

train_images = [os.path.join(dataset_path,images[i]) for i in train_idxs]
test_images = [os.path.join(dataset_path,images[i]) for i in test_idxs]

# print(train_idxs)
print(train_images[0])
print(len(train_images), len(test_images))
DisplayImages(train_images[:5])

In [None]:
class CustomDataset(Dataset):
    def __init__(self, paths, split="train"):
        if split == 'train':
            self.transforms = transforms.Compose(
                [
                    transforms.Resize((INPUT_SHAPE, INPUT_SHAPE), Image.BICUBIC),
                    transforms.RandomHorizontalFlip(),
                ]
            )
        elif split == "val":
            self.transforms = transforms.Resize((INPUT_SHAPE, INPUT_SHAPE), Image.BICUBIC)
        self.split = split
        self.paths = paths
        self.size = BATCH_SIZE

    def __getitem__(self, idx):
        image = Image.open(self.paths[idx]).convert("RGB")
        image = self.transforms(image)
        img = np.array(image)
        img_lab = rgb2lab(img).astype("float32")
        img_lab = transforms.ToTensor()(img_lab)
        # print(img_lab.shape)
        # print(img_lab)
        L = img_lab[[0],...]/50.0 - 1.0
        ab = img_lab[[1,2], ...]/110.0
        return {"L": L, "ab":ab}
    def __len__(self):
        return len(self.paths)

def make_dataloader(batch=8, pin_memory = True, **kwargs):
    dataset = CustomDataset(**kwargs)
    dataloader = DataLoader(dataset, batch_size = batch, pin_memory = pin_memory)
    return dataloader

In [None]:
traindl = make_dataloader(paths = train_images, split = "train")
testdl = make_dataloader(paths = test_images, split = "val")

In [None]:
data = next(iter(traindl))
# print(data["ab"].shape)

In [None]:
grid_image = make_grid(data["L"], nrow=4, padding=2, pad_value=1)
plt.imshow(grid_image.permute(1, 2, 0))
plt.axis('off')
plt.show()

In [None]:
for i in range(8):
    plt.figure(figsize=(8, 4))
    for j in range(2):
        plt.subplot(1, 2, j + 1)
        plt.imshow(data["ab"][i, j])
        plt.title(f'Channel {j+1}')
        plt.axis('off')
    plt.suptitle(f'Image {i+1}')
    plt.show()

In [None]:
# # class ASSP(nn.Module):
# #   def __init__(self,in_channels,out_channels = 256):
# #     super(ASSP,self).__init__()


# #     self.relu = nn.ReLU(inplace=True)

# #     self.conv1 = nn.Conv2d(in_channels = in_channels,
# #                           out_channels = out_channels,
# #                           kernel_size = 1,
# #                           padding = 0,
# #                           dilation=1,
# #                           bias=False)

# #     self.bn1 = nn.BatchNorm2d(out_channels)

# #     self.conv2 = nn.Conv2d(in_channels = in_channels,
# #                           out_channels = out_channels,
# #                           kernel_size = 3,
# #                           stride=1,
# #                           padding = 6,
# #                           dilation = 6,
# #                           bias=False)

# #     self.bn2 = nn.BatchNorm2d(out_channels)

# #     self.conv3 = nn.Conv2d(in_channels = in_channels,
# #                           out_channels = out_channels,
# #                           kernel_size = 3,
# #                           stride=1,
# #                           padding = 12,
# #                           dilation = 12,
# #                           bias=False)

# #     self.bn3 = nn.BatchNorm2d(out_channels)

# #     self.conv4 = nn.Conv2d(in_channels = in_channels,
# #                           out_channels = out_channels,
# #                           kernel_size = 3,
# #                           stride=1,
# #                           padding = 18,
# #                           dilation = 18,
# #                           bias=False)

# #     self.bn4 = nn.BatchNorm2d(out_channels)

# #     self.conv5 = nn.Conv2d(in_channels = in_channels,
# #                           out_channels = out_channels,
# #                           kernel_size = 1,
# #                           stride=1,
# #                           padding = 0,
# #                           dilation=1,
# #                           bias=False)

# #     self.bn5 = nn.BatchNorm2d(out_channels)

# #     self.convf = nn.Conv2d(in_channels = out_channels * 5,
# #                           out_channels = out_channels,
# #                           kernel_size = 1,
# #                           stride=1,
# #                           padding = 0,
# #                           dilation=1,
# #                           bias=False)

# #     self.bnf = nn.BatchNorm2d(out_channels)

# #     self.adapool = nn.AdaptiveAvgPool2d(1)


# #   def forward(self,x):
# #     x1 = self.conv1(x)
# #     x1 = self.bn1(x1)
# #     x1 = self.relu(x1)

# #     x2 = self.conv2(x)
# #     x2 = self.bn2(x2)
# #     x2 = self.relu(x2)

# #     x3 = self.conv3(x)
# #     x3 = self.bn3(x3)
# #     x3 = self.relu(x3)

# #     x4 = self.conv4(x)
# #     x4 = self.bn4(x4)
# #     x4 = self.relu(x4)

# #     x5 = self.adapool(x)
# #     x5 = self.conv5(x5)
# #     x5 = self.bn5(x5)
# #     x5 = self.relu(x5)
# #     x5 = F.interpolate(x5, size = tuple(x4.shape[-2:]), mode='bilinear')

# #     x = torch.cat((x1,x2,x3,x4,x5), dim = 1) #channels first
# #     x = self.convf(x)
# #     x = self.bnf(x)
# #     x = self.relu(x)

# #     return x



class ASSP(nn.Module):
    def __init__(self, in_channels, out_channels=256, final_out_channels=2):
        super(ASSP, self).__init__()

        self.relu = nn.ReLU(inplace=True)

        # 1x1 convolution
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)

        # 3x3 convolutions with different dilation rates
        self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=3, dilation=3, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=6, dilation=6, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels)

        self.conv4 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=9, dilation=9, bias=False)
        self.bn4 = nn.BatchNorm2d(out_channels)

        # 1x1 convolution after global average pooling
        self.conv5 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
        self.bn5 = nn.BatchNorm2d(out_channels)

        # Final 1x1 convolution to combine features
        self.convf = nn.Conv2d(out_channels * 5, final_out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
        self.bnf = nn.BatchNorm2d(final_out_channels)

        # Global average pooling
        self.adapool = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        # 1x1 convolution
        x1 = self.conv1(x)
        x1 = self.bn1(x1)
        x1 = self.relu(x1)

        # 3x3 convolution with dilation 6
        x2 = self.conv2(x)
        x2 = self.bn2(x2)
        x2 = self.relu(x2)

        # 3x3 convolution with dilation 12
        x3 = self.conv3(x)
        x3 = self.bn3(x3)
        x3 = self.relu(x3)

        # 3x3 convolution with dilation 18
        x4 = self.conv4(x)
        x4 = self.bn4(x4)
        x4 = self.relu(x4)

        # Global average pooling, 1x1 convolution, and upsample
        x5 = self.adapool(x)
        x5 = self.conv5(x5)
        x5 = self.bn5(x5)
        x5 = self.relu(x5)
        x5 = F.interpolate(x5, size=x4.shape[-2:], mode='bilinear', align_corners=True)

        # Concatenate all feature maps
        x = torch.cat((x1, x2, x3, x4, x5), dim=1)

        # Final 1x1 convolution
        x = self.convf(x)
        x = self.bnf(x)
        x = self.relu(x)

        return x

# import math
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from models.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d

# class _ASPPModule(nn.Module):
#     def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm):
#         super(_ASPPModule, self).__init__()
#         self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
#                                             stride=1, padding=padding, dilation=dilation, bias=False)
#         self.bn = BatchNorm(planes)
#         self.relu = nn.ReLU()

#         self._init_weight()

#     def forward(self, x):
#         x = self.atrous_conv(x)
#         x = self.bn(x)

#         return self.relu(x)

#     def _init_weight(self):
#         for m in self.modules():
#             if isinstance(m, nn.Conv2d):
#                 torch.nn.init.kaiming_normal_(m.weight)
#             elif isinstance(m, SynchronizedBatchNorm2d):
#                 m.weight.data.fill_(1)
#                 m.bias.data.zero_()
#             elif isinstance(m, nn.BatchNorm2d):
#                 m.weight.data.fill_(1)
#                 m.bias.data.zero_()

# class ASPP(nn.Module):
#     def __init__(self, backbone, output_stride, BatchNorm):
#         super(ASPP, self).__init__()
#         if backbone == 'drn':
#             inplanes = 512
#         elif backbone == 'mobilenet':
#             inplanes = 320
#         else:
#             inplanes = 2048
#         if output_stride == 16:
#             dilations = [1, 6, 12, 18]
#         elif output_stride == 8:
#             dilations = [1, 12, 24, 36]
#         else:
#             raise NotImplementedError

#         self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)
#         self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm)
#         self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm)
#         self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm)

#         self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
#                                              nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
#                                              BatchNorm(256),
#                                              nn.ReLU())
#         self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
#         self.bn1 = BatchNorm(256)
#         self.relu = nn.ReLU()
#         self.dropout = nn.Dropout(0.5)
#         self._init_weight()

#     def forward(self, x):
#         x1 = self.aspp1(x)
#         x2 = self.aspp2(x)
#         x3 = self.aspp3(x)
#         x4 = self.aspp4(x)
#         x5 = self.global_avg_pool(x)
#         x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
#         x = torch.cat((x1, x2, x3, x4, x5), dim=1)

#         x = self.conv1(x)
#         x = self.bn1(x)
#         x = self.relu(x)

#         return self.dropout(x)

#     def _init_weight(self):
#         for m in self.modules():
#             if isinstance(m, nn.Conv2d):
#                 # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
#                 # m.weight.data.normal_(0, math.sqrt(2. / n))
#                 torch.nn.init.kaiming_normal_(m.weight)
#             elif isinstance(m, SynchronizedBatchNorm2d):
#                 m.weight.data.fill_(1)
#                 m.bias.data.zero_()
#             elif isinstance(m, nn.BatchNorm2d):
#                 m.weight.data.fill_(1)
#                 m.bias.data.zero_()


# def build_aspp(backbone, output_stride, BatchNorm):
#     return ASPP(backbone, output_stride, BatchNorm)

In [None]:
from torchvision import models
import torch.nn as nn

# class ResNet_50 (nn.Module):
#   def __init__(self, in_channels = 1, conv1_out = 64):
#     super(ResNet_50,self).__init__()

#     self.resnet_50 = models.resnet50(pretrained = True)

#     self.relu = nn.ReLU(inplace=True)

# #   def forward(self,x):
# #     x = self.relu(self.resnet_50.bn1(self.resnet_50.conv1(x)))
# #     x = self.resnet_50.maxpool(x)
# #     x = self.resnet_50.layer1(x)
# #     x = self.resnet_50.layer2(x)
# #     x = self.resnet_50.layer3(x)

# #     return x

class ResNet_50(nn.Module):
    def __init__(self, in_channels=1):
        super(ResNet_50, self).__init__()

        # Load the pre-trained ResNet-50 model
        self.resnet_50 = models.resnet50(pretrained=True)

        # Modify the first convolutional layer to accept 1-channel input
        self.resnet_50.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Use the layers up to the final layer before the fully connected layer
        self.resnet_50 = nn.Sequential(*list(self.resnet_50.children())[:-2])
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.resnet_50(x)
        return x


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Decoder(nn.Module):
    def __init__(self, num_classes, backbone, BatchNorm):
        super(Decoder, self).__init__()
        if backbone == 'resnet' or backbone == 'drn':
            low_level_inplanes = 256
        elif backbone == 'xception':
            low_level_inplanes = 128
        elif backbone == 'mobilenet':
            low_level_inplanes = 24
        else:
            raise NotImplementedError

        self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
        self.bn1 = BatchNorm(48)
        self.relu = nn.ReLU()
        self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                       BatchNorm(256),
                                       nn.ReLU(),
                                       nn.Dropout(0.5),
                                       nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                       BatchNorm(256),
                                       nn.ReLU(),
                                       nn.Dropout(0.1),
                                       nn.Conv2d(256, num_classes, kernel_size=1, stride=1))
        self._init_weight()

    def forward(self, x, low_level_feat):
        low_level_feat = self.conv1(low_level_feat)
        low_level_feat = self.bn1(low_level_feat)
        low_level_feat = self.relu(low_level_feat)

        # Upsample x to match the size of low_level_feat
        x = F.interpolate(x, size=(512, 512), mode='bilinear', align_corners=True)
        low_level_feat = F.interpolate(low_level_feat, size=(512, 512), mode='bilinear', align_corners=True)

        x = torch.cat((x, low_level_feat), dim=1)
        x = self.last_conv(x)

        return x

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

def build_decoder(num_classes, backbone, BatchNorm):
    return Decoder(num_classes, backbone, BatchNorm)

In [None]:
# class deeplabv3(nn.Module):
#     def __init__(self, input_channels=1, output_channels=2):
#         super(deeplabv3, self).__init__()
#         self.resnet = ResNet_50(in_channels=input_channels)
#         self.aspp = ASSP(in_channels=2048, final_out_channels=output_channels)
#         self.conv = nn.Conv2d(in_channels=2, out_channels=output_channels, kernel_size=1, stride=1, padding=0)

#     def forward(self, x):
#         _, _, h, w = x.shape
#         x = self.resnet(x)  # Output should be [batch_size, 2048, H/32, W/32]RuntimeError: Given groups=1, weight of size [256, 2048, 1, 1], expected input[2, 1024, 32, 32] to have 2048 channels, but got 1024 channels instead, 32, 32] to have 2048 channels, but got 1024 channels instead
#         x = self.aspp(x)
#         x = self.conv(x)
#         x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
#         return x



class deeplabv3_encoder_decoder(nn.Module):
    def __init__(self, input_channels=1, output_channels=2):
        super(deeplabv3_encoder_decoder, self).__init__()
        self.resnet = ResNet_50(in_channels=input_channels)
        self.aspp = ASSP(in_channels=2048, final_out_channels=1024)

        # Decoder layers
        # self.decoder = nn.Sequential(
        #     nn.Conv2d(2, 64, kernel_size=3, stride=1, padding=1),
        #     nn.BatchNorm2d(64),
        #     nn.ReLU(inplace=True),
        #     nn.Conv2d(64, output_channels, kernel_size=3, stride=1, padding=1)
        # )
    
        self.decoder = nn.Sequential(
                nn.ConvTranspose2d(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.BatchNorm2d(512),
                nn.ReLU(inplace=True),
                nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True),
                nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU(inplace=True),
                nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=True),
                nn.ConvTranspose2d(64, 2, kernel_size=3, stride=2, padding=1, output_padding=1),
                # nn.BatchNorm2d(32),
                # nn.ReLU(inplace=True),
                # nn.ConvTranspose2d(32, 2, kernel_size=3, stride=2, padding=1, output_padding=1),
                # nn.Sigmoid()  # Assuming the input images are normalized between 0 and 1
            )

    def forward(self, x):
        _, _, h, w = x.shape
        x = self.resnet(x)  # Output should be [batch_size, 2048, H/32, W/32]
        x = self.aspp(x)
        # x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)  # Upsample
        # print(x.shape)
        x = self.decoder(x)  # Decode
        return x



In [None]:
import torch
from torchsummary import summary

model = deeplabv3_encoder_decoder().to("cuda")

summary(model, input_size=(1, 512, 512))

In [None]:
import torch
import gc

# Your main code here
# ...

# Clear unused variables
# del variable1, variable2, variable3

# Manually run garbage collection
gc.collect()

# Clear GPU cache if using PyTorch
torch.cuda.empty_cache()


In [None]:
class PatchDiscriminator(nn.Module):
    def __init__(self, input_c, num_filters=64, n_down=3):
        super().__init__()
        model = [self.get_layers(input_c, num_filters, norm=False)]
        model += [
            self.get_layers(
                num_filters * 2**i,
                num_filters * 2 ** (i + 1),
                s=1 if i == (n_down - 1) else 2,
            )
            for i in range(n_down)
        ]  # the 'if' statement is taking care of not using
        # stride of 2 for the last block in this loop
        model += [
            self.get_layers(num_filters * 2**n_down, 1, s=1, norm=False, act=False)
        ]  # Make sure to not use normalization or
        # activation for the last layer of the model
        self.model = nn.Sequential(*model)

    def get_layers(
        self, ni, nf, k=4, s=2, p=1, norm=True, act=True
    ):  # when needing to make some repeatitive blocks of layers,
        layers = [
            nn.Conv2d(ni, nf, k, s, p, bias=not norm)
        ]  # it's always helpful to make a separate method for that purpose
        if norm:
            layers += [nn.BatchNorm2d(nf)]
        if act:
            layers += [nn.LeakyReLU(0.2, True)]
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [None]:
class GANLoss(nn.Module):
    def __init__(self, gan_mode="vanilla", real_label=1.0, fake_label=0.0):
        super().__init__()
        self.register_buffer("real_label", torch.tensor(real_label))
        self.register_buffer("fake_label", torch.tensor(fake_label))
        if gan_mode == "vanilla":
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode == "lsgan":
            self.loss = nn.MSELoss()

    def get_labels(self, preds, target_is_real):
        if target_is_real:
            labels = self.real_label
        else:
            labels = self.fake_label
        return labels.expand_as(preds)

    def __call__(self, preds, target_is_real):
        labels = self.get_labels(preds, target_is_real)
        loss = self.loss(preds, labels)
        return loss

In [None]:
def init_weights(net, init="norm", gain=0.02):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, "weight") and "Conv" in classname:
            if init == "norm":
                nn.init.normal_(m.weight.data, mean=0.0, std=gain)
            elif init == "xavier":
                nn.init.xavier_normal_(m.weight.data, gain=gain)
            elif init == "kaiming":
                nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")

            if hasattr(m, "bias") and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)
        elif "BatchNorm2d" in classname:
            nn.init.normal_(m.weight.data, 1.0, gain)
            nn.init.constant_(m.bias.data, 0.0)

    net.apply(init_func)
    print(f"model initialized with {init} initialization")
    return net


def init_model(model, device):
    model = model.to(device)
    model = init_weights(model)
    return model

In [None]:
class MainModel(nn.Module):
    def __init__(
        self, net_G=None, lr_G=2e-4, lr_D=2e-4, beta1=0.5, beta2=0.999, lambda_L1=100.0
    ):
        super().__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lambda_L1 = lambda_L1

        if net_G is None:
            self.net_G = init_model(
                deeplabv3_encoder_decoder(),
                self.device
            )
        else:
            self.net_G = net_G.to(self.device)

        self.net_D = init_model(
            PatchDiscriminator(input_c=3, num_filters=64, n_down=3),
            self.device
        )

        self.GANcriterion = GANLoss(gan_mode="vanilla").to(self.device)
        self.L1criterion = nn.L1Loss()
        self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
        self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))

    def set_requires_grad(self, model, requires_grad=True):
        for p in model.parameters():
            p.requires_grad = requires_grad

    def setup_input(self, data):
        self.L = data["L"].to(self.device)
        self.ab = data["ab"].to(self.device)

    def forward(self):
        self.fake_color = self.net_G(self.L)

    def backward_D(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image.detach())
        self.loss_D_fake = self.GANcriterion(fake_preds, False)
        real_image = torch.cat([self.L, self.ab], dim=1)
        real_preds = self.net_D(real_image)
        self.loss_D_real = self.GANcriterion(real_preds, True)
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()

    def backward_G(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image)
        self.loss_G_GAN = self.GANcriterion(fake_preds, True)
        self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()

    def optimize(self):
        self.forward()
        self.net_D.train()
        self.set_requires_grad(self.net_D, True)
        self.opt_D.zero_grad()
        self.backward_D()
        self.opt_D.step()

        self.net_G.train()
        self.set_requires_grad(self.net_D, False)
        self.opt_G.zero_grad()
        self.backward_G()
        self.opt_G.step()

In [None]:
class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.count, self.avg, self.sum = [0.0] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += count * val
        self.avg = self.sum / self.count


def create_loss_meters():
    loss_D_fake = AverageMeter()
    loss_D_real = AverageMeter()
    loss_D = AverageMeter()
    loss_G_GAN = AverageMeter()
    loss_G_L1 = AverageMeter()
    loss_G = AverageMeter()

    return {
        "loss_D_fake": loss_D_fake,
        "loss_D_real": loss_D_real,
        "loss_D": loss_D,
        "loss_G_GAN": loss_G_GAN,
        "loss_G_L1": loss_G_L1,
        "loss_G": loss_G,
    }


def update_losses(model, loss_meter_dict, count):
    for loss_name, loss_meter in loss_meter_dict.items():
        loss = getattr(model, loss_name)
        loss_meter.update(loss.item(), count=count)


def lab_to_rgb(L, ab):
    """
    Takes a batch of images
    """

    L = (L + 1.0) * 50.0
    ab = ab * 110.0
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    return np.stack(rgb_imgs, axis=0)


def visualize(model, data, save=True):
    print("Started Visualizing")
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    model.net_G.train()
    fake_color = model.fake_color.detach()
    real_color = model.ab
    L = model.L
    fake_imgs = lab_to_rgb(L, fake_color)
    real_imgs = lab_to_rgb(L, real_color)
    fig = plt.figure(figsize=(15, 8))
    for i in range(5):
        ax = plt.subplot(3, 5, i + 1)
        ax.imshow(L[i][0].cpu(), cmap="gray")
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 5)
        ax.imshow(fake_imgs[i])
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 10)
        ax.imshow(real_imgs[i])
        ax.axis("off")
    plt.show()
    if save:
        fig.savefig(f"colorization_{time.time()}.png")


def log_results(loss_meter_dict):
    for loss_name, loss_meter in loss_meter_dict.items():
        print(f"{loss_name}: {loss_meter.avg:.5f}")


def create_lab_tensors(image):
    """
    This function receives an image path or a direct image input and creates a dictionary of L and ab tensors.
    Args:
    - image: either a path to the image file or a direct image input.
    Returns:
    - lab_dict: dictionary containing the L and ab tensors.
    """
    if isinstance(image, str):
        # Open the image and convert it to RGB format
        img = Image.open(image).convert("RGB")
    else:
        img = image.convert("RGB")

    custom_transforms = transforms.Compose(
        [
            transforms.Resize((INPUT_SHAPE, INPUT_SHAPE), Image.BICUBIC),
            transforms.RandomHorizontalFlip(),  # A little data augmentation!
        ]
    )
    img = custom_transforms(img)
    img = np.array(img)
    img_lab = rgb2lab(img).astype("float32")  # Converting RGB to L*a*b
    img_lab = transforms.ToTensor()(img_lab)
    L = img_lab[[0], ...] / 50.0 - 1.0  # Between -1 and 1
    L = L.unsqueeze(0)
    ab = img_lab[[1, 2], ...] / 110.0  # Between -1 and 1
    return {"L": L, "ab": ab}


def predict_and_visualize_single_image(model, data, save=True):
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    fake_color = model.fake_color.detach()
    L = model.L
    fake_imgs = lab_to_rgb(L, fake_color)
    fig, axs = plt.subplots(1, 2, figsize=(8, 4))
    axs[0].imshow(L[0][0].cpu(), cmap="gray")
    axs[0].set_title("Grey Image")
    axs[0].axis("off")

    axs[1].imshow(fake_imgs[0])
    axs[1].set_title("Colored Image")
    axs[1].axis("off")
    plt.show()
    if save:
        fig.savefig(f"colorization_{time.time()}.png")


def predict_color(model, image, save=False):
    """
    This function receives an image path or a direct image input and creates a dictionary of L and ab tensors.
    Args:
    - model : Pytorch Gray Scale to Colorization Model
    - image: either a path to the image file or a direct image input.
    """
    data = create_lab_tensors(image)
    predict_and_visualize_single_image(model, data, save)


def predict_and_return_image(image):
    data = create_lab_tensors(image)
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    fake_color = model.fake_color.detach()
    L = model.L
    fake_imgs = lab_to_rgb(L, fake_color)
    return fake_imgs[0]

In [None]:
model_path = "./ImageColorizationModel.pth"


def save_model(model, file_path):
    """
    Save PyTorch model to file.

    Args:
        model (torch.nn.Module): PyTorch model to save.
        file_path (str): File path to save the model to.
    """
    torch.save(model.state_dict(), file_path)


def load_model(model_class, file_path):
    """
    Load PyTorch model from file.

    Args:
        model_class (torch.nn.Module): PyTorch model class to load.
        file_path (str): File path to load the model from.

    Returns:
        model (torch.nn.Module): Loaded PyTorch model.
    """
    model = model_class()
    model.load_state_dict(torch.load(file_path))
    return model

In [None]:
from tqdm import tqdm
def train_model(model, train_dl, epochs, display_every=200):
    data = next(
        iter(testdl)
    )  # getting a batch for visualizing the model output after fixed intrvals
    print("started")
    for e in range(epochs):
        # print("Inprogress")
        loss_meter_dict = (
            create_loss_meters()
        )  # function returing a dictionary of objects to
        i = 0  # log the losses of the complete network
        for data in tqdm(train_dl):
            # print("tqdm")
            model.setup_input(data)
            model.optimize()
            update_losses(
                model, loss_meter_dict, count=data["L"].size(0)
            )  # function updating the log objects
            i += 1
            if i % display_every == 0:
                print(f"\nEpoch {e+1}/{epochs}")
                print(f"Iteration {i}/{len(train_dl)}")
                log_results(loss_meter_dict)  # function to print out the losses
                visualize(
                    model, data, save=False
                )  # function displaying the model's outputs
        save_model(model=model, file_path=f"ImageColorizationModel{e}.pth")
                


model = None
if not os.path.exists(model_path):
    print("Model not find")
    model = MainModel()
    train_model(model, traindl, 100)
    save_model(model=model, file_path="ImageColorizationModel.pth")
else:
    model = load_model(model_class=MainModel, file_path=model_path)
    print("Model Loaded")