<h1>Freeform Colorization with U-Net</h1>

In [1]:
# Satvika Eda, Divya Sri Bandaru & Dhriti Anjaria
# 21nd April 2025
# This code is for freeform image colorization with U-Net


In [None]:
import os
import shutil
import random
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import cv2
import torch
from torchvision import transforms
import torch.nn as nn
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image

In [None]:
# List of specific ImageNet class IDs to include in the subset

subset_class_ids = [
    "n01440764", "n01491361", "n01498041", "n01514859", "n01530575", "n01616318", "n01622779", "n01629819", "n01641577", "n01664065",
    "n01729322", "n01734418", "n01742172", "n01774750", "n01795545", "n01806143", "n01833805", "n01882714", "n01914609", "n01945685",
    "n01978455", "n01985128", "n02002556", "n02006656", "n02011460", "n02058221", "n02071294", "n02085782", "n02088466", "n02099601",
    "n02100583", "n02104029", "n02106550", "n02110063", "n02110958", "n02112137", "n02113023", "n02113799", "n02123045", "n02128757",
    "n02132136", "n02165456", "n02190166", "n02206856", "n02226429", "n02233338", "n02256656", "n02279972", "n02317335", "n02346627",
    "n02364673", "n02391049", "n02395406", "n02403003", "n02412080", "n02415577", "n02423022", "n02437312", "n02480495", "n02504458",
    "n02509815", "n02640242", "n02692877", "n02708093", "n02786058", "n02795169", "n02814533", "n02823428", "n02837789", "n02879718",
    "n02951358", "n02966193", "n03014705", "n03047690", "n03085013", "n03100240", "n03126707", "n03160309", "n03179701", "n03255030",
    "n03272010", "n03314780", "n03379051", "n03417042", "n03424325", "n03444034", "n03478589", "n03494278", "n03584829", "n03633091",
    "n03666591", "n03770439", "n03793489", "n03814639", "n03888257", "n03933933", "n03976657", "n04037443", "n04118538", "n04552348"
]

imagenet_train_dir = "/kaggle/input/imagenet-object-localization-challenge/ILSVRC/Data/CLS-LOC/train"
subset_dir = "/kaggle/working/imagenet_subset/train"
os.makedirs(subset_dir, exist_ok=True)

# Iterate through specified class IDs and copy all images
for class_id in subset_class_ids:
    src_class_dir = os.path.join(imagenet_train_dir, class_id)
    dst_class_dir = os.path.join(subset_dir, class_id)
    os.makedirs(dst_class_dir, exist_ok=True)
    
    images = sorted(os.listdir(src_class_dir))
    for img_name in images:
        src_img_path = os.path.join(src_class_dir, img_name)
        dst_img_path = os.path.join(dst_class_dir, img_name)
        shutil.copyfile(src_img_path, dst_img_path)


In [None]:
# Dataset for loading ImageNet images
class ImageNetColorizationDataset(Dataset):

    # To initialize the dataset by collecting all image file paths
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.image_paths = []
        for subdir, _, files in os.walk(root_dir):
            for file in files:
                if file.endswith('.JPEG') or file.endswith('.jpg'):
                    self.image_paths.append(os.path.join(subdir, file))
        self.transform = transform

    # To return the total number of images found
    def __len__(self):
        return len(self.image_paths)

    # To load an image, convert and return the L channel and ab channels as tensors
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            img = Image.open(img_path).convert("RGB").resize((256, 256))
            img = np.array(img)
    
            lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
            L = lab[:, :, 0] / 255.0
            ab = lab[:, :, 1:] / 128.0
    
            L = torch.from_numpy(L).unsqueeze(0).float()       
            ab = torch.from_numpy(ab.transpose((2, 0, 1))).float()  
    
            return L, ab
        except Exception as e:
            print(f"Failed to load {img_path}: {e}")
            return None



In [None]:
# Create a DataLoader for the ImageNet colorization dataset

transform = transforms.ToTensor()
imagenet_dataset = ImageNetColorizationDataset('/kaggle/working/imagenet_subset/train', transform=transform)
train_loader = DataLoader(
    imagenet_dataset,
    batch_size=32,                  
    shuffle=True,
    num_workers=8,                  
    pin_memory=True,                
    prefetch_factor=2,             
    persistent_workers=True,       
    drop_last=True                 
)


In [None]:
# Model for Vanilla UNet Architecture
class UNetColor(nn.Module):
    def __init__(self):
        super(UNetColor, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2)
        )
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 2, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 2, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 2, stride=2), nn.ReLU(),
            nn.Conv2d(32, 2, 1)  
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [None]:
# Create model and move to CUDA, using MSE loss and Adam Optimizer
criterion = nn.MSELoss()
model = UNetColor()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


In [None]:
# Training section with L and ab
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for L, ab in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        L, ab = L.to(device), ab.to(device)

        optimizer.zero_grad()
        output_ab = model(L)
        loss = criterion(output_ab, ab)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")


In [None]:
# To colorize single image with model in evaluation mode

def colorize_single_image(model, image_path, device):
    model.eval()

    img = Image.open(image_path).convert("RGB").resize((256, 256))
    img_np = np.array(img)

    # Convert to LAB and extract L channel
    lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
    L = lab[:, :, 0] / 255.0 

    L_tensor = torch.from_numpy(L).unsqueeze(0).unsqueeze(0).float().to(device) 
    with torch.no_grad():
        ab_pred = model(L_tensor)[0].cpu().numpy().transpose(1, 2, 0) * 128.0

    L_img = L * 255.0
    lab_pred = np.zeros((256, 256, 3), dtype=np.float32)
    lab_pred[:, :, 0] = L_img
    lab_pred[:, :, 1:] = ab_pred
    rgb_pred = cv2.cvtColor(lab_pred.astype(np.uint8), cv2.COLOR_LAB2RGB)

    plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(img_np)
    plt.title("Original Image")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(rgb_pred)
    plt.title("Colorized Output")
    plt.axis('off')
    plt.tight_layout()
    plt.show()


In [None]:
# To test the vanilla model performance on different images

image_path = "/kaggle/input/imagenet-object-localization-challenge/ILSVRC/Data/CLS-LOC/test/ILSVRC2012_test_00000003.JPEG"  # Upload your own image to input folder
colorize_single_image(model, image_path, device)
