In [None]:
# IMPORTANT: SOME KAGGLE DATA SOURCES ARE PRIVATE
# RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES.
import kagglehub
kagglehub.login()


In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

imagenet_object_localization_challenge_path = kagglehub.competition_download('imagenet-object-localization-challenge')
dipitgolechha7_imagenetsubsub_path = kagglehub.dataset_download('dipitgolechha7/imagenetsubsub')

print('Data source import complete.')


In [None]:
import os
import glob
import time
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb
Path.ls = lambda x: list(x.iterdir())

import torch
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18
from fastai.vision.models.unet import DynamicUnet

In [None]:
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split

# Define the path to the training data
color_path = '/kaggle/input/imagenet-object-localization-challenge/ILSVRC/Data/CLS-LOC/train'

# Get all class directories
class_dirs = [d for d in os.listdir(color_path) if os.path.isdir(os.path.join(color_path, d))]

max_images = 10000
all_images = []
np.random.seed(100)

# Gather up to 1000 images total
for class_dir in class_dirs:
    if len(all_images) >= max_images:
        break
    class_path_pattern = os.path.join(color_path, class_dir, '*.JPEG')
    class_images = glob.glob(class_path_pattern)
    np.random.shuffle(class_images)  # Shuffle to get a random sample from this class
    needed = max_images - len(all_images)
    to_add = class_images[:needed]
    all_images.extend(to_add)

print("Total images collected:", len(all_images))

# Split into training and validation sets (80% train, 20% val)
train_paths, val_paths = train_test_split(all_images, test_size=0.2, random_state=123)

print("Total training images:", len(train_paths))
print("Total validation images:", len(val_paths))

# Display a few training images
_, axes = plt.subplots(4, 4, figsize=(12, 12))
for ax, img_path in zip(axes.flatten(), train_paths[:16]):
    img = Image.open(img_path)
    ax.imshow(img)
    ax.axis("off")
plt.tight_layout()
plt.show()

In [None]:
SIZE = 256

class ColorizationDataset(Dataset):
    def __init__(self, paths, split='train'):
        if split == 'train':
            self.transforms = transforms.Compose([
                transforms.Resize((SIZE, SIZE), Image.BICUBIC),
                transforms.RandomHorizontalFlip(),
            ])
        else:
            self.transforms = transforms.Resize((SIZE, SIZE), Image.BICUBIC)

        self.split = split
        self.size = SIZE
        self.paths = paths

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = self.transforms(img)
        img = np.array(img)
        img_lab = rgb2lab(img).astype("float32")
        img_lab = transforms.ToTensor()(img_lab)
        L = img_lab[[0], ...] / 50. - 1.
        ab = img_lab[[1, 2], ...] / 110.
        return {'L': L, 'ab': ab}

    def __len__(self):
        return len(self.paths)

def make_dataloaders(paths, split='train', batch_size=16, n_workers=4, pin_memory=True):
    dataset = ColorizationDataset(paths, split=split)
    dataloader = DataLoader(dataset, batch_size=batch_size,
                            num_workers=n_workers, pin_memory=pin_memory, shuffle=(split=='train'))
    return dataloader

train_dl = make_dataloaders(train_paths, split='train', batch_size=16)
val_dl = make_dataloaders(val_paths, split='val', batch_size=16)

data = next(iter(train_dl))
Ls, abs_ = data['L'], data['ab']
print("Train batch shapes:", Ls.shape, abs_.shape)
print("Number of batches:", len(train_dl), "train,", len(val_dl), "val")

In [None]:
class UnetBlock(nn.Module):
    def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False,
                 innermost=False, outermost=False):
        super().__init__()
        self.outermost = outermost
        if input_c is None: input_c = nf
        downconv = nn.Conv2d(input_c, ni, kernel_size=4,
                             stride=2, padding=1, bias=False)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = nn.BatchNorm2d(ni)
        uprelu = nn.ReLU(True)
        upnorm = nn.BatchNorm2d(nf)

        if outermost:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
                                        stride=2, padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4,
                                        stride=2, padding=1, bias=False)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
                                        stride=2, padding=1, bias=False)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]
            if dropout: up += [nn.Dropout(0.5)]
            model = down + [submodule] + up
        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)

class Unet(nn.Module):
    def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64):
        super().__init__()
        unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True)
        for _ in range(n_down - 5):
            unet_block = UnetBlock(num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True)
        out_filters = num_filters * 8
        for _ in range(3):
            unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block)
            out_filters //= 2
        self.model = UnetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True)

    def forward(self, x):
        return self.model(x)

