In [1]:
# Operating system and file operations
import os

# Numerical operations and array processing
import numpy as np

import pickle

# Data visualization
import seaborn as sns
import matplotlib.pyplot as plt

# Random number generation
import random

# Image processing and computer vision
import cv2

# Deep learning and neural networks
import torch
import torch.nn as nn
import torch.nn.functional as F

# Data handling for deep learning
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

# Displaying model information
from torchsummary import summary

# Image manipulation
from PIL import Image

# Type annotations
from typing import List

# Progress bar for Jupyter Notebooks
from tqdm.notebook import tqdm

# Color space conversions
from colour import sRGB_to_XYZ, XYZ_to_Lab, Lab_to_XYZ, XYZ_to_sRGB

# Image quality assessment
from skimage.metrics import structural_similarity as ssim

# Memory management
import gc

# Evaluation
from evaluation import calculate_ssim
from evaluation import calculate_colourfulness

# Handling warnings
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

  warn(*args, **kwargs)  # noqa: B028


In [2]:
# Determine the device to use for PyTorch operations
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Print the selected device which device will be used for PyTorch operations.
print(device)

cpu


In [3]:
def lab_to_rgb(L, ab, device):
    """
    Takes a batch of images
    """
    L = 100 * L  # Scale the L component from 0-1 to 0-100
    ab = (ab - 0.5) * 256  # Adjust the a and b components to the correct range
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).detach().cpu().numpy()  # Combine L, a, b, and rearrange the format for processing
    rgb_imgs = []  # Initialize a list to store the resulting RGB images
    for img in Lab:
        img = Lab_to_XYZ(img)  # Convert LAB to XYZ
        img = XYZ_to_sRGB(img)  # Convert XYZ to RGB
        rgb_imgs.append(img)  # Append the RGB image to the list
    return torch.tensor(np.stack(rgb_imgs, axis=0)).permute(0, 3, 1, 2).to(device)  # Return the images as a PyTorch tensor and move to the specified device

In [4]:
def load_model(model_file):
    loaded_model = pickle.load(open(model_file, 'rb'))
    return loaded_model

In [5]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, norm_layer=nn.BatchNorm2d):
        super().__init__()
        # Defines a sequential container for two convolutional blocks with BatchNorm and ReLU activation
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False),
            norm_layer(out_channels),  # Normalization layer, here using BatchNorm
            nn.ReLU(inplace=True),     # ReLU activation with in-place operation to save memory
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
            norm_layer(out_channels),  # Second normalization layer
            nn.ReLU(inplace=True)      # Second ReLU activation
        )
        
        # Identity mapping that may be used to match dimensions for the residual connection
        self.identity = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
        self.relu = nn.ReLU(inplace=True)  # Final ReLU activation after adding the residual
        
    def forward(self, x):
        x_ = x.detach().clone()  # Detach and clone the input to prevent modifications during forwarding
        x_ = self.block(x_)       # Pass the input through the convolutional block
        
        residual = self.identity(x)  # Apply the identity mapping to the original input
        
        out = x_ + residual          # Add the output of the convolutional block to the identity mapping
        
        return self.relu(out)        # Apply a ReLU activation to the combined output and return it

In [6]:
class EncoderBlock(nn.Module):
    def __init__(self, in_chans, out_chans, sampling_factor=2):
        super().__init__()
        # Sequential container for an encoder block that includes a max pooling followed by a convolutional block
        self.block = nn.Sequential(
            nn.MaxPool2d(sampling_factor),  # Reduces the spatial dimensions of the input
            ConvBlock(in_chans, out_chans) # Applies a convolutional block to further process the data
        )
        
    def forward(self, x):
        # Forward pass of the encoder block: applies pooling and then convolution
        return self.block(x)
    
