N.N and deep learning
Image recoloring
Shahla Sadeghzadeh - Shayan Sharifi - Mohammad Vanaei

1.first step is importing the needed library

In [37]:
import os
import glob
import time
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb

import torch
from torch import nn, optim
import torchvision
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from fastai.data.external import untar_data, URLs

from torch.utils.tensorboard import SummaryWriter

Downloading images from COCO
we considered 10000 images which we dedicated 8000 for training and 2
000 for testing


In [38]:

# # coco_path = untar_data(URLs.COCO_SAMPLE)
# # coco_path = str(coco_path) + "/train_sample"
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
paths=glob.glob("/content/drive/MyDrive/Deep/coco/*.jpg")
#paths = glob.glob("./data/coco_sample/train_sample/*.jpg")
np.random.seed(100)
paths_subset = np.random.choice(paths, 3_000, replace=False) # choosing 1000 images by random
rand_idxs = np.random.permutation(3_000)
train_idxs = rand_idxs[:2400]
val_idxs = rand_idxs[2400:]
train_paths = paths_subset[train_idxs]
val_paths = paths_subset[val_idxs]


print(len(train_paths), len(val_paths))

_, axes = plt.subplots(4, 4, figsize=(10, 10))
for ax, img_path in zip(axes.flatten(), train_paths):
    ax.imshow(Image.open(img_path))
    ax.axis("off")

below code defines a dataset and data loaders for a colorization task using PyTorch. It loads images, converts them to the Lab color space, separates them into luminance (L) and color (ab) channels, and prepares them for training and validation. The data loaders enable efficient batch processing for training a colorization model.
It converts the images to the Lab color space, which consists of three channels: L (luminance), a (green-red color component), and b (blue-yellow color component). The L channel represents the grayscale version of the image, while the a and b channels represent the color information.

In [40]:
SIZE = 256
class ColorizationDataset(Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE)),

            ])

        elif split == 'val':
            self.transforms = transforms.Resize((SIZE, SIZE))

        self.split = split
        self.size = SIZE
        self.paths = paths

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = self.transforms(img)
        img = np.array(img)
        img_lab = rgb2lab(img).astype("float32") #  RGB to Lab

        #print(img_lab)
        img_lab = transforms.ToTensor()(img_lab) #It scales the pixel values to the range [0, 1] during the conversion.
                                                # The ToTensor transformation also ensures that the pixel values are
                                                # scaled appropriately for input into neural networks, which commonly expect input in the [0, 1] range.

        L = img_lab[[0], ...] / 50. - 1.
        ab = img_lab[[1, 2], ...] / 110.
        return {'L': L, 'ab': ab}

    def __len__(self):
        return len(self.paths)
#batch_size=16
def make_dataloaders(batch_size=32, n_workers=4, pin_memory=True, **kwargs):
    dataset = ColorizationDataset(**kwargs)
    dataloader = DataLoader(dataset, batch_size=batch_size)
    return dataloader


train_dl = make_dataloaders(paths=train_paths, split='train')
val_dl = make_dataloaders(paths=val_paths, split='val')

data = next(iter(train_dl))
Ls, abs_ = data['L'], data['ab']
Ls_numpy = Ls.numpy()[:][0]


UNet
 this code defines a U-Net architecture for image segmentation tasks. The U-Net architecture is composed of a series of downsampling (encoding) and upsampling (decoding) blocks, which allow it to capture fine-grained details and spatial information in images /n

the defenition of parameters are below:/n

in_channels: The number of input channels.
out_channels: The number of output channels.
submodule: A submodule to be used inside this block.
input_channels: The number of input channels if not specified (defaults to out_channels).
use_dropout: Whether to use dropout layers.
is_innermost: Indicates if this is the innermost block.
is_outermost: Indicates if this is the outermost block.
/n
Inside the UNetBlock class:

The constructor sets up the block's layers based on the provided parameters.
It includes convolutional layers, activation functions (ReLU or LeakyReLU), and batch normalization.
The block structure varies depending on whether it's innermost, outermost, or a middle block.
The forward method defines the forward pass through this block./n