class PatchDiscriminator(nn.Module):
    def __init__(self, input_c, num_filters=64, n_down=3):
        super().__init__()
        model = [self.get_layers(input_c, num_filters, norm=False)]
        model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2)
                  for i in range(n_down)]
        model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False, act=False)]
        self.model = nn.Sequential(*model)

    def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True):
        layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)]
        if norm: layers += [nn.BatchNorm2d(nf)]
        if act: layers += [nn.LeakyReLU(0.2, True)]
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [None]:
class GANLoss(nn.Module):
    def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
        super().__init__()
        self.register_buffer('real_label', torch.tensor(real_label))
        self.register_buffer('fake_label', torch.tensor(fake_label))
        if gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode == 'lsgan':
            self.loss = nn.MSELoss()

    def get_labels(self, preds, target_is_real):
        if target_is_real:
            labels = self.real_label
        else:
            labels = self.fake_label
        return labels.expand_as(preds)

    def forward(self, preds, target_is_real):
        labels = self.get_labels(preds, target_is_real)
        loss = self.loss(preds, labels)
        return loss

def init_weights(net, init='norm', gain=0.02):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and 'Conv' in classname:
            if init == 'norm':
                nn.init.normal_(m.weight.data, mean=0.0, std=gain)
            elif init == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=gain)
            elif init == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')

            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)
        elif 'BatchNorm2d' in classname:
            nn.init.normal_(m.weight.data, 1., gain)
            nn.init.constant_(m.bias.data, 0.)

    net.apply(init_func)
    print(f"model initialized with {init} initialization")
    return net

def init_model(model, device):
    model = model.to(device)
    model = init_weights(model)
    return model

class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.count, self.avg, self.sum = [0.] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += count * val
        self.avg = self.sum / self.count

def create_loss_meters():
    loss_D_fake = AverageMeter()
    loss_D_real = AverageMeter()
    loss_D = AverageMeter()
    loss_G_GAN = AverageMeter()
    loss_G_L1 = AverageMeter()
    loss_G = AverageMeter()

    return {'loss_D_fake': loss_D_fake,
            'loss_D_real': loss_D_real,
            'loss_D': loss_D,
            'loss_G_GAN': loss_G_GAN,
            'loss_G_L1': loss_G_L1,
            'loss_G': loss_G}

def update_losses(model, loss_meter_dict, count):
    for loss_name, loss_meter in loss_meter_dict.items():
        loss = getattr(model, loss_name)
        loss_meter.update(loss.item(), count=count)

In [None]:
def lab_to_rgb(L, ab):
    L = (L + 1.) * 50.
    ab = ab * 110.
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    return np.stack(rgb_imgs, axis=0)

def visualize(model, data, save=True):
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    model.net_G.train()
    fake_color = model.fake_color.detach()
    real_color = model.ab
    L = model.L
    fake_imgs = lab_to_rgb(L, fake_color)
    real_imgs = lab_to_rgb(L, real_color)
    fig = plt.figure(figsize=(15, 12))  # Adjusted height for row titles
    rows, cols = 3, 5  # 3 rows (Grayscale, Model generated, Actual), up to 5 columns

    # Add row titles
    plt.subplot(rows, cols, 1).set_title("Grayscale Image", fontsize=16, loc='left')
    plt.subplot(rows, cols, cols + 1).set_title("Model Generated Image", fontsize=16, loc='left')
    plt.subplot(rows, cols, 2 * cols + 1).set_title("Actual Image", fontsize=16, loc='left')

    for i in range(min(5, L.size(0))):
        # Grayscale Image (Row 1)
        ax = plt.subplot(rows, cols, i + 1)
        ax.imshow(L[i][0].cpu(), cmap='gray')
        ax.axis("off")

        # Model Generated Image (Row 2)
        ax = plt.subplot(rows, cols, i + 1 + cols)
        ax.imshow(fake_imgs[i])
        ax.axis("off")

        # Actual Image (Row 3)
        ax = plt.subplot(rows, cols, i + 1 + 2 * cols)
        ax.imshow(real_imgs[i])
        ax.axis("off")

    plt.tight_layout()
    plt.show()
    if save:
        fig.savefig(f"images_new/colorization_{time.time()}.png")

def log_results(loss_meter_dict, log_file, epoch, iteration):
    """
    Log the training results to a CSV file and print them to the console.
    """
    with open(log_file, mode='a', newline='') as file:
        writer = csv.writer(file)
        for loss_name, loss_meter in loss_meter_dict.items():
            print(f"{loss_name}: {loss_meter.avg:.5f}")  # Print to console
            # Append the results to the CSV file
            writer.writerow([epoch, iteration, loss_name, loss_meter.avg])

