# ***PART 1: DATA PROCESSING***
- Download of the COCO Google Dataset
- Arrange data inside Datasets and Dataloaders


In [None]:
#Import Libraries

from PIL import Image
import os
#!pip install fiftyone
import fiftyone.zoo as foz
import fiftyone as fo
import numpy as np
import pickle
import cv2
import matplotlib.pyplot as plt
import wandb
#Importing Libraries related to Pytorch
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchviz import make_dot

from tqdm import tqdm

In [None]:
#### Importing the COCO Google Dataset with fiftyone

SAMPLE_TRAIN = 10000
SAMPLE_TEST = 500

print("#"*51 + "IMPORTING COCO DATASET" + "#"*51)
print("#"*51 + "LOADING TRAINING SET" +"#"*51)
train_dataset = foz.load_zoo_dataset(
    "coco-2017",
    split = "train",
    max_samples = SAMPLE_TRAIN,
    shuffle=True,
)
train_dataset.persistent = True


print("#"*51 + "LOADING TEST SET" + "#"*51)
test_dataset = foz.load_zoo_dataset(
    "coco-2017",
    split="test",
    max_samples = SAMPLE_TEST,
)
test_dataset.persistent = True
print('#'*51 + "COCO DATASET IMPORTED" + "#"*51)

In [None]:

class CocoDataset(Dataset):
  #Constructor
  def __init__(self, root:str, color_space: str="RGB", size_limit=None, transform=None):
    self.paths = [root+fname for fname in os.listdir(root)]
    if size_limit != None:
      self.paths = self.paths[:size_limit]
    self.transform = transform
    self.color_space = color_space

  def __len__(self):
    return len(self.paths)

  def __getitem__(self, index):
    sel_image_path = self.paths[index]
    sel_image = cv2.imread(sel_image_path)

    if self.transform != None:
      sel_image = self.transform(sel_image)

    #Firstly we transform the image into an np array
    sel_image = np.array(sel_image)
    #Since the images are saved in te BRG (not RGB) space
    if self.color_space == "RGB":
      sel_image = cv2.cvtColor(sel_image, cv2.COLOR_BGR2RGB) #original
      bw_image = cv2.cvtColor(sel_image, cv2.COLOR_RGB2GRAY) #black and white
      target_image = sel_image
    elif self.color_space == "Lab":
      sel_image = cv2.cvtColor(sel_image, cv2.COLOR_BGR2Lab)
      bw_image = sel_image[:,:,[0]] #only the first channel --> brightness one (bw)
      target_image = sel_image[:,:,[1,2]]

    #At this point we transform everything into Tensor
    target_image = transforms.ToTensor()(target_image)
    bw_image = transforms.ToTensor()(bw_image)
    #Then we do
    target_image = 2.0 * target_image - 1.0 #we do some changes
    bw_image = 2.0 * bw_image - 1.0
    #Finally we return the (input image to the network = BW, and the label image = COLOR)
    return bw_image, target_image

In [None]:
IMG_SIZE = 128 #size of images
BATCH_SIZE = 8
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

#Transformations for the Images

train_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip()
])

test_transforms = transforms.Compose([
    transforms.ToPILImage(), #why?
    transforms.Resize((IMG_SIZE, IMG_SIZE))
])

cars = True #using the car dataset or not
#Creating the Datasets (RGB)
if cars:
    train_dataRGB = CocoDataset(root='C:/Users/alber/fiftyone/cars_dataset/train/', color_space = "RGB", size_limit = SAMPLE_TRAIN, transform=train_transforms)
    test_dataRGB = CocoDataset(root='C:/Users/alber/fiftyone/cars_dataset/test/', color_space = 'RGB', size_limit = SAMPLE_TEST, transform=test_transforms)
else:
    train_dataRGB = CocoDataset(root='C:/Users/alber/fiftyone/coco-2017/train/data/', color_space = "RGB", size_limit = SAMPLE_TRAIN, transform=train_transforms)
    test_dataRGB = CocoDataset(root='C:/Users/alber/fiftyone/coco-2017/test/data/', color_space = 'RGB', size_limit = SAMPLE_TEST, transform=test_transforms)

#Creating the Dataloaders
train_dl_RGB = DataLoader(train_dataRGB, batch_size=BATCH_SIZE, num_workers=0, pin_memory=True, shuffle=True)
test_dl_RGB = DataLoader(test_dataRGB, batch_size=BATCH_SIZE, num_workers=0, pin_memory=True, shuffle=False)

