In [None]:
!nvidia-smi

In [None]:
pip install -U albumentations

# Import

In [None]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import cv2
import math
from tqdm.notebook import trange, tqdm
import random

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, DistributedSampler
import torch.nn.functional as F
#from torch.distributions import Categorical

import torchvision
from torchvision.utils import make_grid
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, random_split

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from torch.cuda.amp import GradScaler, autocast

from PIL import Image
from skimage.color import rgb2lab, lab2rgb
from sklearn.metrics import roc_auc_score

from pathlib import Path

import time

import wandb

torch.backends.cuda.matmul.allow_tf32 = True

In [None]:
data_set_root='/kaggle/input/coco-2017-dataset/coco2017'
train_set ='train2017'
validation_set ='val2017'
test_set = 'test2017'

train_path = os.path.join(data_set_root, train_set)

val_path = os.path.join(data_set_root, validation_set)

test_path = os.path.join(data_set_root, test_set)

In [None]:
train_image_path = list(Path(train_path).rglob("*.*"))
val_image_path = list(Path(val_path).rglob("*.*"))
test_image_path = list(Path(test_path).rglob("*.*"))

print(len(train_image_path), len(val_image_path), len(test_image_path))

In [None]:
img = cv2.imread(train_image_path[1])
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (224, 224))
plt.imshow(img)
img.shape

# Parameters

In [None]:
image_size = 224

batch_size = 64

# Data Processing

In [None]:
# Load the 313 quantized color bins
pts_in_hull = np.load('/kaggle/input/colorful-image-colorization-parameters/pts_in_hull.npy')  # shape (313, 2), each is an (a, b) pair

In [None]:
from sklearn.neighbors import NearestNeighbors

def soft_encode_ab(ab_image, pts_in_hull, sigma=5):
    """
    ab_image: torch tensor of shape (2, H, W)
    pts_in_hull: numpy array of shape (313, 2)
    returns: soft encoding of shape (H, W, 313)
    """
    C, H, W = ab_image.shape
    assert C == 2, "Expected input with 2 channels (a and b)"

    # Convert to (H, W, 2)
    ab_image = ab_image.permute(1, 2, 0)
    ab_flat = ab_image.reshape(-1, 2).cpu().numpy()
    
    # Nearest neighbors search
    nbrs = NearestNeighbors(n_neighbors=5, algorithm='ball_tree').fit(pts_in_hull)
    dists, inds = nbrs.kneighbors(ab_flat)

    # Gaussian kernel
    dists = np.clip(dists, 0, 30)
    weights = np.exp(-dists**2 / (2 * sigma**2))
    
    weights_sum = np.sum(weights, axis=1, keepdims=True)
    weights_sum[weights_sum == 0] = 1e-8  # prevent div by zero
    
    weights /= weights_sum #normalize

    # Soft encoding
    soft_enc = np.zeros((ab_flat.shape[0], 313), dtype=np.float32)
    for i in range(5):
        soft_enc[np.arange(ab_flat.shape[0]), inds[:, i]] = weights[:, i]

    # Reshape back to (H, W, 313)
    return torch.from_numpy(soft_enc).reshape(H, W, 313)


In [None]:
class ColorizationDataset(Dataset):
    def __init__(self, paths, Size=(224, 224), transform=None, pts_in_hull = None):
        self.paths = paths
        self.height, self.width = Size
        self.transform = transform
        self.pts_in_hull = pts_in_hull

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = img.resize((self.height, self.width), Image.BICUBIC)
        img = np.array(img)  # Convert PIL to NumPy (Albumentations requires NumPy)

        # Apply Albumentations transform if provided
        if self.transform:
            transformed = self.transform(image=img)
            img = transformed["image"]

        # If img is in (C, H, W) format, convert to (H, W, C)
        if img.shape[0] == 3:
            img = np.transpose(img, (1, 2, 0))  # Convert (C, H, W) → (H, W, C)

        # Convert RGB → LAB
        img_lab = rgb2lab(img).astype("float32")  # (H, W, 3)

        # Extract L and ab channels
        L = img_lab[:, :, 0] 
        ab = img_lab[:, :, 1:]

        # Convert to PyTorch tensors
        L = torch.tensor(L, dtype=torch.float32).unsqueeze(0)  # (1, H, W)
        ab = torch.tensor(ab, dtype=torch.float32).permute(2, 0, 1)  # (2, H, W)

        # ab: [2, H, W] ground truth ab channels
        ab_downsampled = F.interpolate(
            ab.unsqueeze(0),  # Add batch dim -> shape: [1, 2, H, W]
            scale_factor=0.25,
            mode='bilinear',
            align_corners=False
        ).squeeze(0)  # Remove batch dim -> back to [2, H/4, W/4]  

        # Soft encode ab (call external function)
        soft_encoded = soft_encode_ab(ab_image = ab_downsampled, pts_in_hull = self.pts_in_hull)  # (H/4, W/4, 313), NumPy
        soft_encoded = soft_encoded.float().permute(2, 0, 1)  # To Tensor (313, H/4, W/4)

        return {'L': L, 'ab': ab, 'soft_ab': soft_encoded}