In [None]:
class MainModel(nn.Module):
    def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4,
                 beta1=0.5, beta2=0.999, lambda_L1=100.):
        super().__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lambda_L1 = lambda_L1

        if net_G is None:
            self.net_G = init_model(Unet(input_c=1, output_c=2, n_down=8, num_filters=64), self.device)
        else:
            self.net_G = net_G.to(self.device)
        self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device)
        self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
        self.L1criterion = nn.L1Loss()
        self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
        self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))

    def set_requires_grad(self, model, requires_grad=True):
        for p in model.parameters():
            p.requires_grad = requires_grad

    def setup_input(self, data):
        self.L = data['L'].to(self.device)
        self.ab = data['ab'].to(self.device)

    def forward(self):
        self.fake_color = self.net_G(self.L)

    def backward_D(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image.detach())
        self.loss_D_fake = self.GANcriterion(fake_preds, False)
        real_image = torch.cat([self.L, self.ab], dim=1)
        real_preds = self.net_D(real_image)
        self.loss_D_real = self.GANcriterion(real_preds, True)
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()

    def backward_G(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image)
        self.loss_G_GAN = self.GANcriterion(fake_preds, True)
        self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()

    def optimize(self):
        self.forward()
        self.net_D.train()
        self.set_requires_grad(self.net_D, True)
        self.opt_D.zero_grad()
        self.backward_D()
        self.opt_D.step()

        self.net_G.train()
        self.set_requires_grad(self.net_D, False)
        self.opt_G.zero_grad()
        self.backward_G()
        self.opt_G.step()

In [None]:
import random

# Ensure DataLoader tensors are on the correct device
def move_batch_to_device(batch, device):
    return {key: value.to(device) for key, value in batch.items()}

import random
def train_model(model, train_dl, epochs, display_every=200, log_file="logs_training_final.csv"):
    # Prepare the CSV file for logging
    with open(log_file, mode='w', newline='') as file:
        writer = csv.writer(file)
        # Write the header
        writer.writerow(["Epoch", "Iteration", "Loss Name", "Average Loss"])

    for e in range(epochs):
        loss_meter_dict = create_loss_meters()
        i = 0
        # Generate a random interval for image visualization for this epoch
        visualize_every = random.randint(1, 499)
        for batch_data in tqdm(train_dl):
            model.setup_input(batch_data)
            model.optimize()
            update_losses(model, loss_meter_dict, count=batch_data['L'].size(0))
            i += 1

            # Log results at fixed intervals
            if i % display_every == 0:
                print(f"\nEpoch {e+1}/{epochs}")
                print(f"Iteration {i}/{len(train_dl)}")
                log_results(loss_meter_dict, log_file, epoch=e + 1, iteration=i)

            # Visualize images at random intervals
            if i % visualize_every == 0:
                print(f"Visualizing at random iteration {i} (interval: {visualize_every})")
                # Fetch a new batch of validation data for visualization
                data = next(iter(val_dl))
                visualize(model, data, save=True)

In [None]:
def build_res_unet(n_input=1, n_output=2, size=256):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # First, instantiate the resnet18 model with pretrained weights
    backbone = resnet18(pretrained=True)
    # Now pass this instantiated model to create_body
    body = create_body(backbone, n_in=n_input, cut=-2)
    net_G = DynamicUnet(body, n_output, (size, size)).to(device)
    return net_G


def pretrain_generator(net_G, train_dl, opt, criterion, epochs):
    for e in range(epochs):
        loss_meter = AverageMeter()
        for batch_data in tqdm(train_dl):
            L, ab = batch_data['L'].to(device), batch_data['ab'].to(device)
            preds = net_G(L)
            loss = criterion(preds, ab)
            opt.zero_grad()
            loss.backward()
            opt.step()

            loss_meter.update(loss.item(), L.size(0))

        print(f"Epoch {e + 1}/{epochs}")
        print(f"L1 Loss: {loss_meter.avg:.5f}")

In [None]:
# Pretrain the generator
net_G = build_res_unet(n_input=1, n_output=2, size=256)
opt_g = optim.Adam(net_G.parameters(), lr=1e-4)
criterion_L1 = nn.L1Loss()

print("Pretraining the generator with L1 loss...")
pretrain_generator(net_G, train_dl, opt_g, criterion_L1, epochs=20)

# Save the pretrained generator weights
torch.save(net_G.state_dict(), "res18-unet-v2.pt")
print("Pretraining complete and weights saved.")
"""

In [None]:
"""

In [None]:
# Define the path to the pretrained weights
pretrained_weights_path = "/kaggle/working/res18-unet.pt"  # Update this path if the weights are saved elsewhere

# Load the pretrained generator
print("Loading pretrained generator weights...")
net_G = build_res_unet(n_input=1, n_output=2, size=256)  # Build the generator architecture
net_G.load_state_dict(torch.load(pretrained_weights_path, map_location=device))  # Load weights
net_G = net_G.to(device)  # Move the model to the appropriate device (GPU/CPU)
net_G.eval()  # Set the model to evaluation mode (optional, for inference)