Define the UNet class:

This class represents the entire U-Net architecture.
It takes parameters for the number of input channels, number of output channels, the number of downscaling steps, and the number of filters to use.
Inside the constructor:
It initializes the innermost block with one downscaling step.
Repeatedly, it adds middle blocks (with optional dropout) according to the specified number of downscaling steps.
Finally, it sets up the outermost block.

inside the UNet class:

The forward method defines the forward pass through the entire U-Net architecture. It passes the input through all the blocks and returns the output.

In [41]:

class UnetBlock(nn.Module):
    def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False,
                 innermost=False, outermost=False):
        super().__init__()
        self.outermost = outermost
        if input_c is None: input_c = nf
        downconv = nn.Conv2d(input_c, ni, kernel_size=4,
                             stride=2, padding=1, bias=False)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = nn.BatchNorm2d(ni)
        uprelu = nn.ReLU(True)
        upnorm = nn.BatchNorm2d(nf)

        if outermost:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
                                        stride=2, padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4,
                                        stride=2, padding=1, bias=False)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
                                        stride=2, padding=1, bias=False)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]
            if dropout: up += [nn.Dropout(0.5)]
            model = down + [submodule] + up
        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)

class Unet(nn.Module):
    def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64):
        super().__init__()
        unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True)
        for _ in range(n_down - 5):
            unet_block = UnetBlock(num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True)
        out_filters = num_filters * 8
        for _ in range(3):
            unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block)
            out_filters //= 2
        self.model = UnetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True)

    def forward(self, x):
        return self.model(x)






Unet()