In [None]:
transform = A.Compose([
    A.HorizontalFlip(p=0.4),
    A.VerticalFlip(p=0.4),
    A.RandomRotate90(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.RandomGamma (gamma_limit=(70, 130), p=0.2),
    ToTensorV2(),
])

In [None]:
def make_dataloaders(batch_size=16, n_workers=4, pin_memory=False, **kwargs): # A handy function to make our dataloaders
    dataset = ColorizationDataset(**kwargs)
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,
                            pin_memory=pin_memory)
    return dataloader

In [None]:
train_loader = make_dataloaders(batch_size = batch_size, pin_memory=True, paths=train_image_path, transform = transform, pts_in_hull = pts_in_hull)

val_loader = make_dataloaders(batch_size = batch_size, pin_memory=True, paths=val_image_path, pts_in_hull = pts_in_hull)

In [None]:
data = next(iter(train_loader))
Ls, abs_ = data['L'], data['soft_ab']
print(Ls.shape, abs_.shape)
print(len(train_loader), len(val_loader))

# Model

In [None]:
class BaseColor(nn.Module):
	def __init__(self):
		super(BaseColor, self).__init__()

		self.l_cent = 50.
		self.l_norm = 100.
		self.ab_norm = 110.

	def normalize_l(self, in_l):
		return (in_l-self.l_cent)/self.l_norm   # Normalize L to [-0.5, 0.5]

	def unnormalize_l(self, in_l):
		return in_l*self.l_norm + self.l_cent

	def normalize_ab(self, in_ab):
		return in_ab/self.ab_norm    # Normalize ab to [-1, 1]

	def unnormalize_ab(self, in_ab):
		return in_ab*self.ab_norm

In [None]:
class ColorizationNet(BaseColor):
    def __init__(self, norm_layer=nn.BatchNorm2d):
        super(ColorizationNet, self).__init__()

        model1=[nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True),]
        model1+=[nn.ReLU(True),]
        model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True),]
        model1+=[nn.ReLU(True),]
        model1+=[norm_layer(64),]

        model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
        model2+=[nn.ReLU(True),]
        model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True),]
        model2+=[nn.ReLU(True),]
        model2+=[norm_layer(128),]

        model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
        model3+=[nn.ReLU(True),]
        model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
        model3+=[nn.ReLU(True),]
        model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True),]
        model3+=[nn.ReLU(True),]
        model3+=[norm_layer(256),]

        model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
        model4+=[nn.ReLU(True),]
        model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
        model4+=[nn.ReLU(True),]
        model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
        model4+=[nn.ReLU(True),]
        model4+=[norm_layer(512),]

        model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
        model5+=[nn.ReLU(True),]
        model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
        model5+=[nn.ReLU(True),]
        model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
        model5+=[nn.ReLU(True),]
        model5+=[norm_layer(512),]

        model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
        model6+=[nn.ReLU(True),]
        model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
        model6+=[nn.ReLU(True),]
        model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
        model6+=[nn.ReLU(True),]
        model6+=[norm_layer(512),]

        model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
        model7+=[nn.ReLU(True),]
        model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
        model7+=[nn.ReLU(True),]
        model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
        model7+=[nn.ReLU(True),]
        model7+=[norm_layer(512),]

        model8=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True),]
        model8+=[nn.ReLU(True),]
        model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
        model8+=[nn.ReLU(True),]
        model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
        model8+=[nn.ReLU(True),]

        model8+=[nn.Conv2d(256, 313, kernel_size=1, stride=1, padding=0, bias=True),]

        self.model1 = nn.Sequential(*model1)
        self.model2 = nn.Sequential(*model2)
        self.model3 = nn.Sequential(*model3)
        self.model4 = nn.Sequential(*model4)
        self.model5 = nn.Sequential(*model5)
        self.model6 = nn.Sequential(*model6)
        self.model7 = nn.Sequential(*model7)
        self.model8 = nn.Sequential(*model8)

        #self.softmax = nn.Softmax(dim=1)
        #self.model_out = nn.Conv2d(313, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=False)
        #self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear')

    def forward(self, input_l):
        conv1_2 = self.model1(self.normalize_l(input_l))
        conv2_2 = self.model2(conv1_2)
        conv3_3 = self.model3(conv2_2)
        conv4_3 = self.model4(conv3_3)
        conv5_3 = self.model5(conv4_3)
        conv6_3 = self.model6(conv5_3)
        conv7_3 = self.model7(conv6_3)
        conv8_3 = self.model8(conv7_3)
        #out_reg = self.model_out(self.softmax(conv8_3))

        #return self.unnormalize_ab(self.upsample4(out_reg))
        #return self.softmax(conv8_3)
        return conv8_3