print("Pretrained generator loaded successfully!")
"""

In [None]:
import os
import csv
# Ensure the images directory exists
#os.makedirs("/kaggle/working/images", exist_ok=True)
#!rm /kaggle/working/images/colorization_1733694174.787311.png
#!mkdir images_new

In [None]:
# Load the pretrained generator from the working directory
net_G = build_res_unet(n_input=1, n_output=2, size=256)

# Ensure the model is on the GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net_G = net_G.to(device)

# Load the pretrained weights into the generator
pretrained_weights_path = "/kaggle/input/res18-v2/other/default/1/res18-unet-v2.pt"  # Path to the weights file
net_G.load_state_dict(torch.load(pretrained_weights_path, map_location=device))
print("Pretrained weights loaded into the generator.")

# Create the main model with the pretrained net_G
model = MainModel(net_G=net_G, lr_G=2e-4, lr_D=2e-4, beta1=0.5, beta2=0.999, lambda_L1=100.)
model = model.to(device)  # Ensure the full model is on the GPU
print("Main model created and moved to the GPU (if available).")

# Train the model
print("Training the full model (GAN + L1) now...")
train_model(model, train_dl, epochs=50, display_every=200)
print("Training complete.")

In [None]:
# Load the pretrained generator
net_G = build_res_unet(n_input=1, n_output=2, size=256)
net_G.load_state_dict(torch.load("res18-unet.pt", map_location=device))

# Create the main model with the pretrained net_G
model = MainModel(net_G=net_G, lr_G=2e-4, lr_D=2e-4, beta1=0.5, beta2=0.999, lambda_L1=100.)
print("Training the full model (GAN + L1) now...")
train_model(model, train_dl, epochs=50, display_every=200)
print("Training complete.")
"""

In [None]:
# Save the generator weights
generator_save_path = "/kaggle/working/net_G_final.pt"
torch.save(model.net_G.state_dict(), generator_save_path)
print(f"Generator weights saved to {generator_save_path}")

# Save the entire model
main_model_save_path = "/kaggle/working/main_model_final.pt"
torch.save({
    'net_G': model.net_G.state_dict(),
    'net_D': model.net_D.state_dict(),
    'opt_G': model.opt_G.state_dict(),
    'opt_D': model.opt_D.state_dict()
}, main_model_save_path)
print(f"Full model saved to {main_model_save_path}")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from skimage.color import lab2rgb
import os

# Ensure the directory for saving plots exists
os.makedirs("predicted_samples", exist_ok=True)

# Load the generator model
def load_model(generator_path, device):
    net_G = build_res_unet(n_input=1, n_output=2, size=256)  # Ensure consistent architecture
    net_G.load_state_dict(torch.load(generator_path, map_location=device))
    net_G.eval()  # Set to evaluation mode
    return net_G.to(device)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator_path = "/kaggle/input/yashvidl/other/default/1/net_G_final.pt"
generator = load_model(generator_path, device)

# Function to convert LAB to RGB
def lab_to_rgb(L, ab):
    L = (L + 1.) * 50.
    ab = ab * 110.
    Lab = np.concatenate([L, ab], axis=1).transpose(0, 2, 3, 1)
    rgb_imgs = [lab2rgb(img) for img in Lab]
    return np.stack(rgb_imgs)

# Metrics Calculation
from skimage.metrics import structural_similarity as ssim
import numpy as np

def calculate_metrics(color, predicted):
    """
    Calculate evaluation metrics: MSE, PSNR, and SSIM with robust handling for small images and explicit channel_axis.
    """
    try:
        # Mean Squared Error
        mse = np.mean((color - predicted) ** 2)

        # Peak Signal-to-Noise Ratio
        psnr = 20 * np.log10(1.0 / np.sqrt(mse)) if mse > 0 else float('inf')

        # Determine minimum dimension of the image
        min_dim = min(color.shape[0], color.shape[1])

        # Dynamically set `win_size` or skip SSIM if the image is too small
        if min_dim < 7:
            print(f"Skipping SSIM for image with insufficient dimensions: {color.shape}")
            ssim_value = 0.0
        else:
            win_size = min(7, min_dim if min_dim % 2 == 1 else min_dim - 1)
            ssim_value = ssim(color, predicted, data_range=1.0, win_size=win_size, channel_axis=-1)

        return mse, psnr, ssim_value
    except Exception as e:
        print(f"Error calculating metrics: {e}")
        return float('inf'), float('inf'), 0.0




# Visualization
def plot_images_with_metrics(color, grayscale, predicted, mse, psnr, ssim_value, sample_idx):
    try:
        plt.figure(figsize=(15, 15))
        plt.subplot(1, 3, 1)
        plt.title('Color Image', color='green', fontsize=20)
        plt.imshow(color)
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.title('Grayscale Image', color='black', fontsize=20)
        plt.imshow(grayscale, cmap='gray')
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.title(f'Predicted Image\nMSE: {mse:.4f}\nPSNR: {psnr:.2f} dB\nSSIM: {ssim_value:.4f}',
                  color='red', fontsize=16)
        plt.imshow(predicted)
        plt.axis('off')

        plt.savefig(f'predicted_samples/predicted_sample_{sample_idx}.png', bbox_inches='tight')
        plt.show()
    except Exception as e:
        print(f"Error plotting images: {e}")

