# Image Colorization with U-Net and GAN 

In [None]:
pip install -U albumentations

# Loading Image Paths

In [None]:
import os
import glob
import time
import numpy as np
from PIL import Image
import cv2
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb

import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
#!pip install fastai==2.4

The following will download about 20,000 images from COCO dataset. Notice that **we are going to use only 8000 of them** for training. Also you can use any other dataset like ImageNet as long as it contains various scenes and locations.

In [None]:
# from fastai.data.external import untar_data, URLs
# coco_path = untar_data(URLs.COCO_SAMPLE)
# coco_path = str(coco_path) + "/train_sample"
# use_colab = True

In [None]:
data_set_root='/kaggle/input/coco-2017-dataset/coco2017'
train_set ='train2017'
validation_set ='val2017'
test_set = 'test2017'

train_path = os.path.join(data_set_root, train_set)

val_path = os.path.join(data_set_root, validation_set)

test_path = os.path.join(data_set_root, test_set)

In [None]:
train_image_path = list(Path(train_path).rglob("*.*"))
val_image_path = list(Path(val_path).rglob("*.*"))
test_image_path = list(Path(test_path).rglob("*.*"))

print(len(train_image_path), len(val_image_path), len(test_image_path))

In [None]:
img = cv2.imread(train_image_path[1])
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (224, 224))
plt.imshow(img)
img.shape

# parameters

In [None]:
image_size = 256

batch_size = 64

# Making Datasets and DataLoaders

In [None]:
import albumentations as A

In [None]:

class ColorizationDataset(Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = A.Compose([
                A.Resize(image_size, image_size),
                A.HorizontalFlip(p=0.4),
                A.VerticalFlip(p=0.4),
                A.RandomRotate90(p=0.5),
                A.RandomBrightnessContrast(p=0.2),
                A.RandomGamma (gamma_limit=(70, 130), p=0.2),
            ])
        elif split == 'val':
            self.transforms = A.Compose([
                A.Resize(image_size, image_size)
            ])

        self.split = split
        self.size = image_size
        self.paths = paths

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = np.array(img)
        augmented = self.transforms(image = img)
        img = augmented['image']
        
        img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b
        img_lab = transforms.ToTensor()(img_lab)
        L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
        ab = img_lab[[1, 2], ...] / 110. # Between -1 and 1

        return {'L': L, 'ab': ab}

    def __len__(self):
        return len(self.paths)

def make_dataloaders(batch_size=16, n_workers=4, pin_memory=True, **kwargs): # A handy function to make our dataloaders
    dataset = ColorizationDataset(**kwargs)
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,
                            pin_memory=pin_memory)
    return dataloader

In [None]:
train_dl = make_dataloaders(batch_size = batch_size, paths=train_image_path, split='train')
val_dl = make_dataloaders(batch_size = batch_size, paths=val_image_path, split='val')

data = next(iter(train_dl))
Ls, abs_ = data['L'], data['ab']
print(Ls.shape, abs_.shape)
print(len(train_dl), len(val_dl))

## Model

### Generator

In [None]:
class UnetBlock(nn.Module):
    def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False,
                 innermost=False, outermost=False):
        super().__init__()
        self.outermost = outermost
        if input_c is None: input_c = nf
        downconv = nn.Conv2d(input_c, ni, kernel_size=4,
                             stride=2, padding=1, bias=False)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = nn.BatchNorm2d(ni)
        uprelu = nn.ReLU(True)
        upnorm = nn.BatchNorm2d(nf)

        if outermost:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
                                        stride=2, padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4,
                                        stride=2, padding=1, bias=False)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
                                        stride=2, padding=1, bias=False)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]
            if dropout: up += [nn.Dropout(0.5)]
            model = down + [submodule] + up
        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)