print("RGB DATASET")
print(f"Training Set has: {len(train_dataRGB)} images")
print(f"Test Set has: {len(test_dataRGB)} images")


#Creating the Datasets (Lab)
if cars:
    train_dataLAB = CocoDataset(root='C:/Users/alber/fiftyone/cars_dataset/train/', color_space = "Lab", size_limit = SAMPLE_TRAIN, transform=train_transforms)
    test_dataLAB = CocoDataset(root='C:/Users/alber/fiftyone/cars_dataset/test/', color_space = 'Lab', size_limit = SAMPLE_TEST, transform=test_transforms)
else:
    train_dataLAB = CocoDataset(root='C:/Users/alber/fiftyone/coco-2017/train/data/', color_space = "Lab", size_limit = SAMPLE_TRAIN, transform=train_transforms)
    test_dataLAB = CocoDataset(root='C:/Users/alber/fiftyone/coco-2017/test/data/', color_space = 'Lab', size_limit = SAMPLE_TEST, transform=test_transforms)
#Creating the Dataloaders
train_dl_LAB = DataLoader(train_dataLAB, batch_size=BATCH_SIZE, num_workers=0, pin_memory=True, shuffle=True)
test_dl_LAB = DataLoader(test_dataLAB, batch_size=BATCH_SIZE, num_workers=0, pin_memory=True, shuffle=False)

print("\n")
print("LAB DATASET")
print(f"Training Set has: {len(train_dataLAB)} images")
print(f"Test Set has: {len(test_dataLAB)} images")

In [None]:
def visualize(model, batch, color_space="RGB", images_per_row=8, save = False ,filename = 'images_results'):
    model.generator.eval()
    with torch.no_grad():
        x, label = batch
        x = x.to(model.device)
        output = model.generator(x)

        x = x.cpu()
        label = label.cpu()
        output = output.detach().cpu()

        x = (x + 1.0) / 2.0
        output = (output + 1.0) / 2.0
        label = (label + 1.0) / 2.0
    model.generator.train()

    if color_space == "Lab":
        generated_images = torch.cat([x, output], dim=1)
        true_images = torch.cat([x, label], dim=1)
    else:
        generated_images = output
        true_images = label

    num_images = generated_images.shape[0]
    num_rows = int(np.ceil(num_images / images_per_row))

    fig, axs = plt.subplots(3 * num_rows, images_per_row, figsize=(15, 8 * num_rows))
    fig.subplots_adjust(wspace=0.1, hspace=0.1)

    for row in range(num_rows):
        start_idx = row * images_per_row
        end_idx = min((row + 1) * images_per_row, num_images)

        for i in range(start_idx, end_idx):
            inp, img, true_img = x[i], generated_images[i], true_images[i]

            if color_space == "Lab":
                img = transforms.ToPILImage()(img)
                img = np.array(img)
                img = cv2.cvtColor(img, cv2.COLOR_Lab2RGB)
                true_img = transforms.ToPILImage()(true_img)
                true_img = np.array(true_img)
                true_img = cv2.cvtColor(true_img, cv2.COLOR_Lab2RGB)

            axs[3 * row, i - start_idx].imshow(transforms.ToPILImage()(inp), cmap='gray')
            axs[3 * row + 1, i - start_idx].imshow(transforms.ToPILImage()(img))
            axs[3 * row + 2, i - start_idx].imshow(transforms.ToPILImage()(true_img))
            axs[3 * row, i - start_idx].axis("off")
            axs[3 * row + 1, i - start_idx].axis("off")
            axs[3 * row + 2, i - start_idx].axis("off")
            
    if save:
        plt.savefig('C:/Users/alber/Desktop/MSc ICT Internet Multimedia Engineering/'+filename+'.png')
    plt.show()