# Evaluation loop
results = []
for i, data in enumerate(val_dl):  # Iterate over validation DataLoader
    if i >= 10:  # Test on first 10 batches (adjust as needed)
        break

    L = data['L'].to(device)
    ab_real = data['ab'].to(device)

    # Predict color channels
    with torch.no_grad():
        ab_predicted = generator(L).cpu().numpy()

    # Convert to RGB for visualization and metrics
    predicted_rgb = lab_to_rgb(L.cpu().numpy(), ab_predicted)
    real_rgb = lab_to_rgb(L.cpu().numpy(), ab_real.cpu().numpy())

    for idx in range(len(L)):
        real_img = real_rgb[idx]
        predicted_img = predicted_rgb[idx]

        print(f"Processing Image {idx}: Real shape: {real_img.shape}, Predicted shape: {predicted_img.shape}")

        # Calculate metrics
        mse, psnr, ssim_value = calculate_metrics(real_img, predicted_img)
        print(f"Metrics for Image {idx}: MSE={mse:.4f}, PSNR={psnr:.2f} dB, SSIM={ssim_value:.4f}")

    for idx in range(len(L)):
        # Calculate metrics
        mse, psnr, ssim_value = calculate_metrics(real_rgb[idx], predicted_rgb[idx])
        results.append([mse, psnr, ssim_value])

        # Plot and save images with metrics
        plot_images_with_metrics(
            real_rgb[idx],
            L[idx][0].cpu().numpy(),  # Grayscale image
            predicted_rgb[idx],
            mse,
            psnr,
            ssim_value,
            sample_idx=i * len(L) + idx
        )

# Calculate average metrics
if results:
    average_metrics = np.mean(results, axis=0)
    print(f"Average Metrics: MSE={average_metrics[0]:.4f}, PSNR={average_metrics[1]:.2f} dB, SSIM={average_metrics[2]:.4f}")
else:
    print("No results to average.")


In [None]:
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

# Sort results by MSE (ascending), PSNR (descending), and SSIM (descending)
sorted_results = sorted(
    enumerate(results),  # Include the index to identify which image
    key=lambda x: (x[1][0], -x[1][1], -x[1][2])  # Sort by MSE (low), then PSNR (high), then SSIM (high)
)

# Extract the top 100 results
top_100_results = sorted_results[:100]

# Create a PDF file to save the plots
output_pdf_path = "/kaggle/working/top_100_predictions.pdf"
pdf = PdfPages(output_pdf_path)

# Function to add images to the PDF
def add_prediction_to_pdf(pdf, real_img, grayscale_img, predicted_img, mse, psnr, ssim_value, rank):
    """
    Add the ground truth, grayscale input, and predicted image with metrics to a PDF.
    """
    plt.figure(figsize=(15, 5))

    # Ground Truth Image
    plt.subplot(1, 3, 1)
    plt.imshow(real_img)
    plt.title("Ground Truth", fontsize=14)
    plt.axis("off")

    # Grayscale Input Image
    plt.subplot(1, 3, 2)
    plt.imshow(grayscale_img, cmap="gray")
    plt.title("Grayscale Input", fontsize=14)
    plt.axis("off")

    # Predicted Image
    plt.subplot(1, 3, 3)
    plt.imshow(predicted_img)
    plt.title(f"Prediction\nMSE: {mse:.4f}, PSNR: {psnr:.2f}, SSIM: {ssim_value:.4f}", fontsize=12)
    plt.axis("off")

    plt.suptitle(f"Rank: {rank}", fontsize=16, color="blue")
    plt.tight_layout()

    # Save the current figure to the PDF
    pdf.savefig()
    plt.close()

# Add the top 100 predictions to the PDF
for rank, (index, metrics) in enumerate(top_100_results, start=1):
    mse, psnr, ssim_value, *extra_metrics = metrics  # Unpack the first three metrics

    # Retrieve the real image, grayscale input, and predicted image
    L = val_dl.dataset[index]["L"].numpy()  # Grayscale input (normalized)
    ab_real = val_dl.dataset[index]["ab"].numpy()  # Ground truth color
    ab_predicted = generator(torch.tensor(L).unsqueeze(0).to(device)).cpu().detach().numpy().squeeze()

    # Convert LAB to RGB
    real_img = lab_to_rgb(L[np.newaxis, ...], ab_real[np.newaxis, ...])[0]
    predicted_img = lab_to_rgb(L[np.newaxis, ...], ab_predicted[np.newaxis, ...])[0]
    grayscale_img = L[0]  # Grayscale image for visualization

    # Add the prediction to the PDF
    add_prediction_to_pdf(pdf, real_img, grayscale_img, predicted_img, mse, psnr, ssim_value, rank)

# Close the PDF file
pdf.close()

print(f"PDF saved to: {output_pdf_path}")


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from skimage.color import lab2rgb
import warnings
from scipy.linalg import sqrtm
from torchvision.models import resnet50
from torchvision.transforms import Resize

# Ignore warnings
warnings.filterwarnings("ignore")