class DecoderBlock(nn.Module):
    def __init__(self, in_chans, out_chans, sampling_factor=2):
        super().__init__()
        # Upsampling layer to increase the spatial dimensions of the input
        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        
        # Convolutional block that processes the concatenated input from the upsampled feature map and the skip connection
        self.block = ConvBlock(in_chans + out_chans, out_chans)
        
    def forward(self, x, skip):
        # Upsamples the input feature map
        x = self.upsample(x)
        # Concatenates the upsampled feature map with the skip connection feature map
        x = torch.cat([x, skip], dim=1)
        # Processes the concatenated feature maps using a convolutional block
        x = self.block(x)
        return x

In [7]:
class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=2, dropout_rate=0.1):
        super().__init__()
        # Initialize the encoder part of the U-Net with progressively increasing channels
        self.encoder = nn.ModuleList([
            ConvBlock(in_channels, 64),
            EncoderBlock(64, 128),
            EncoderBlock(128, 256),
            EncoderBlock(256, 512),
        ])
        # Initialize the decoder part of the U-Net with progressively decreasing channels
        self.decoder = nn.ModuleList([
            DecoderBlock(512, 256),
            DecoderBlock(256, 128),
            DecoderBlock(128, 64)
        ])
        # Dropout layer to prevent overfitting
        self.dropout = nn.Dropout2d(dropout_rate)
        # Final convolution layer to map the decoded features to the desired number of output channels
        self.logits = nn.Conv2d(in_channels=64, out_channels=out_channels, kernel_size=1)
    
    def forward(self, x):
        encoded = []
        # Pass input through each encoder block, apply dropout, and store intermediate outputs for skip connections
        for enc in self.encoder:
            x = enc(x)
            x = self.dropout(x)
            encoded.append(x)

        enc_out = encoded.pop()
        
        # Start the decoding process using the stored encoded features
        for dec in self.decoder:
            enc_out = encoded.pop()  # Retrieve the corresponding encoder output for skip connections
            x = dec(x, enc_out)  # Decoder block processes input with skip connections
        # Apply a sigmoid activation to the final layer's output to normalize the output to [0,1] range
        return F.sigmoid(self.logits(x))

In [8]:
def colorize_images_in_folder(model, input_folder, output_folder, device='cuda'):
    # Create output folder if it doesn't exist
    os.makedirs(output_folder, exist_ok=True)

    # Define transforms
    transform = transforms.Compose([
        transforms.Resize((256, 128), antialias=True),
        transforms.ToTensor(),
        transforms.Normalize(mean=0, std=0.5)
    ])

    # Set model to evaluation mode
    model.eval()

    # Get list of image files in the input folder
    image_files = [f for f in os.listdir(input_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

    # Process each image
    for image_file in tqdm(image_files, desc="Colorizing images"):
        input_path = os.path.join(input_folder, image_file)
        output_path = os.path.join(output_folder, f"colorized_{image_file}")

        # Load and preprocess the image
        input_image = Image.open(input_path).convert('L')  # Convert to grayscale
        w, h = input_image.size
        input_image = input_image.crop((0, 0, w, h//2))  # Crop top half as per your dataset

        # Apply transforms
        input_tensor = transform(input_image).unsqueeze(0).to(device)

        # Generate colorized output
        with torch.no_grad():
            output = model(input_tensor)

        # Convert output to RGB
        L = input_tensor * 100
        ab = (output - 0.5) * 256
        Lab = torch.cat([L, ab], dim=1).squeeze().permute(1, 2, 0).cpu().numpy()

        # Convert LAB to RGB
        rgb_image = XYZ_to_sRGB(Lab_to_XYZ(Lab))
        
        # Clip values to [0, 1] range and convert to uint8
        rgb_image = np.clip(rgb_image, 0, 1)
        rgb_image = (rgb_image * 255).astype(np.uint8)

        # Convert to PIL Image and save
        colorized_image = Image.fromarray(rgb_image)
        colorized_image.save(output_path)

    print(f"Colorized images saved in {output_folder}")

In [9]:
G = load_model('gen_model.pkl')

In [10]:
Generated = colorize_images_in_folder(G, "dataset/test_black", "output55/", device)

Colorizing images:   0%|          | 0/13 [00:00<?, ?it/s]

Colorized images saved in output55/