#The following function is used to save the output images instead of displaying them in the output and was used to produce the final report
def save_outputs(model, batch, color_space="RGB", save=True, folder_path='C:/Users/alber/Desktop/MSc ICT Internet Multimedia Engineering/outputs'):
    model.generator.eval()
    with torch.no_grad():
        x, label = batch
        x = x.to(model.device)
        output = model.generator(x)

        x = x.cpu()
        label = label.cpu()
        output = output.detach().cpu()

        x = (x + 1.0) / 2.0
        output = (output + 1.0) / 2.0
        label = (label + 1.0) / 2.0
    model.generator.train()

    if color_space == "Lab":
        generated_images = torch.cat([x, output], dim=1)
        true_images = torch.cat([x, label], dim=1)
    else:
        generated_images = output
        true_images = label

    # Create folder if it doesn't exist
    os.makedirs(folder_path, exist_ok=True)

    for i in range(generated_images.shape[0]):
        inp, img, true_img = x[i], generated_images[i], true_images[i]

        if color_space == "Lab":
            img = transforms.ToPILImage()(img)
            img = np.array(img)
            img = cv2.cvtColor(img, cv2.COLOR_Lab2RGB)
            true_img = transforms.ToPILImage()(true_img)
            true_img = np.array(true_img)
            true_img = cv2.cvtColor(true_img, cv2.COLOR_Lab2RGB)

        # Save images as PNG files in the specified folder
        inp_filename = os.path.join(folder_path, f'input_{i}.png')
        img_filename = os.path.join(folder_path, f'generated_{i}.png')
        true_img_filename = os.path.join(folder_path, f'true_{i}.png')

        transforms.ToPILImage()(inp).save(inp_filename)
        transforms.ToPILImage()(img).save(img_filename)
        transforms.ToPILImage()(true_img).save(true_img_filename)

    print(f"Images saved in {folder_path}")



# ***PART 2: MODELS***
**Definition of the various models which will be used as generators:**
- U-net
- Autoencoder
- ResNet

In [None]:
REAL_LABELS = 1.
FAKE_LABELS = 0.

### Functions for the Initialization of the Model and its Weights

def init_weights(net, init='norm', gain=0.02, device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
    # function from the repository of the authors of the paper
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and 'Conv' in classname:
            if init == 'norm':
                nn.init.normal_(m.weight.data, mean=0.0, std=gain)
            elif init == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=gain)
            elif init == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')

            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)
        elif 'BatchNorm2d' in classname:
            nn.init.normal_(m.weight.data, 1., gain)
            nn.init.constant_(m.bias.data, 0.)

    net.apply(init_func)
    return net

def init_model(model, device):
    model = model.to(device)
    model = init_weights(model)
    return model

**U-NET**