# Load the generator model
def load_model(generator_path, device):
    net_G = build_res_unet(n_input=1, n_output=2, size=256)  # Ensure consistent architecture
    net_G.load_state_dict(torch.load(generator_path, map_location=device))
    net_G.eval()  # Set to evaluation mode
    return net_G.to(device)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator_path = "/kaggle/input/yashvidl/other/default/1/net_G_final.pt"
generator = load_model(generator_path, device)

# Convert LAB to RGB
def lab_to_rgb(L, ab):
    L = (L + 1.) * 50.
    ab = ab * 110.
    Lab = np.concatenate([L, ab], axis=1).transpose(0, 2, 3, 1)
    rgb_imgs = [lab2rgb(img) for img in Lab]
    return np.stack(rgb_imgs)

# FID Helper Functions
from scipy.linalg import sqrtm

def calculate_fid(real_features, fake_features):
    """
    Calculate FID between real and fake features.
    """
    # Compute mean and covariance for real and generated features
    mu1 = np.mean(real_features, axis=0)
    sigma1 = np.cov(real_features, rowvar=False)
    mu2 = np.mean(fake_features, axis=0)
    sigma2 = np.cov(fake_features, rowvar=False)

    # Compute the mean difference
    diff = mu1 - mu2
    diff_squared = diff @ diff

    # Compute the square root of the product of covariance matrices
    covmean, _ = sqrtm(sigma1 @ sigma2, disp=False)

    # Handle imaginary numbers from sqrtm (they can occur due to numerical instability)
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    # Calculate FID
    fid = diff_squared + np.trace(sigma1 + sigma2 - 2 * covmean)
    return fid

def extract_features(images, model, batch_size=32):
    """
    Extract features using a pre-trained ResNet50 model.
    """
    features = []
    transform = Resize((224, 224))  # ResNet50 expects (224, 224) inputs
    for i in range(0, len(images), batch_size):
        batch = torch.tensor(images[i:i + batch_size]).permute(0, 3, 1, 2).to(device)  # (B, H, W, C) -> (B, C, H, W)
        batch = transform(batch)  # Resize to (224, 224)
        if batch.ndim != 4:
            print(f"Invalid batch shape: {batch.shape}. Skipping this batch.")
            continue
        with torch.no_grad():
            pred = model(batch)
        features.append(pred.cpu().numpy())
    return np.concatenate(features, axis=0)

# Metrics Calculation
def calculate_metrics(color, predicted):
    """
    Calculate MSE, PSNR, and SSIM.
    """
    try:
        mse = np.mean((color - predicted) ** 2)
        psnr = 20 * np.log10(1.0 / np.sqrt(mse)) if mse > 0 else float('inf')

        min_dim = min(color.shape[0], color.shape[1])
        if min_dim < 7:
            ssim_value = 0.0
        else:
            win_size = min(7, min_dim if min_dim % 2 == 1 else min_dim - 1)
            ssim_value = ssim(color, predicted, data_range=1.0, win_size=win_size, channel_axis=-1)

        return mse, psnr, ssim_value
    except Exception as e:
        print(f"Error calculating metrics: {e}")
        return float('inf'), float('inf'), 0.0

# Load ResNet50 for FID
resnet_model = resnet50(pretrained=True).to(device)
resnet_model.eval()

# Evaluation Loop
results = []  # Store MSE, PSNR, SSIM, FID for all images
all_mse, all_psnr, all_ssim = [], [], []
real_features, fake_features = [], []

for i, data in enumerate(val_dl):
    L = data['L'].to(device)
    ab_real = data['ab'].to(device)

    # Predict color channels
    with torch.no_grad():
        ab_predicted = generator(L).cpu().numpy()

    # Convert to RGB
    predicted_rgb = lab_to_rgb(L.cpu().numpy(), ab_predicted)
    real_rgb = lab_to_rgb(L.cpu().numpy(), ab_real.cpu().numpy())

    # Collect real and fake features for FID
    real_features.append(real_rgb)
    fake_features.append(predicted_rgb)

    for idx in range(len(L)):
        real_img = real_rgb[idx]
        predicted_img = predicted_rgb[idx]

        # Calculate metrics
        mse, psnr, ssim_value = calculate_metrics(real_img, predicted_img)
        results.append([mse, psnr, ssim_value])
        all_mse.append(mse)
        all_psnr.append(psnr)
        all_ssim.append(ssim_value)

# FID Calculation
real_features_np = np.concatenate(real_features)
fake_features_np = np.concatenate(fake_features)

real_features = extract_features(real_features_np, resnet_model)
fake_features = extract_features(fake_features_np, resnet_model)

fid_value = calculate_fid(real_features, fake_features)
all_fid = [fid_value] * len(results)

# Add FID to results
results = [result + [fid_value] for result in results]

