In [None]:
# Importing functions for color space conversion and image saving from scikit-image
from skimage.color import rgb2lab, lab2rgb, rgb2gray, xyz2lab, gray2rgb
from skimage.io import imsave

# Importing a metric for evaluating model performance from scikit-learn
from sklearn.metrics import mean_squared_error

# Importing image processing transforms from PyTorch
from torchvision import transforms

# Importing pre-trained models and model-related functionalities from torchvision and timm
import torchvision.models as models
import timm

# Importing the Python Imaging Library (PIL) for image handling
from PIL import Image

# Importing NumPy for numerical operations
import numpy as np

# Importing operating system functions for file and directory operations
import os

# Importing the random module for generating random numbers
import random

# Importing PyTorch for deep learning functionalities
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Importing PyTorch's DataLoader and TensorDataset for handling data in batches
from torch.utils.data import TensorDataset, DataLoader

# Importing additional image processing transforms from PyTorch
from torchvision import transforms

In [None]:
# Get images
X = []

# Iterate over files in the 'Train/' directory
for filename in os.listdir('Train/'):
    # Open each image file, convert to RGB format, and append to the list
    X.append(np.array(Image.open('Train/' + filename).convert('RGB')))

# Convert the list of images to a NumPy array with dtype=float
X = np.array(X, dtype=float)

# Normalize pixel values to the range [0, 1]
Xtrain = 1.0/255 * X

# Load pre-trained Inception-ResNet-v2 model from torchvision
# Note: 'pytorch/vision:v0.10.0' is the model zoo release version
inception_resnet_v2 = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=True)

# Set the model to evaluation mode
inception_resnet_v2.eval()


In [None]:
import torch.nn as nn
import torch.nn.functional as F

# Encoder class definition
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        # Convolutional layers for feature extraction
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1)
        self.conv6 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.conv7 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv8 = nn.Conv2d(512, 256, kernel_size=3, padding=1)

    def forward(self, x):
        # Forward pass through the encoder
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = F.relu(self.conv7(x))
        x = F.relu(self.conv8(x))
        return x


# Fusion class definition
class Fusion(nn.Module):
    def __init__(self):
        super(Fusion, self).__init__()

    def forward(self, x, embed_input):
        # Reshape and repeat the embedding to match the feature map size
        embed_input = embed_input.view(-1, 1000, 1, 1)
        embed_input = embed_input.repeat(1, 1, x.size(2), x.size(3))
        # Concatenate the features with the repeated embedding
        x = torch.cat([x, embed_input], dim=1)
        # Apply convolution for fusion
        x = F.relu(nn.Conv2d(256 + 1000, 256, kernel_size=1)(x))
        return x


# Decoder class definition
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        # Convolutional and upsampling layers for decoding
        self.conv1 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.upsample1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv2 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.upsample2 = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv3 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(16, 2, kernel_size=3, padding=1)
        self.upsample3 = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, x):
        # Forward pass through the decoder
        x = F.relu(self.conv1(x))
        x = self.upsample1(x)
        x = F.relu(self.conv2(x))
        x = self.upsample2(x)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.tanh(self.conv5(x))
        x = self.upsample3(x)
        return x


# ColorizationModel class definition
class ColorizationModel(nn.Module):
    def __init__(self):
        super(ColorizationModel, self).__init__()
        # Instantiate encoder, fusion, and decoder
        self.encoder = Encoder()
        self.fusion = Fusion()
        self.decoder = Decoder()

    def forward(self, encoder_input, embed_input):
        # Forward pass through the entire colorization model
        encoder_output = self.encoder(encoder_input)
        fusion_output = self.fusion(encoder_output, embed_input)
        decoder_output = self.decoder(fusion_output)
        return decoder_output


In [None]:
def preprocess_batch(images):
    # Define a series of image preprocessing transformations using transforms.Compose
    preprocess = transforms.Compose([
        transforms.ToPILImage(),  # Convert images to PIL Image
        transforms.Resize((299, 299)),  # Resize images to (299, 299)
        transforms.ToTensor(),  # Convert images to PyTorch tensors
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize image values
    ])

    # Apply the defined transformations to the input batch of images
    return preprocess(images)

def create_inception_embedding(grayscaled_rgb_resized):
    # Preprocess the grayscaled RGB images
    grayscaled_rgb_resized = preprocess_batch(grayscaled_rgb_resized)
    
    # Convert the preprocessed images to PyTorch tensor and adjust dimension order
    input_tensor = torch.tensor(grayscaled_rgb_resized, dtype=torch.float32).permute(0, 3, 1, 2)
    
    # Disable gradient computation during inference
    with torch.no_grad():
        # Forward pass through the Inception-ResNet-v2 model to get embeddings
        output = inception_resnet_v2(input_tensor)
    
    return output


In [None]:
# Convert data to PyTorch tensors
lab_train = rgb2lab(Xtrain)
X_train = lab_train[:, :, :, 0]  # Extract L channel (luminance)
Y_train = lab_train[:, :, :, 1:] / 128  # Extract AB channels (chrominance) and normalize
X_train = X_train.reshape(10, 256, 256, 1)  # Reshape to (batch_size, height, width, channels)
Y_train = Y_train.reshape(10, 256, 256, 2)  # Reshape to (batch_size, height, width, channels)
Xtrain_tensor = torch.tensor(X_train, dtype=torch.float32).permute(0, 3, 1, 2)  # Convert to PyTorch tensor, adjust dimension order
Ytrain_tensor = torch.tensor(Y_train, dtype=torch.float32).permute(0, 3, 1, 2)  # Convert to PyTorch tensor, adjust dimension order