In [None]:
class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, use_bias=False, use_dropout=True):
        
        #:param input_nc: number of input channels
        #:param output_nc: number of output channels
        #:param ngf: number of generator filters in the first convolutional layer
        
        super().__init__()
        self.downrelu = nn.LeakyReLU(0.2, True)
        self.uprelu = nn.ReLU(True)
        self.tanh = nn.Tanh()
        self.drop_rate = 0.5 if use_dropout else 0.0

        self.downconv1 = nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.downconv2 = nn.Conv2d(ngf, ngf*2, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.downbn2   = nn.BatchNorm2d(ngf*2)
        self.downconv3 = nn.Conv2d(ngf*2, ngf*4, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.downbn3   = nn.BatchNorm2d(ngf*4)
        self.downconv4 = nn.Conv2d(ngf*4, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.downbn4   = nn.BatchNorm2d(ngf*8)
        self.downconv5 = nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.downbn5   = nn.BatchNorm2d(ngf*8)
        self.downconv6 = nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.downbn6   = nn.BatchNorm2d(ngf*8)
        self.downconv7 = nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.downbn7   = nn.BatchNorm2d(ngf*8)
        self.downconv8 = nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)

        self.upconv1   = nn.ConvTranspose2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.upbn1     = nn.BatchNorm2d(ngf*8)
        self.updrop1   = nn.Dropout(self.drop_rate)
        self.upconv2   = nn.ConvTranspose2d(ngf*8*2, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.upbn2     = nn.BatchNorm2d(ngf*8)
        self.updrop2   = nn.Dropout(self.drop_rate)
        self.upconv3   = nn.ConvTranspose2d(ngf*8*2, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.upbn3     = nn.BatchNorm2d(ngf*8)
        self.updrop3   = nn.Dropout(self.drop_rate)
        self.upconv4   = nn.ConvTranspose2d(ngf*8*2, ngf*4, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.upbn4     = nn.BatchNorm2d(ngf*4)
        self.upconv5   = nn.ConvTranspose2d(ngf*4*2, ngf*2, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.upbn5     = nn.BatchNorm2d(ngf*2)
        self.upconv6   = nn.ConvTranspose2d(ngf*2*2, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.upbn6     = nn.BatchNorm2d(ngf*1)
        self.upconv7   = nn.ConvTranspose2d(ngf*2, output_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)

    def forward(self, x):
        e1  = self.downconv1(x)                                        # input x is (input_nc) x 128 x 128
        e2  = self.downbn2(self.downconv2(self.downrelu(e1)))          # input e1 is (ngf) x 64 x 64
        e3  = self.downbn3(self.downconv3(self.downrelu(e2)))          # input e2 is (ngf * 2) x 32 x 32
        e4  = self.downbn4(self.downconv4(self.downrelu(e3)))          # input e3 is (ngf * 4) x 16 x16
        e5  = self.downbn5(self.downconv5(self.downrelu(e4)))          # input e4 is (ngf * 8) x 8 x 8
        e6  = self.downbn6(self.downconv6(self.downrelu(e5)))          # input e5 is (ngf * 8) x 4 x 4
        e7  = self.downconv7(self.downrelu(e6))                        # input e6 is (ngf * 8) x 2 x 2

        d1_ = self.updrop1(self.upbn1(self.upconv1(self.uprelu(e7))))  # input e7 is (ngf * 8) x 1 x 1
        d1  = torch.cat([d1_, e6], dim=1)
        d2_ = self.updrop2(self.upbn2(self.upconv2(self.uprelu(d1))))  # input d1 is (ngf * 8 * 2) x 2 x 2
        d2  = torch.cat([d2_, e5], dim=1)
        d3_ = self.updrop3(self.upbn3(self.upconv3(self.uprelu(d2))))  # input d2 is (ngf * 8 * 2) x 4 x 4
        d3  = torch.cat([d3_, e4], dim=1)
        d4_ = self.upbn4(self.upconv4(self.uprelu(d3)))                # input d3 is (ngf * 8 * 2) x 8 x 8
        d4  = torch.cat([d4_, e3], dim=1)
        d5_ = self.upbn5(self.upconv5(self.uprelu(d4)))                # input d4 is (ngf * 8 * 2) x 16 x 16
        d5  = torch.cat([d5_, e2], dim=1)
        d6_ = self.upbn6(self.upconv6(self.uprelu(d5)))                # input d5 is (ngf * 4 * 2) x 32 x 32
        d6  = torch.cat([d6_, e1], dim=1)
        d7 =  self.upconv7(self.uprelu(d6))                            # input d6 is (ngf * 2 * 2) x 64 x 64
                                                                       # input d7 is (ngf * 2) x 128 x 128
        o1  = self.tanh(d7)
        return o1

net = Generator(1,3)
x = torch.randn(1,1,128,128)


#visualization = make_dot(net(x), params=dict(net.named_parameters()))
#visualization.render("Unet_render", format="png", cleanup=True)

print(net(x).shape)

**Autoencoder**

In [None]:
class Autoencoder(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, use_bias=False, use_dropout=True):
        super().__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ngf, ngf*2, kernel_size=4, stride=2, padding=1, bias=use_bias),
            nn.BatchNorm2d(ngf*2),
            nn.LeakyReLU(0.2, True),
        )

        # Decoder
        self.decoder = nn.Sequential(
            # Add more decoder layers with transpose convolutions
            nn.ConvTranspose2d(ngf*2, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, output_nc, kernel_size=4, stride=2, padding=1, bias=use_bias),
            nn.Tanh()
        )

    def forward(self, x):
        # Encoder
        encoded = self.encoder(x)

        # Decoder
        decoded = self.decoder(encoded)

        return decoded


net = Autoencoder(1,3)
x = torch.randn(1,1,128,128)
#visualization = make_dot(net(x), params=dict(net.named_parameters()))
#visualization.render("Autoencoder_render", format="png", cleanup=True)
print(net(x).shape)

**ResNet**

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, use_bias=False):
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=use_bias)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=use_bias),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class GeneratorResNet(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, num_residual_blocks=6, use_bias=False):
        super(GeneratorResNet, self).__init__()

        self.relu = nn.ReLU(True)
        self.tanh = nn.Tanh()

        # Initial convolutional layer
        self.conv1 = nn.Conv2d(input_nc, ngf, kernel_size=7, stride=1, padding=3, bias=use_bias)
        self.bn1 = nn.BatchNorm2d(ngf)

        # Downsampling
        self.downsampling = nn.Sequential(
            nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=2, padding=1, bias=use_bias),
            nn.BatchNorm2d(ngf*2),
            self.relu,
            nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=2, padding=1, bias=use_bias),
            nn.BatchNorm2d(ngf*4),
            self.relu
        )

        # Residual blocks
        self.residual_blocks = nn.ModuleList([
            ResidualBlock(ngf*4, ngf*4, stride=1, use_bias=use_bias) for _ in range(num_residual_blocks)
        ])

        # Upsampling
        self.upsampling = nn.Sequential(
            nn.ConvTranspose2d(ngf*4, ngf*2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias),
            nn.BatchNorm2d(ngf*2),
            self.relu,
            nn.ConvTranspose2d(ngf*2, ngf, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias),
            nn.BatchNorm2d(ngf),
            self.relu
        )

        # Output layer
        self.output_layer = nn.Conv2d(ngf, output_nc, kernel_size=7, stride=1, padding=3, bias=use_bias)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.downsampling(out)

        for block in self.residual_blocks:
            out = block(out)

        out = self.upsampling(out)

        out = self.output_layer(out)
        out = self.tanh(out)

        return out