# Calculate average metrics
average_metrics = np.mean(results, axis=0)
print(f"Average Metrics: MSE={average_metrics[0]:.4f}, PSNR={average_metrics[1]:.2f} dB, SSIM={average_metrics[2]:.4f}, FID={average_metrics[3]:.4f}")

# Plotting Metrics
def plot_metric(metric_values, metric_name):
    plt.figure(figsize=(10, 5))
    plt.plot(metric_values, label=metric_name)
    plt.xlabel("Image Index")
    plt.ylabel(metric_name)
    plt.title(f"{metric_name} Across Validation Images")
    plt.legend()
    plt.tight_layout()
    plt.show()

# Separate plots for each metric
plot_metric(all_mse, "MSE")
plot_metric(all_psnr, "PSNR (dB)")
plot_metric(all_ssim, "SSIM")
plot_metric([fid_value] * len(all_mse), "FID")


In [None]:
# Calculate average metrics (excluding FID from per-image results)
average_metrics = np.mean(results, axis=0)
print(f"Average Metrics: MSE={average_metrics[0]:.4f}, PSNR={average_metrics[1]:.2f} dB, SSIM={average_metrics[2]:.4f}")

# Print the global FID value
print(f"Global FID: {fid_value:.4f}")


In [None]:
import numpy as np

# Assuming results is a list of lists: [[mse1, psnr1, ssim1], [mse2, psnr2, ssim2], ...]
if results:
    # Convert the results list to a NumPy array for easier manipulation
    results_array = np.array(results)

    # Calculate the average along the columns (axis=0)
    average_metrics = np.mean(results_array, axis=0)

    # Print the average metrics
    print(f"Average Metrics: MSE={average_metrics[0]:.4f}, PSNR={average_metrics[1]:.2f} dB, SSIM={average_metrics[2]:.4f}")
else:
    print("Results list is empty. No metrics to average.")


In [None]:
import torch
from pathlib import Path

# Define the paths to the saved model files
main_model_path = "/kaggle/input/yashvidl/other/default/1/main_model_final.pt"
generator_weights_path = "/kaggle/input/yashvidl/other/default/1/net_G_final.pt"

# Ensure the device is set (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the generator architecture
def build_res_unet(n_input=1, n_output=2, size=256):
    from fastai.vision.learner import create_body
    from torchvision.models.resnet import resnet18
    from fastai.vision.models.unet import DynamicUnet
    backbone = resnet18(pretrained=False)  # Pretrained=False because we're loading custom weights
    body = create_body(backbone, n_in=n_input, cut=-2)
    net_G = DynamicUnet(body, n_output, (size, size)).to(device)
    return net_G

# Initialize the generator and discriminator
net_G = build_res_unet(n_input=1, n_output=2, size=256)
net_D = PatchDiscriminator(input_c=3, n_down=3, num_filters=64).to(device)

# Load the saved generator weights
print("Loading generator weights...")
net_G.load_state_dict(torch.load(generator_weights_path, map_location=device))
print("Generator weights loaded successfully.")

# Load the full model state
print("Loading full model state...")
checkpoint = torch.load(main_model_path, map_location=device)
net_G.load_state_dict(checkpoint['net_G'])
net_D.load_state_dict(checkpoint['net_D'])