# Instantiate the model and set up the optimizer and loss function
model = ColorizationModel()  # Instantiate ColorizationModel
criterion = nn.MSELoss()  # Mean Squared Error loss
optimizer = optim.RMSprop(model.parameters(), lr=0.001, alpha=0.9)  # RMSprop optimizer

# Data augmentation using torchvision.transforms
transform = transforms.Compose([
    transforms.RandomAffine(degrees=20, shear=[-5, 5], scale=(0.8, 1.2)),  # Random affine transformations
    transforms.RandomHorizontalFlip(),  # Random horizontal flip
    transforms.RandomVerticalFlip(),  # Random vertical flip
    transforms.RandomRotation(20),  # Random rotation
])

# Define a custom dataset class for handling input data
class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, X, embed, Y, transform=None):
        """
        Args:
            X (list): Input data (grayscale images).
            embed (list): Embeddings.
            Y (list): Target data (colorized images).
            transform (torchvision.transforms.Compose): Optional data transformations.
        """
        self.X = X
        self.Y = Y
        self.embed = embed
        self.transform = transform

    def __len__(self):
        """
        Returns:
            int: Number of samples in the dataset.
        """
        return len(self.X)

    def __getitem__(self, idx):
        """
        Args:
            idx (int): Index of the sample.

        Returns:
            dict: Dictionary containing input image, embedding, and target data.
        """
        sample = {'image': self.X[idx], 'embed': self.embed[idx], 'target': self.Y[idx]}

        if self.transform:
            # Apply transformations to the input image if specified
            sample['image'] = self.transform(sample['image'])

        return sample


In [None]:
# Create DataLoader for batch training
batch_size = 2

# Convert RGB images to grayscale
grayscaled_rgb = gray2rgb(rgb2gray(Xtrain))

# Convert grayscaled images to PyTorch tensors and adjust dimension order
grayscaled_rgb = torch.tensor(grayscaled_rgb, dtype=torch.float32).permute(0, 3, 1, 2)

# Apply preprocessing transformations to the grayscaled images using preprocess_batch function
transformed_batch = torch.stack([preprocess_batch(image) for image in grayscaled_rgb])

# Compute embeddings using the pre-trained Inception ResNet V2 model
with torch.no_grad():
    embed = inception_resnet_v2(transformed_batch)

# Create a custom dataset with grayscaled images, embeddings, and colorized images
dataset = ImageDataset(Xtrain_tensor, embed, Ytrain_tensor, transform=transform)

# Create DataLoader for batch training with the specified batch size and shuffle the data
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [None]:
# Training loop
num_epochs = 100  # Number of training epochs
for epoch in range(num_epochs):
    total_loss = 0.0  # Initialize total loss for the current epoch
    for batch in dataloader:
        inputs, targets, em = batch['image'], batch['target'], batch['embed']  # Unpack input data, target data, and embeddings
        optimizer.zero_grad()  # Zero the gradients to avoid accumulation
        outputs = model(inputs, em)  # Forward pass: obtain model predictions
        loss = criterion(outputs, targets)  # Calculate the loss between predictions and target
        loss.backward()  # Backward pass: compute gradients
        optimizer.step()  # Update model parameters using the optimizer
        total_loss += loss.item()  # Accumulate the loss for the current batch

    average_loss = total_loss / len(dataloader)  # Calculate the average loss for the epoch
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {average_loss:.4f}")  # Print the average loss for the current epoch

In [None]:
# Load and preprocess test images
color_me = []

# Iterate through each file in the 'Test/' directory
for filename in os.listdir('Test/'):
    img = Image.open('Test/' + filename).convert('RGB')  # Open and convert the image to RGB format
    img_array = np.array(img, dtype=float)  # Convert the image to a NumPy array
    color_me.append(img_array)  # Append the NumPy array to the list

color_me = np.array(color_me, dtype=float)  # Convert the list of NumPy arrays to a NumPy array
x = color_me  # Save the color_me array for later use
color_me = rgb2lab(1.0/255*color_me)[:, :, :, 0]  # Convert the RGB images to Lab color space and extract the L channel
color_me = color_me.reshape(8, 256, 256, 1)  # Reshape to (batch_size, height, width, channels)

# Convert numpy array to PyTorch tensor
color_me_tensor = torch.tensor(color_me, dtype=torch.float32).permute(0, 3, 1, 2)  # Convert to PyTorch tensor, adjust dimension order

# Convert RGB test images to PyTorch tensor and apply preprocessing transformations
grayscaled_rgb_test = torch.tensor(x, dtype=torch.float32).permute(0, 3, 1, 2)
transformed_test = torch.stack([preprocess_batch(image) for image in grayscaled_rgb_test])

# Compute embeddings for test images using the pre-trained Inception ResNet V2 model
with torch.no_grad():
    embed_test = inception_resnet_v2(transformed_test)

# Set the model to evaluation mode
model.eval()

# Test the model
with torch.no_grad():
    output = model(color_me_tensor, embed_test)  # Forward pass to obtain colorized images
    output = output.cpu().numpy()  # Convert PyTorch tensor to numpy array

# Rescale the output to the original Lab color space range
output = output * 128

# Output colorizations
for i in range(len(output)):
    cur = np.zeros((256, 256, 3))
    cur[:, :, 0] = color_me[i][:, :, 0]  # Copy the L channel to the output
    cur[:, :, 1:] = output[i].transpose(1, 2, 0)  # Copy the colorized AB channels to the output
    output_img = lab2rgb(cur)  # Convert Lab back to RGB
    output_img = (output_img * 255).astype(np.uint8)  # Rescale pixel values to 0-255
    # Save the output image
    imsave("result/img_"+str(i)+".png", output_img)