net = GeneratorResNet(1,3)
x = torch.randn(1,1,128,128)
#visualization = make_dot(net(x), params=dict(net.named_parameters()))
#visualization.render("ResNet_render", format="png", cleanup=True)
print(net(x).shape)


In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_nc, ndf=64):
        """
        :param input_nc: number of input channels
        :param ndf: number of discriminator filters in the first convolutional layer
        """
        super().__init__()

        self.leaky_relu = nn.LeakyReLU(0.2, True)

        self.conv1    = nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1) #input_nc + output_nc, ndf, 1, 1, 0 nell'articolo (pixelGan --> pix2pix)
        self.conv2    = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) #1, 1, 0
        self.conv2_bn = nn.BatchNorm2d(ndf*2)
        self.conv3    = nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) #ndf*2, 1, 1, 1, 0
        self.conv3_bn = nn.BatchNorm2d(ndf*4) #no
        self.conv4    = nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=1, padding=1) #no
        self.conv4_bn = nn.BatchNorm2d(ndf*8) #no
        self.conv5    = nn.Conv2d(ndf*8, 1, kernel_size=4, stride=1, padding=1) #no --> solo sigmoid()

    def forward(self, x):
      #print("Size input x: ", x.shape)
      #print("Size after first conv: ", self.conv1(x).shape)
      x = self.leaky_relu(self.conv1(x))
      #print("Size after leaky relu 1: ", x.shape) #x is 3x256x256
      #print("Size after second conv: ", self.conv2(x).shape)
      x = self.leaky_relu(self.conv2_bn(self.conv2(x)))
      #print("Size after leaky relu 2: ", x.shape)
      #print("Size after third conv: ", self.conv3(x).shape)
      x = self.leaky_relu(self.conv3_bn(self.conv3(x)))
      #print("Size after leaky relu 3: ", x.shape)
      #print("Size after 4th conv: ", self.conv4(x).shape)
      x = self.leaky_relu(self.conv4_bn(self.conv4(x)))
      #print("Size after leaky relu 4: ", x.shape)
      x = self.conv5(x)  # No sigmoid since BCEWithLogitsLoss is used
      #print("Size output (conv5): ", x.shape)
      return x

net = Discriminator(3)
x = torch.randn(1,3,128,128)
print(net(x).shape)

PATCH_SIZE = net(x).shape[2] #now is 14