# Load optimizer states if needed
opt_G = torch.optim.Adam(net_G.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_D = torch.optim.Adam(net_D.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_G.load_state_dict(checkpoint['opt_G'])
opt_D.load_state_dict(checkpoint['opt_D'])

print("Full model and optimizers loaded successfully.")


In [None]:
class ColorizationDatasetEval(Dataset):
    def __init__(self, paths, size=256):
        self.transforms = transforms.Compose([
            transforms.Resize((size, size), Image.BICUBIC),
        ])
        self.size = size
        self.paths = [path for path in paths if os.path.isfile(path)]  # Filter valid paths

    def __getitem__(self, idx):
        try:
            img = Image.open(self.paths[idx]).convert("RGB")
            img = self.transforms(img)
            img = np.array(img)
            img_lab = rgb2lab(img).astype("float32")
            img_lab = transforms.ToTensor()(img_lab)
            L = img_lab[[0], ...] / 50. - 1.
            ab = img_lab[[1, 2], ...] / 110.
            return {'L': L, 'ab': ab, 'path': self.paths[idx]}
        except Exception as e:
            print(f"Error loading image {self.paths[idx]}: {e}")
            return None  # Return None if an image fails to load

    def __len__(self):
        return len(self.paths)


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from skimage.color import lab2rgb
import warnings
from scipy.linalg import sqrtm
from torchvision.models import resnet50
from torchvision.transforms import Resize

# Ignore warnings
warnings.filterwarnings("ignore")

# Load the generator model
def load_model(generator_path, device):
    net_G = build_res_unet(n_input=1, n_output=2, size=256)  # Ensure consistent architecture
    net_G.load_state_dict(torch.load(generator_path, map_location=device))
    net_G.eval()  # Set to evaluation mode
    return net_G.to(device)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator_path = "/kaggle/input/yashvidl/other/default/1/net_G_final.pt"
generator = load_model(generator_path, device)

# Convert LAB to RGB
def lab_to_rgb(L, ab):
    L = (L + 1.) * 50.
    ab = ab * 110.
    Lab = np.concatenate([L, ab], axis=1).transpose(0, 2, 3, 1)
    rgb_imgs = [lab2rgb(img) for img in Lab]
    return np.stack(rgb_imgs)

# FID Helper Functions
def calculate_fid(real_features, fake_features):
    """
    Calculate FID between real and fake features.
    """
    mu1, sigma1 = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = np.mean(fake_features, axis=0), np.cov(fake_features, rowvar=False)

    diff = mu1 - mu2
    covmean, _ = sqrtm(sigma1 @ sigma2, disp=False)

    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff @ diff + np.trace(sigma1 + sigma2 - 2 * covmean)
    return fid

def extract_features(images, model, batch_size=32):
    """
    Extract features using a pre-trained ResNet50 model.
    """
    features = []
    transform = Resize((224, 224))  # ResNet50 expects (224, 224) inputs
    for i in range(0, len(images), batch_size):
        batch = torch.tensor(images[i:i + batch_size]).permute(0, 3, 1, 2).to(device)  # (B, H, W, C) -> (B, C, H, W)
        batch = transform(batch)  # Resize to (224, 224)
        if batch.ndim != 4:
            print(f"Invalid batch shape: {batch.shape}. Skipping this batch.")
            continue
        with torch.no_grad():
            pred = model(batch)
        features.append(pred.cpu().numpy())
    return np.concatenate(features, axis=0)

# Metrics Calculation
def calculate_metrics(color, predicted):
    """
    Calculate MSE, PSNR, and SSIM.
    """
    try:
        mse = np.mean((color - predicted) ** 2)
        psnr = 20 * np.log10(1.0 / np.sqrt(mse)) if mse > 0 else float('inf')

        min_dim = min(color.shape[0], color.shape[1])
        if min_dim < 7:
            ssim_value = 0.0
        else:
            win_size = min(7, min_dim if min_dim % 2 == 1 else min_dim - 1)
            ssim_value = ssim(color, predicted, data_range=1.0, win_size=win_size, channel_axis=-1)

        return mse, psnr, ssim_value
    except Exception as e:
        print(f"Error calculating metrics: {e}")
        return float('inf'), float('inf'), 0.0

# Load ResNet50 for FID
resnet_model = resnet50(pretrained=True).to(device)
resnet_model.eval()

# Evaluation Loop
results = []  # Store MSE, PSNR, SSIM, FID for all images
all_mse, all_psnr, all_ssim = [], [], []
real_features, fake_features = [], []

for i, data in enumerate(val_dl):
    L = data['L'].to(device)
    ab_real = data['ab'].to(device)

    # Predict color channels
    with torch.no_grad():
        ab_predicted = generator(L).cpu().numpy()

    # Convert to RGB
    predicted_rgb = lab_to_rgb(L.cpu().numpy(), ab_predicted)
    real_rgb = lab_to_rgb(L.cpu().numpy(), ab_real.cpu().numpy())

    # Collect real and fake features for FID
    real_features.append(real_rgb)
    fake_features.append(predicted_rgb)

    for idx in range(len(L)):
        real_img = real_rgb[idx]
        predicted_img = predicted_rgb[idx]

        # Calculate metrics
        mse, psnr, ssim_value = calculate_metrics(real_img, predicted_img)
        results.append([mse, psnr, ssim_value])
        all_mse.append(mse)
        all_psnr.append(psnr)
        all_ssim.append(ssim_value)

# FID Calculation
real_features_np = np.concatenate(real_features)
fake_features_np = np.concatenate(fake_features)

real_features = extract_features(real_features_np, resnet_model)
fake_features = extract_features(fake_features_np, resnet_model)

fid_value = calculate_fid(real_features, fake_features)
all_fid = [fid_value] * len(results)

# Add FID to results
results = [result + [fid_value] for result in results]

# Calculate average metrics
average_metrics = np.mean(results, axis=0)
print(f"Average Metrics: MSE={average_metrics[0]:.4f}, PSNR={average_metrics[1]:.2f} dB, SSIM={average_metrics[2]:.4f}, FID={average_metrics[3]:.4f}")

# Plotting Metrics
def plot_metric(metric_values, metric_name):
    plt.figure(figsize=(10, 5))
    plt.plot(metric_values, label=metric_name)
    plt.xlabel("Image Index")
    plt.ylabel(metric_name)
    plt.title(f"{metric_name} Across Validation Images")
    plt.legend()
    plt.tight_layout()
    plt.show()

# Separate plots for each metric
plot_metric(all_mse, "MSE")
plot_metric(all_psnr, "PSNR (dB)")
plot_metric(all_ssim, "SSIM")
plot_metric([fid_value] * len(all_mse), "FID")