In [None]:
def init_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

In [None]:
def create_model(device):
    model = ColorizationNet()
    model.apply(init_weights)
    model = nn.DataParallel(model, device_ids=[0, 1])
    model = model.to(device)
    return model

In [None]:
#summary(model, (3, 256, 256))

# Hyperparameters

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
device

In [None]:
learning_rate = 3e-5

epochs = 160

model_path = '/kaggle/working/model.pth'

# Set up

In [None]:
model = create_model(device)

In [None]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total trainable parameters: {total_params:,}')

In [None]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas = (0.9, 0.99), weight_decay = 1e-3)

scaler = GradScaler()  # Mixed Precision Training

#loss_fn = nn.CrossEntropyLoss()


In [None]:
import torch.nn.functional as F

def cross_entropy_with_soft_targets(pred_logits, soft_targets, pixel_weights=None):
    """
    pred_logits: (B, 313, H, W) — model outputs (logits)
    soft_targets: (B, 313, H, W) — soft-encoded targets
    pixel_weights: (B, H, W) — optional pixel-wise weighting
    """
    log_preds = F.log_softmax(pred_logits, dim=1)  # (B, 313, H, W)
    log_preds = torch.clamp(log_preds, min=-100)  # prevent -inf

    # Element-wise product and sum over channels
    loss = -(soft_targets * log_preds).sum(dim=1)  # (B, H, W)

    if pixel_weights is not None:
        loss = loss * pixel_weights  # (B, H, W)

    return loss.mean()

In [None]:
def compute_rebalance_weights(p, Q=313, lam=0.5):
    """
    p: numpy array of shape (313,), empirical distribution over color bins
    returns: weight array of shape (313,)
    """
    mixed = (1 - lam) * p + lam / Q
    w = 1 / mixed
    w *= 1 / np.sum(w * p)  # normalize to expected value 1
    return w

In [None]:
# Load or compute your prior (e.g., empirical distribution of 313 bins)
prior_probs = np.load('/kaggle/input/colorful-image-colorization-parameters/prior_probs.npy')  # shape: (313,)
rebalance_weights = compute_rebalance_weights(prior_probs)  # shape: (313,)
rebalance_weights = torch.tensor(rebalance_weights, dtype=torch.float32).to(device)

In [None]:
inverse_weights = compute_rebalance_weights(prior_probs, Q=313, lam=0)
inverse_weights = torch.tensor(inverse_weights, dtype=torch.float32).to(device)  # (313,)

In [None]:
def get_pixel_weight(soft_label, weights):
    """
    soft_label: (B, 313, H, W) soft encoding tensor
    weights: (313,) numpy array or tensor
    returns: (B, H, W) tensor of pixel weights
    """
    if soft_label.dim() != 4 or soft_label.shape[1] != 313:
        raise ValueError(f"Expected shape (B, 313, H, W), got {soft_label.shape}")

    # Move channel to last dimension: (B, H, W, 313)
    soft_label = soft_label.permute(0, 2, 3, 1)

    max_idx = torch.argmax(soft_label, dim=3)  # (B, H, W)

    if isinstance(weights, np.ndarray):
        weight_tensor = torch.from_numpy(weights).to(soft_label.device)
    else:
        weight_tensor = weights.to(soft_label.device)

    return weight_tensor[max_idx]  # (B, H, W)

In [None]:
def soft_to_ab(soft_map, color_bins):
    """
    soft_map: (B, 313, H, W) — soft probability over quantized bins
    color_bins: (313, 2) — ab values for each quantized bin
    Returns:
        ab_map: (B, 2, H, W)
    """
    B, _, H, W = soft_map.shape
    color_bins = torch.tensor(color_bins, dtype=soft_map.dtype, device=soft_map.device)  # (313, 2)

    # reshape softmap to (B, H, W, 313), then do a weighted sum along 313
    soft_map = soft_map.permute(0, 2, 3, 1)  # (B, H, W, 313)
    ab_map = torch.matmul(soft_map, color_bins)  # (B, H, W, 2)
    ab_map = ab_map.permute(0, 3, 1, 2)  # (B, 2, H, W)
    return ab_map