In [None]:
class GANModel(nn.Module):
  def __init__(self, device, criterion, lr, betas, im_type="RGB", ngf=64, ndf=64, use_bias=False, use_dropout=True, lambda_L1=0.0, gen_type='Unet'):
    """
    :param im_type: a string ("Lab" or "RGB") indicating the image type
    :param ngf: number of generator filters in the first convolutional layer
    :param dgf: number of discriminator filters in the first convolutional layer
    """
    super().__init__()

    self.device = device
    self.in_channels = 1
    self.im_type = im_type
    if self.im_type == "RGB":
      self.out_channels = 3
    else:
      self.out_channels = 2

    if gen_type == 'Unet':
        self.generator = Generator(self.in_channels, self.out_channels).to(self.device)
    elif gen_type == 'Autoencoder':
        self.generator = Autoencoder(self.in_channels, self.out_channels).to(self.device)
    elif gen_type == 'ResNet' :
        self.generator = GeneratorResNet(self.in_channels, self.out_channels).to(self.device)
        
    self.discriminator = Discriminator(3).to(self.device)
    self.generator.apply(init_weights)
    self.discriminator.apply(init_weights)

    (self.generator_lr , self.discriminator_lr) = lr
    self.betas = betas

    self.criterion = criterion
    self.lambdaL1 = lambda_L1
    self.L1loss = nn.L1Loss()

    self.generator_optimizer = optim.Adam(self.generator.parameters(), lr=self.generator_lr, betas=betas)
    self.discriminator_optimizer = optim.Adam(self.discriminator.parameters(), lr=self.discriminator_lr, betas=betas)

  def forward(self, x):
    return self.generator(x)

  def set_requires_grad(self, model, requires_grad=True):
    for param in model.parameters():
      param.requires_grad = requires_grad

  def save(self, epoch, log, path="./checkpoint.pt", download=True):
    """Saves state_dict for generator, discriminator and optimizers to path
    """
    torch.save({
        'epoch' : epoch,
        'log' : log,
        'generator_state_dict' : self.generator.state_dict(),
        'discriminator_state_dict' : self.discriminator.state_dict(),
        'generator_optimizer_state_dict' : self.generator_optimizer.state_dict(),
        'discriminator_optimizer_state_dict' : self.discriminator_optimizer.state_dict()
        }, path)

  def load(self, path="./checkpoint.pt"):
    """Loads state_dict for generator, discriminator and optimizers from path
    """
    checkpoint = torch.load(path)
    self.generator.load_state_dict(checkpoint['generator_state_dict'])
    self.discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
    self.generator_optimizer.load_state_dict(checkpoint['generator_optimizer_state_dict'])
    self.discriminator_optimizer.load_state_dict(checkpoint['discriminator_optimizer_state_dict'])
    return checkpoint['epoch'], checkpoint['log']



## Training loop with data logging ## 

In [None]:
def train_step(model, data, real_label, fake_label):
    """
    What happens for each batch of the training dataloader = data
    """
    x, label = data
    x = x.to(model.device)
    label = label.to(model.device)

    ######### UPDATE D NETWORK ##########

    if model.im_type == "Lab":
      real_images = torch.cat([x, label], dim=1)
    else:
      real_images = label
    real_images.to(model.device)

    #Train with all the real batch
    model.discriminator.train()
    model.set_requires_grad(model.discriminator, True)
    model.discriminator_optimizer.zero_grad()

    output_D_real = model.discriminator(real_images).to(model.device) #passing real data to discriminator and save the output
    label_D_real_images = torch.tensor(real_label).expand_as(output_D_real).to(model.device)
    errorD_real = model.criterion(output_D_real, label_D_real_images) #comparing output with labels to compute the loss log(D(x))
    errorD_real_item = errorD_real.item() 

    generator_output = model.generator(x).to(device)
    if model.im_type == "Lab":
      generated_images = torch.cat([x, generator_output], dim=1).to(model.device)
    else:
      generated_images = generator_output

    output_D_fake = model.discriminator(generated_images.detach()).to(model.device) 
    label_D_fake_images = torch.tensor(fake_label).expand_as(output_D_fake).to(model.device)
    errorD_fake = model.criterion(output_D_fake, label_D_fake_images)
    errorD_fake_item = errorD_fake.item()

    error_D = 0.5*(errorD_real + errorD_fake)
    error_D = error_D.to(model.device)
    error_D_item = error_D.item()
    error_D.backward()
    model.discriminator_optimizer.step()

    ####### UPDATE G NETWORK ######
    model.generator.train()
    model.set_requires_grad(model.discriminator, False)
    model.generator_optimizer.zero_grad()
    output_D_fake = model.discriminator(generated_images).to(model.device)
    loss_generated_images = model.criterion(output_D_fake, label_D_real_images)
    loss_generated_images_item = loss_generated_images.item()
    loss_L1 = model.L1loss(generator_output, label)
    loss_L1_item = loss_L1.item()

    error_G = loss_generated_images + loss_L1 * model.lambdaL1
    error_G_item = loss_generated_images_item + loss_L1_item * model.lambdaL1
    error_G.backward()

    model.generator_optimizer.step()

    return error_G_item, error_D_item