class Unet(nn.Module):
    def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64):
        super().__init__()
        unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True)
        for _ in range(n_down - 5):
            unet_block = UnetBlock(num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True)
        out_filters = num_filters * 8
        for _ in range(3):
            unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block)
            out_filters //= 2
        self.model = UnetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True)

    def forward(self, x):
        return self.model(x)

### Discriminator

In [None]:
class PatchDiscriminator(nn.Module):
    def __init__(self, input_c, num_filters=64, n_down=3):
        super().__init__()
        model = [self.get_layers(input_c, num_filters, norm=False)]
        model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2)
                          for i in range(n_down)] # the 'if' statement is taking care of not using
                                                  # stride of 2 for the last block in this loop
        model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False, act=False)] # Make sure to not use normalization or
                                                                                             # activation for the last layer of the model
        self.model = nn.Sequential(*model)

    def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True): # when needing to make some repeatitive blocks of layers,
        layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)]          # it's always helpful to make a separate method for that purpose
        if norm: layers += [nn.BatchNorm2d(nf)]
        if act: layers += [nn.LeakyReLU(0.2, True)]
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

 blocks:

In [None]:
PatchDiscriminator(3)

 output shape:

In [None]:
discriminator = PatchDiscriminator(3)
dummy_input = torch.randn(16, 3, 256, 256) # batch_size, channels, size, size
out = discriminator(dummy_input)
out.shape

### 1.5- GAN Loss

In [None]:
class GANLoss(nn.Module):
    def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
        super().__init__()
        self.register_buffer('real_label', torch.tensor(real_label))
        self.register_buffer('fake_label', torch.tensor(fake_label))
        if gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode == 'lsgan':
            self.loss = nn.MSELoss()

    def get_labels(self, preds, target_is_real):
        if target_is_real:
            labels = self.real_label
        else:
            labels = self.fake_label
        return labels.expand_as(preds)

    def __call__(self, preds, target_is_real):
        labels = self.get_labels(preds, target_is_real)
        loss = self.loss(preds, labels)
        return loss

### Model Initialization

In [None]:
def init_weights(net, init='norm', gain=0.02):

    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and 'Conv' in classname:
            if init == 'norm':
                nn.init.normal_(m.weight.data, mean=0.0, std=gain)
            elif init == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=gain)
            elif init == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')

            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)
        elif 'BatchNorm2d' in classname:
            nn.init.normal_(m.weight.data, 1., gain)
            nn.init.constant_(m.bias.data, 0.)

    net.apply(init_func)
    print(f"model initialized with {init} initialization")
    return net

def init_model(model, device):
    model = model.to(device)
    model = init_weights(model)
    return model

### Model

In [None]:
class MainModel(nn.Module):
    def __init__(self, device, net_G=None, lr_G=2e-4, lr_D=2e-4,
                 beta1=0.5, beta2=0.999, lambda_L1=100., weight_decay=1e-5):
        super().__init__()

        self.device = device
        self.lambda_L1 = lambda_L1

        if net_G is None:
            self.net_G = init_model(Unet(input_c=1, output_c=2, n_down=8, num_filters=64), self.device)
        else:
            self.net_G = net_G.to(self.device)
        self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device)
        self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
        self.L1criterion = nn.L1Loss()
        self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2), weight_decay=weight_decay)
        self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2), weight_decay=weight_decay)

    def set_requires_grad(self, model, requires_grad=True):
        for p in model.parameters():
            p.requires_grad = requires_grad

    def setup_input(self, data):
        self.L = data['L'].to(self.device)
        self.ab = data['ab'].to(self.device)

    def forward(self):
        self.fake_color = self.net_G(self.L)

    def backward_D(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image.detach())
        self.loss_D_fake = self.GANcriterion(fake_preds, False)
        real_image = torch.cat([self.L, self.ab], dim=1)
        real_preds = self.net_D(real_image)
        self.loss_D_real = self.GANcriterion(real_preds, True)
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()

    def backward_G(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image)
        self.loss_G_GAN = self.GANcriterion(fake_preds, True)
        self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()

    def optimize(self):
        self.forward()
        self.net_D.train()
        self.set_requires_grad(self.net_D, True)
        self.opt_D.zero_grad()
        self.backward_D()
        self.opt_D.step()

        self.net_G.train()
        self.set_requires_grad(self.net_D, False)
        self.opt_G.zero_grad()
        self.backward_G()
        self.opt_G.step()