In [None]:
def compute_class_balanced_auc(pred_ab, true_ab, inverse_weights, pts_in_hull, max_thresh=150):
    """
    pred_ab, true_ab: (B, 2, H, W)
    color_freq: (313,) normalized class frequency
    """
    # Compute L2 distance per pixel
    l2_error = torch.linalg.norm(pred_ab - true_ab, dim=1)   # (B, H, W)
    error_flat = l2_error.flatten().cpu().numpy()

    # Inverse class weights
    weight = inverse_weights

    # Get the class for each pixel based on true_ab (use clustering or quantization)
    true_ab_exp = true_ab.permute(0, 2, 3, 1).reshape(-1, 2).cpu().numpy()
    
    # Compute L2 distance to cluster centers
    diff = true_ab_exp[:, None, :] - pts_in_hull[None, :, :]  # (47040, 313, 2)
    dist = np.linalg.norm(diff, axis=2)  # (47040, 313)

    bin_ids = np.argmin(dist, axis=1)  # Find the closest bin for each pixel

    # Use the class ids (bin_ids) to get the corresponding weight
    weighted_error = weight[bin_ids]
    weighted_error = weighted_error.cpu().numpy() 
    total_weight = weighted_error.sum() + 1e-6  # avoid divide-by-zero

    # Compute the CDF: cumulative distribution of errors weighted by class importance
    thresholds = np.arange(0, max_thresh + 1)
    cdf = []
    
    for t in thresholds:
        
        # Calculate cumulative weighted sum for pixels where error < threshold t
        cdf_value = weighted_error[error_flat < t].sum()
        cdf.append(cdf_value/ total_weight)

    # Normalize the CDF and compute AUC
    cdf = np.array(cdf)
    auc = np.trapz(cdf, thresholds) / max_thresh

    return auc, thresholds, cdf

In [None]:
def annealed_mean(prob, pts_in_hull, T=0.38):
    """
    prob: (H, W, 313), predicted softmax probabilities
    pts_in_hull: (313, 2)
    returns: ab image of shape (H, W, 2)
    """
    prob = prob ** (1 / T)
    prob /= np.sum(prob, axis=2, keepdims=True)
    ab = np.dot(prob, pts_in_hull)  # weighted mean
    return ab  # shape (H, W, 2)


# Train

In [None]:
from skimage.metrics import structural_similarity as ssim, peak_signal_noise_ratio as psnr

def validate(model, data_loader, rebalance_weights, inverse_weights, pts_in_hull, device):
    model.eval()
    ssim_vals = []
    psnr_vals = []

    with torch.no_grad():
        for batch in data_loader:
            L = batch['L'].to(device)  # (B, 1, H, W)
            ab_gt = batch['ab'].to(device)  # (B, 2, H, W)

            outputs = model(L)  # (B, 313, H/4, W/4)

            logits_upsampled = F.interpolate(outputs, scale_factor=4, mode='bilinear', align_corners=False)

            pred_soft = torch.softmax(logits_upsampled, dim=1)

            # Convert softmax predictions to ab
            pred_ab = annealed_mean(pred_soft, pts_in_hull)  # (B, 2, H, W)
            true_ab = ab_gt  # (B, 2, H, W)

            # Convert L + ab to LAB and then to RGB for SSIM/PSNR
            L_np = L.cpu().numpy()
            pred_ab_np = pred_ab.cpu().numpy()
            true_ab_np = true_ab.cpu().numpy()

            for i in range(L_np.shape[0]):
                # Prepare LAB images
                L_img = L_np[i, 0]
                pred_ab_img = pred_ab_np[i].transpose(1, 2, 0)
                true_ab_img = true_ab_np[i].transpose(1, 2, 0)

                # Unnormalize
                L_img = (L_img + 0) 
                pred_lab = np.zeros((L_img.shape[0], L_img.shape[1], 3))
                pred_lab[:, :, 0] = L_img
                pred_lab[:, :, 1:] = pred_ab_img
                true_lab = np.zeros_like(pred_lab)
                true_lab[:, :, 0] = L_img
                true_lab[:, :, 1:] = true_ab_img

                # Convert to RGB
                pred_rgb = lab2rgb(pred_lab)
                true_rgb = lab2rgb(true_lab)

                # Compute SSIM and PSNR
                ssim_val = ssim(true_rgb, pred_rgb, data_range=1.0, channel_axis=2)
                psnr_val = psnr(true_rgb, pred_rgb, data_range=1.0)
                ssim_vals.append(ssim_val)
                psnr_vals.append(psnr_val)

    mean_ssim = float(np.mean(ssim_vals))
    mean_psnr = float(np.mean(psnr_vals))
    return mean_ssim, mean_psnr

In [None]:
val_ssim, val_psnr = validate(model, val_loader, rebalance_weights, inverse_weights, pts_in_hull, device)
print(f"Validation SSIM: {val_ssim:.4f}, PSNR: {val_psnr:.4f}")