def train_GAN(model, train_dl, epochs, checkpoint=None):
    real_label = REAL_LABELS
    fake_label = FAKE_LABELS

    if checkpoint:
        curr_ep, log = model.load(checkpoint)
    else:
        log = {'tr_generator_loss' : [], 'tr_discriminator_loss' : []}
        curr_ep = 0

    print("Starting TRAIN LOOP")

    for epoch in range(epochs):
      ep_num = curr_ep + epoch + 1
      b_n = 0
      epoch_G_losses = []
      epoch_D_losses = []
      
        
      model.generator.train()
      model.discriminator.train()

      for batch in tqdm(train_dl, desc=f"Epoch {epoch+1} / {epochs}"):
        G_loss, D_loss = train_step(model, batch, real_label, fake_label)
        epoch_G_losses.append(G_loss) 
        epoch_D_losses.append(D_loss)
          
      avg_G_loss = np.mean(epoch_G_losses)
      avg_D_loss = np.mean(epoch_D_losses)
      print(f"Epoch {ep_num}/{epochs}, Avg Loss G: {avg_G_loss}, Avg Loss D: {avg_D_loss}")
      log['tr_discriminator_loss'].append(avg_D_loss)
      log['tr_generator_loss'].append(avg_G_loss)

      wandb.log({"epoch":ep_num,
                 "generator_loss": avg_G_loss,
                 "discriminator_loss": avg_D_loss})

      model.save(epoch=ep_num, log=log, path=f"./checkpoint.pt", download=False)
      print(f"SAVED CHECKPOINT EPOCH {ep_num}")

    return log


