In [None]:
# Import functions for color space conversion from scikit-image
from skimage.color import rgb2lab, lab2rgb, rgb2gray, xyz2lab

# Import function for saving images from scikit-image
from skimage.io import imsave

# Import mean_squared_error from scikit-learn to measure performance
from sklearn.metrics import mean_squared_error

# Import image processing transforms from PyTorch
from torchvision import transforms

# Import the Image module from Python Imaging Library (PIL) for image handling
from PIL import Image

# Import NumPy for numerical operations
import numpy as np

# Import operating system functions for file and directory operations
import os

# Import random module for generating random numbers
import random

# Import PyTorch for deep learning functionalities
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Import PyTorch's DataLoader and TensorDataset for handling data in batches
from torch.utils.data import TensorDataset, DataLoader

# Import additional image processing transforms from PyTorch
from torchvision import transforms


In [None]:
# List to store image data
X = []

# Iterate over files in the 'Train/' directory
for filename in os.listdir('Train/'):
    # Open each image file, convert to RGB format, and append to the list
    X.append(np.array(Image.open('Train/' + filename).convert('RGB')))

# Convert the list of images to a NumPy array with dtype=float
X = np.array(X, dtype=float)

# Set up train and test data
# Calculate the index for splitting data into train and test sets
split = int(0.95 * len(X))

# Select the training data (images)
Xtrain = X[:split]

# Normalize the pixel values to the range [0, 1]
Xtrain = 1.0 / 255 * Xtrain


In [None]:
# Import necessary PyTorch modules
import torch.nn as nn
import torch.nn.functional as F

# Define a custom neural network model for image processing
class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()

        # Define the model architecture using nn.Sequential
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 2, kernel_size=3, padding=1),
            nn.Tanh(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        )

    def forward(self, x):
        # Forward pass through the model
        return self.model(x)

In [None]:
# Convert data to PyTorch tensors
# Convert RGB images to LAB color space
lab_train = rgb2lab(Xtrain)
X_train = lab_train[:, :, :, 0]
Y_train = lab_train[:, :, :, 1:] / 128
X_train = X_train.reshape(10, 256, 256, 1)
Y_train = Y_train.reshape(10, 256, 256, 2)

# Convert to PyTorch tensors and adjust dimension order
Xtrain_tensor = torch.tensor(X_train, dtype=torch.float32).permute(0, 3, 1, 2)
Ytrain_tensor = torch.tensor(Y_train, dtype=torch.float32).permute(0, 3, 1, 2)

# Instantiate the model and set up the optimizer and loss function
model = CustomModel()
criterion = nn.MSELoss()
optimizer = optim.RMSprop(model.parameters(), lr=0.001, alpha=0.9)

# Data augmentation using torchvision.transforms
transform = transforms.Compose([
    transforms.RandomAffine(degrees=20, shear=[-5, 5], scale=(0.8, 1.2)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
])

# Define a custom dataset class for the training data
class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, X, Y, transform=None):
        self.X = X
        self.Y = Y
        self.transform = transform

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = self.X[idx]
        y = self.Y[idx]

        if self.transform:
            # Convert to PIL Image before applying transformations
            x = transforms.ToPILImage()(x)
            y = transforms.ToPILImage()(y)

            # Apply transformations
            x = self.transform(x)
            y = self.transform(y)

            # Convert back to PyTorch tensors
            x = transforms.ToTensor()(x)
            y = transforms.ToTensor()(y)

        return x, y

In [None]:
# Assuming Xtrain_tensor and Ytrain_tensor are PyTorch tensors
batch_size = 2
train_dataset = ImageDataset(Xtrain_tensor, Ytrain_tensor, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Training loop
num_epochs = 1000
for epoch in range(num_epochs):
    for batch in train_loader:
        inputs, targets = batch
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        # Print the loss every 1 epoch
    if epoch % 1 == 0:
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')

In [None]:
# Save model
#model_json = model.to_json()
#with open("model.json", "w") as json_file:
#    json_file.write(model_json)
#model.save_weights("model.h5")

In [None]:
# Test images
# Convert RGB images to LAB color space for testing
X_test = rgb2lab(1.0 / 255 * X)[:, :, :, 0]
Y_test = rgb2lab(1.0 / 255 * X)[:, :, :, 1:] / 128

# Convert numpy arrays to PyTorch tensors and adjust dimension order
X_test = X_test.reshape(10, 256, 256, 1)
Y_test = Y_test.reshape(10, 256, 256, 2)
Xtest = torch.tensor(X_test, dtype=torch.float32).permute(0, 3, 1, 2)
Ytest = torch.tensor(Y_test, dtype=torch.float32).permute(0, 3, 1, 2)

# Set the model to evaluation mode (turn off dropout and batch normalization)
model.eval()

# Disable gradient computation during testing
with torch.no_grad():
    # Forward pass through the model to get predictions
    predictions = model(Xtest)

In [None]:
# Load and preprocess test images
color_me = []

# Iterate over files in the 'Test/' directory
for filename in os.listdir('Test/'):
    # Open each image file, convert to RGB format, and append to the list
    img = Image.open('Test/' + filename).convert('RGB')
    img_array = np.array(img, dtype=float)
    color_me.append(img_array)

# Convert the list of images to a NumPy array with dtype=float
color_me = np.array(color_me, dtype=float)

# Convert RGB test images to LAB color space
color_me = rgb2lab(1.0/255 * color_me)[:, :, :, 0]
color_me = color_me.reshape(8, 256, 256, 1)

# Convert the numpy array to a PyTorch tensor and adjust dimension order
color_me_tensor = torch.tensor(color_me, dtype=torch.float32).permute(0, 3, 1, 2)

# Test the model
model.eval()

# Disable gradient computation during testing
with torch.no_grad():
    # Forward pass through the model to get colorized outputs
    output = model(color_me_tensor)
    output = output.cpu().numpy()

# Convert PyTorch tensor to numpy array
output = output * 128

# Output colorizations
for i in range(len(output)):
    cur = np.zeros((256, 256, 3))
    cur[:, :, 0] = color_me[i][:, :, 0]
    cur[:, :, 1:] = output[i].transpose(1, 2, 0)
    output_img = lab2rgb(cur)
    output_img = (output_img * 255).astype(np.uint8)

    # Save the output image
    imsave("result/img_" + str(i) + ".png", output_img)