Unet(
  (model): UnetBlock(
    (model): Sequential(
      (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): UnetBlock(
        (model): Sequential(
          (0): LeakyReLU(negative_slope=0.2, inplace=True)
          (1): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
          (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (3): UnetBlock(
            (model): Sequential(
              (0): LeakyReLU(negative_slope=0.2, inplace=True)
              (1): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
              (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (3): UnetBlock(
                (model): Sequential(
                  (0): LeakyReLU(negative_slope=0.2, inplace=True)
                  (1): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
 

Discriminator
This PatchGAN discriminator is designed to take an image as input and produce a spatial grid of real/fake predictions at different locations in the image. It's commonly used in conditional GANs and image-to-image translation tasks.

In [42]:
class Critic(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Critic, self).__init__()
        self.disc = nn.Sequential(
            # input: N x channels_img x 64 x 64
            nn.Conv2d(
                channels_img, features_d, kernel_size=4, stride=2, padding=1
            ),
            nn.LeakyReLU(0.2),
            # _block(in_channels, out_channels, kernel_size, stride, padding)
            self._block(features_d, features_d * 2, 4, 2, 1),
            self._block(features_d * 2, features_d * 4, 4, 2, 1),
            self._block(features_d * 4, features_d * 8, 4, 2, 1),
            # After all _block img output is 4x4 (Conv2d below makes into 1x1)
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)

Gan loss
 this custom GAN loss class allows us to easily compute GAN losses for both 'vanilla' GAN and 'lsgan' scenarios by specifying the desired GAN mode during initialization.

Model Initialization
We are going to initialize the weights of our model with a mean of 0.0 and standard deviation of 0.02

In [43]:
def init_weights(net, init='norm', gain=0.02):

    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and 'Conv' in classname:
            if init == 'norm':
                nn.init.normal_(m.weight.data, mean=0.0, std=gain)
            elif init == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=gain)
            elif init == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')

            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)
        elif 'BatchNorm2d' in classname:
            nn.init.normal_(m.weight.data, 1., gain)
            nn.init.constant_(m.bias.data, 0.)

    net.apply(init_func)
    print(f"model initialized with {init} initialization")
    return net

def init_model(model, device):
    model = model.to(device)
    model = init_weights(model)
    return model

In [44]:
model = init_model(Unet(input_c=1, output_c=2, n_down=8, num_filters=64), device)

model initialized with norm initialization


### **WGAN**
Wgan Main Model

In [45]:
### WGAN Main Model


class MainModel(nn.Module):
    def __init__(self, net_G=None, lr_G=5e-5, lr_D=5e-5, lambda_L1=100., WEIGHT_CLIP = 0.01):
        super().__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lambda_L1 = lambda_L1
        self.WEIGHT_CLIP = WEIGHT_CLIP

        if net_G is None:
            self.net_G = init_model(Unet(input_c=1, output_c=2, n_down=8, num_filters=64), self.device)
        else:
            self.net_G = net_G.to(self.device)
        self.net_D = init_model(Critic(3, 64), self.device)

        self.opt_G = optim.RMSprop(self.net_G.parameters(), lr=lr_G)
        self.opt_D = optim.RMSprop(self.net_D.parameters(), lr=lr_D)

    def set_requires_grad(self, model, requires_grad=True):
        for p in model.parameters():
            p.requires_grad = requires_grad

    def setup_input(self, data):
        self.L = data['L'].to(self.device)
        self.ab = data['ab'].to(self.device)

    def forward(self):
        self.fake_color = self.net_G(self.L)

    def backward_D(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        self.loss_D_fake = self.net_D(fake_image.detach())
        #print("type loss_D_fake:",type(self.loss_D_fake))
        #print("loss_D_fake:", self.loss_D_fake)
        #self.loss_D_fake = self.GANcriterion(fake_preds, False)
        real_image = torch.cat([self.L, self.ab], dim=1)
        self.loss_D_real = self.net_D(real_image)
        #self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D = -(torch.mean(self.loss_D_real) - torch.mean(self.loss_D_fake))
        #print("type loss_D:",type(self.loss_D))
        #print("loss_D:", self.loss_D)
        self.loss_D.backward()

    def backward_G(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image)
        #self.loss_G_GAN = self.GANcriterion(fake_preds, True)
        #self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
        #self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G = -torch.mean(fake_preds)
        self.loss_G.backward()

    def optimize_Critic(self):
        self.forward()
        self.net_D.train()
        self.set_requires_grad(self.net_D, True)
        self.opt_D.zero_grad()
        self.backward_D()
        self.opt_D.step()

        # clip critic weights between -0.01, 0.01
        for p in self.net_D.parameters():
            p.data.clamp_(-self.WEIGHT_CLIP, self.WEIGHT_CLIP)

    def optimize_G(self):
        self.net_G.train()
        self.set_requires_grad(self.net_D, False)
        self.opt_G.zero_grad()
        self.backward_G()
        self.opt_G.step()

Utility function

In [46]:
epochs_loss_G = []
epochs_loss_D = []

class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.count, self.avg, self.sum = [0.] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += count * val
        self.avg = self.sum / self.count

def create_loss_meters():
    loss_D = AverageMeter()
    loss_G = AverageMeter()

    return {'loss_D': loss_D,
            'loss_G': loss_G}

def update_losses(model, loss_meter_dict, count):
    for loss_name, loss_meter in loss_meter_dict.items():
        loss = getattr(model, loss_name)
        loss_meter.update(loss.item(), count=count)

def lab_to_rgb(L, ab):
    """
    Takes a batch of images
    """

    L = (L + 1.) * 50.
    ab = ab * 110.
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    return np.stack(rgb_imgs, axis=0)

def visualize(model, data, save=True):
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    model.net_G.train()
    fake_color = model.fake_color.detach()
    real_color = model.ab
    L = model.L
    fake_imgs = lab_to_rgb(L, fake_color)
    real_imgs = lab_to_rgb(L, real_color)
    fig = plt.figure(figsize=(15, 8))

    num_images = min(5, L.shape[0])
    for i in range(num_images):
        ax = plt.subplot(3, 5, i + 1)
        ax.imshow(L[i][0].cpu(), cmap='gray')
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 5)
        ax.imshow(fake_imgs[i])
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 10)
        ax.imshow(real_imgs[i])
        ax.axis("off")
    plt.show()

    if save:
        fig.savefig(f"colorization_{time.time()}.png")

def log_results(loss_meter_dict):
    for loss_name, loss_meter in loss_meter_dict.items():
        if loss_name == 'loss_D':
          epochs_loss_D.append(loss_meter.avg)
        if loss_name == 'loss_G':
          epochs_loss_G.append(loss_meter.avg)
        print(f"{loss_name}: {loss_meter.avg:.5f}")

In [47]:
#Metrics------
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
import torch.nn.functional as F

#Metrics-------
def calculate_metrics(model, dt):

    # Set the model to evaluation mode and disable gradient computation
    model.net_G.eval() #model.eval()

    with torch.no_grad():
        # Move input data to the specified device (GPU or CPU)
        model.setup_input(dt)

        # Forward pass to generate fake colorized images
        model.forward()
        model.net_G.train()

        # Convert torch tensors to numpy arrays
        fake_color = model.fake_color.cpu().detach().numpy()
        ab = model.ab.cpu().detach().numpy()#real_color

        # Convert to RGB format for PSNR calculation
        fake_rgb = np.concatenate((model.L.cpu().detach().numpy(), fake_color), axis=1)
        ab_rgb = np.concatenate((model.L.cpu().detach().numpy(), ab), axis=1)

        # Calculate PSNR and append to the list
        psnr = peak_signal_noise_ratio(ab_rgb, fake_rgb)

        # Calculate SSIM and append to the list
        fake_color = fake_color.transpose((0, 2, 3, 1))  # Change to (batch, height, width, channels) format
        ab = ab.transpose((0, 2, 3, 1))
        ssim = structural_similarity(ab, fake_color, multichannel=True)

    return psnr, ssim

Trainig function

## **WGAN Train**

In [None]:
from re import I

### WGAN Train

epochs_times = []

CRITIC_ITERATIONS = 5

def train_model(model, train_dl, epochs, display_every=7):
    data = next(iter(val_dl))

    writer_real = SummaryWriter(f"logs/real")
    writer_fake = SummaryWriter(f"logs/fake")


    #Metrics, Initialize lists to store PSNR and SSIM values
    psnr_values = []
    ssim_values = []

    start_time = time.time()
    for e in range(epochs):
        t1 = time.time()
        loss_meter_dict = create_loss_meters()
        i = 0
        for data in tqdm(train_dl):
            model.setup_input(data)

            # Train Critic: max E[critic(real)] - E[critic(fake)]
            for _ in range(CRITIC_ITERATIONS):
              model.optimize_Critic() # Substitude

            model.optimize_G() #Substitude

            update_losses(model, loss_meter_dict, count=data['L'].size(0))
            i += 1

            #Metrics-----
            psnr, ssim =calculate_metrics(model, data)
            psnr_values.append(psnr)
            ssim_values.append(ssim)
            print(f'PSNR{i}: {psnr:.4f}, SSIM{i}: {ssim:.4f}')
            #Metrics------


            if i % display_every == 0:
                print(f"\nEpoch {e+1}/{epochs}")
                print(f"Iteration {i}/{len(train_dl)}")
                log_results(loss_meter_dict)
                visualize(model, data, save=False)
                print(f'PSNR{i}: {psnr:.4f}, SSIM{i}: {ssim:.4f}')

        # Metrics, Calculate the average PSNR and SSIM values at the end of each epocs
        avg_psnr = sum(psnr_values) / len(psnr_values)
        avg_ssim = sum(ssim_values) / len(ssim_values)
        max_psnr = max(psnr_values)
        min_psnr = min(psnr_values)
        max_ssim = max(ssim_values)
        min_ssim = min(ssim_values)
        # Print the results
        print(f'Average PSNR: {avg_psnr:.4f}, Maximum PSNR: {max_psnr:.4f}, Minimum PSNR: {min_psnr:.4f}')
        print(f'Average SSIM: {avg_ssim:.4f}, Maximum SSIM: {max_ssim:.4f}, Minimum SSIM: {min_ssim:.4f}')
        #Metrics------
        t2 = time.time()
        epochs_times.append((t2 - t1))
    t_end = time.time()



model = MainModel()
train_model(model, train_dl, 20)