In [None]:
def plot_losses(generator_losses, discriminator_losses, filename='epochs_loss.png'):
    epochs = range(1, len(generator_losses) + 1)

    plt.figure(figsize=(10, 5))
    plt.plot(epochs, generator_losses, label='Generator Loss', marker='o')
    plt.plot(epochs, discriminator_losses, label='Discriminator Loss', marker='o')

    plt.title('Generator and Discriminator Losses Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig('C:/Users/alber/Desktop/MSc ICT Internet Multimedia Engineering/'+filename+'.png')
    plt.show()




In [None]:
criterion = nn.BCEWithLogitsLoss() 

betas = [0.5, 0.999]
lr = [2e-4, 2e-4]
LAMBDA = 100. #in the article
GEN_TYPE = 'ResNet'
EPOCHS_l = [0] #list of number of epochs we want to train the net for

IM_TYPE = 'RGB'  

for ep in EPOCHS_l:
    if cars:
        model_saved_name = 'model_E' + str(ep) + '_I' + str(SAMPLE_TRAIN) + '_cars_' + GEN_TYPE + '_' + IM_TYPE
        run = wandb.init(project='bw2rgb', notes = "images in the training set: " + str(SAMPLE_TRAIN) + '_cars_' + GEN_TYPE + '_' + IM_TYPE)
    else:
        model_saved_name = 'model_E' + str(ep) + '_I' + str(SAMPLE_TRAIN) +'_' + GEN_TYPE + '_' + IM_TYPE
        run = wandb.init(project='bw2rgb', notes = "images in the training set: " + str(SAMPLE_TRAIN)+ '_' + GEN_TYPE + '_' + IM_TYPE)
       
    wandb.define_metric("epoch")
    wandb.define_metric("generator_loss", step_metric="epoch")
    wandb.define_metric("discriminator_loss", step_metric="epoch")
    model = GANModel(device, criterion, lr, betas, im_type = IM_TYPE, lambda_L1=LAMBDA, gen_type=GEN_TYPE)
    
    if IM_TYPE == "RGB":
        train_dl = train_dl_RGB
        test_dl = test_dl_RGB
    else:
        train_dl = train_dl_LAB
        test_dl = test_dl_LAB
    log = train_GAN(model, train_dl, epochs=ep)
    
    G_losses = log['tr_generator_loss']
    D_losses = log['tr_discriminator_loss']
    plot_losses(G_losses, D_losses, filename=model_saved_name)

    with open('C:/Users/alber/Desktop/MSc ICT Internet Multimedia Engineering/'+model_saved_name+'.txt', 'w') as output:
        output.write(str(log))
    torch.save(model, 'C:/Users/alber/Desktop/MSc ICT Internet Multimedia Engineering/'+model_saved_name+'.pt')

    wandb.finish()
    


## Visualzing results ##

In [None]:
#model = torch.load("C:/Users/alber/Desktop/MSc ICT Internet Multimedia Engineering/NNDL/model_E45_I10000_cars_Unet_Lab.pt")
#model.eval()

if IM_TYPE == "RGB":
    train_dl = train_dl_RGB
    test_dl = test_dl_RGB
else:
    train_dl = train_dl_LAB
    test_dl = test_dl_LAB

i = 0
for batch in test_dl:
  #save_outputs(model, batch, IM_TYPE, save = True)
  visualize(model, batch, IM_TYPE, save = False, filename = 'results_' + str(i))
  if i == 4:
      break
  i = i + 1


## Computing final losses ##

In [None]:
def test_loss(model, test_dl, L1_coeff = 0.0):
  batch_losses = []
  for batch in test_dl:
    model.generator.eval()
    model.discriminator.eval()
    with torch.no_grad(): 
        x, label = batch
        label = label.to(model.device)
        x = x.to(model.device) 
        output = model.generator(x)
        output = output.to(model.device)
        

        if model.im_type == "Lab":
          generated_images = torch.cat([x, output], dim=1) #label and output are 2d, but we want a 3d image
          true_images = torch.cat([x, label], dim=1)
        else:
          generated_images = output
          true_images = label #3d already 

        #generated_images = generated_images.to(model.device)
        discriminator_output = model.discriminator(generated_images)
        label_GAN_loss = torch.tensor(1.).expand_as(discriminator_output).to(model.device)

        #generated_images = generated_images.cpu()
        L1_loss = nn.L1Loss()
        loss_L1 = L1_loss(generated_images, true_images)
        GAN_loss = model.criterion(discriminator_output, label_GAN_loss)
        loss = GAN_loss + L1_coeff* loss_L1
        loss.cpu()

        batch_losses.append(loss.item())

    model.generator.train()

  return batch_losses

"""
#For the generic dataset:

root_path = "C:/Users/alber/Desktop/MSc ICT Internet Multimedia Engineering/generic_dataset/"
EPOCHS_fl = [15,30,45]
im_count_l = ["10000", "20000"]

losses_f = {}

for im_count in im_count_l:
    losses = np.array([])
    for ep in EPOCHS_fl:
        model_loc = root_path + "E" + str(ep) + "_I" + im_count + "/model_E" + str(ep) + "_I" + im_count +".pt"
        model = torch.load(model_loc)
        model.eval()
    
        test_losses = test_loss(model, test_dl, L1_coeff = LAMBDA)
        average_test_loss = np.mean(test_losses)
        losses = np.append(losses, average_test_loss)
    losses_f[im_count] = losses

print(losses_f)

for im_count, loss_values in losses_f.items():
    plt.plot(EPOCHS_fl, loss_values, label=f"Images {im_count}")

plt.xlabel('Epochs')
plt.ylabel('Average Test Loss')
plt.title('Test Loss vs Epochs for different dataset sizes')
plt.legend()
plt.savefig('C:/Users/alber/Desktop/MSc ICT Internet Multimedia Engineering/'+'test_losses'+'.png')
plt.show()
"""
"""
#For cars dataset:
root_path = "C:/Users/alber/Desktop/MSc ICT Internet Multimedia Engineering/cars_dataset/"
EPOCHS_fl = [15, 30, 45]
im_count_l = ["10000"]

losses_f = {}

for im_count in im_count_l:
    losses = np.array([])
    for ep in EPOCHS_fl:
        model_loc = root_path + "E" + str(ep) + "_I" + im_count + "_cars/model_E" + str(ep) + "_I" + im_count +"_cars.pt"
        model = torch.load(model_loc)
        model.eval()
    
        test_losses = test_loss(model, test_dl, L1_coeff = LAMBDA)
        average_test_loss = np.mean(test_losses)
        losses = np.append(losses, average_test_loss)
    losses_f[im_count] = losses

print("losses for "+ GEN_TYPE + "_" + IM_TYPE + str(losses_f))
"""


#for different architectures
root_path = "C:/Users/alber/Desktop/MSc ICT Internet Multimedia Engineering/NNDL/"
EPOCHS_fl = [45]
im_count_l = ["10000"]

losses_f = {}

for im_count in im_count_l:
    losses = np.array([])
    for ep in EPOCHS_fl:
        model_loc = root_path + "model_E" + str(ep) + "_I" + im_count + "_cars_" + GEN_TYPE + "_"+ IM_TYPE + ".pt"
        model = torch.load(model_loc)
        model.eval()
    
        test_losses = test_loss(model, test_dl, L1_coeff = LAMBDA)
        average_test_loss = np.mean(test_losses)
        losses = np.append(losses, average_test_loss)
    losses_f[im_count] = losses

print("losses for "+ GEN_TYPE + "_" + IM_TYPE + str(losses_f))