In [None]:
def create_model(device):
    model = MainModel(device)
    model = nn.DataParallel(model, device_ids=[0, 1])
    model = model.to(device)
    return model

In [None]:
model = create_model(device)

### Utility functions

In [None]:
class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.count, self.avg, self.sum = [0.] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += count * val
        self.avg = self.sum / self.count

def create_loss_meters():
    loss_D_fake = AverageMeter()
    loss_D_real = AverageMeter()
    loss_D = AverageMeter()
    loss_G_GAN = AverageMeter()
    loss_G_L1 = AverageMeter()
    loss_G = AverageMeter()

    return {'loss_D_fake': loss_D_fake,
            'loss_D_real': loss_D_real,
            'loss_D': loss_D,
            'loss_G_GAN': loss_G_GAN,
            'loss_G_L1': loss_G_L1,
            'loss_G': loss_G}

def update_losses(model, loss_meter_dict, count):
    for loss_name, loss_meter in loss_meter_dict.items():
        loss = getattr(model, loss_name)
        loss_meter.update(loss.item(), count=count)

def lab_to_rgb(L, ab):
    """
    Takes a batch of images
    """

    L = (L + 1.) * 50.
    ab = ab * 110.
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    return np.stack(rgb_imgs, axis=0)

def visualize(model, data, save=True):
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    model.net_G.train()
    fake_color = model.fake_color.detach()
    real_color = model.ab
    L = model.L
    fake_imgs = lab_to_rgb(L, fake_color)
    real_imgs = lab_to_rgb(L, real_color)
    fig = plt.figure(figsize=(15, 8))
    for i in range(5):
        ax = plt.subplot(3, 5, i + 1)
        ax.imshow(L[i][0].cpu(), cmap='gray')
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 5)
        ax.imshow(fake_imgs[i])
        ax.axis("off")
        ax = plt.subplot(3, 5, i + 1 + 10)
        ax.imshow(real_imgs[i])
        ax.axis("off")
    plt.show()
    if save:
        fig.savefig(f"colorization_{time.time()}.png")

def log_results(loss_meter_dict):
    for loss_name, loss_meter in loss_meter_dict.items():
        print(f"{loss_name}: {loss_meter.avg:.5f}")

## Training 

### Set Up

In [None]:
# learning_rate = 

epochs = 50 

In [None]:
import wandb

In [None]:
PROJECT = "Colorizing"
RESUME = "allow"
WANDB_KEY = "d9d14819dddd8a35a353b5c0b087e0f60d717140"

In [None]:
wandb.login(
    key = WANDB_KEY,
)

In [None]:
wandb.init(
     project=PROJECT,
     resume=RESUME,
     name="GanColorization_init",
     config={
         "epochs": epochs,
         "batch_size": batch_size,
     },
 )
wandb.watch(model)

In [None]:
import torch
import os

class EarlyStopping:
    def __init__(self, patience=5, delta=0, save_path="best_model.pth"):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False
        self.save_path = save_path

    def __call__(self, val_loss, model):
        if val_loss < self.best_loss - self.delta:
            self.best_loss = val_loss
            self.counter = 0
            print(f"Validation loss improved to {val_loss:.5f}. Saving model to {self.save_path}")
            torch.save(model.state_dict(), self.save_path)  # Save the best model
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

In [None]:
from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim

def validate(model, val_dl, loss_meter_dict):
    model.eval()
    psnr_vals = []
    ssim_vals = []
    with torch.no_grad():
        for data in tqdm(val_dl, desc="Validation", leave=False):
            model.module.setup_input(data)
            model.module.forward()
            update_losses(model.module, loss_meter_dict, count=data['L'].size(0))
            
            fake_color = model.module.fake_color.detach()
            real_color = model.module.ab
            L = model.module.L

            # 3) Convert Lab → RGB.  lab_to_rgb returns a NumPy array of shape (B, H, W, 3) with values in [0,1].
            fake_imgs = lab_to_rgb(L, fake_color)
            real_imgs = lab_to_rgb(L, real_color)  # np.ndarray (B, H, W, 3)

            # 4) Compute PSNR & SSIM per-item
            for i in range(fake_imgs.shape[0]):
                real_np = real_imgs[i]  # already (H, W, 3), dtype float in [0,1]
                fake_np = fake_imgs[i]

                # PSNR in [0, 1]-space:
                psnr_vals.append(psnr(real_np, fake_np, data_range=2.0))

                # SSIM: use channel_axis=2 instead of multichannel=True
                ssim_vals.append(
                    ssim(
                        real_np,
                        fake_np,
                        data_range=2.0,
                        channel_axis=2,
                        #win_size=win_size,  # ensure your H, W ≥ win_size,
                                            # otherwise pick a smaller odd integer.
                    )
                )

    mean_psnr = float(np.mean(psnr_vals)) if psnr_vals else 0.0
    mean_ssim = float(np.mean(ssim_vals)) if ssim_vals else 0.0
    return mean_psnr, mean_ssim

def train(model, train_dl, loss_meter_dict, display_every=3):
    model.train()
    i = 0
    for data in tqdm(train_dl, desc="Training", leave=False):
        model.module.setup_input(data)
        model.module.optimize()
        update_losses(model.module, loss_meter_dict, count=data['L'].size(0))
        i += 1
        if i % display_every == 0:
            log_results(loss_meter_dict)    

def train_model(model, train_dl, val_dl, epochs, display_every=1, patience=5, save_dir="models"):
    os.makedirs(save_dir, exist_ok=True)
    early_stopping = EarlyStopping(patience=patience, save_path=os.path.join(save_dir, "best_model.pth"))
    schedulerG = torch.optim.lr_scheduler.ReduceLROnPlateau(model.module.opt_G, mode='min', patience=3)
    # schedulerD = torch.optim.lr_scheduler.ReduceLROnPlateau(model.module.opt_D, mode='min', patience=3)

    for e in range(epochs):
        print(f"\nEpoch {e+1}/{epochs}")
        train_loss_meters = create_loss_meters()
        train(model, train_dl, train_loss_meters, display_every)
        print("Training Losses:")
        log_results(train_loss_meters)

        val_loss_meters = create_loss_meters()
        val_psnr, val_ssim = validate(model, val_dl, val_loss_meters)
        print("Validation Losses:")
        log_results(val_loss_meters)

        val_loss = val_loss_meters["loss_G"].avg  # Use loss_G as validation loss for early stopping and scheduler
        schedulerG.step(val_loss)

        wandb.log({
            "train/loss_D": train_loss_meters["loss_D"].avg,
            "train/loss_G": train_loss_meters["loss_G"].avg,
            "val/loss_D": val_loss_meters["loss_D"].avg,
            "val/loss_G": val_loss_meters["loss_G"].avg,
            "val/PSNR": val_psnr,
            "val/SSIM": val_ssim
        })
        print(f"Validation PSNR: {val_psnr:.4f}, SSIM: {val_ssim:.4f}")

        # Check early stopping
        early_stopping(val_loss, model.module)
        if early_stopping.early_stop:
            print("Early stopping triggered.")
            break

    # Save the model at the last epoch
    torch.save(model.module.state_dict(), os.path.join(save_dir, "last_epoch_model.pth"))
    print("Model saved at last epoch.")


In [None]:
train_model(model, train_dl, val_dl, epochs, 100, save_dir = '/kaggle/working/')