In [None]:
import os
import shutil

# List of specific ImageNet class IDs to include in the subset
subset_class_ids = [
    "n01440764", "n01491361", "n01498041", "n01514859", "n01530575", "n01616318", "n01622779", "n01629819", "n01641577", "n01664065",
    "n01729322", "n01734418", "n01742172", "n01774750", "n01795545", "n01806143", "n01833805", "n01882714", "n01914609", "n01945685",
    "n01978455", "n01985128", "n02002556", "n02006656", "n02011460", "n02058221", "n02071294", "n02085782", "n02088466", "n02099601",
    "n02100583", "n02104029", "n02106550", "n02110063", "n02110958", "n02112137", "n02113023", "n02113799", "n02123045", "n02128757",
    "n02132136", "n02165456", "n02190166", "n02206856", "n02226429", "n02233338", "n02256656", "n02279972", "n02317335", "n02346627",
    "n02364673", "n02391049", "n02395406", "n02403003", "n02412080", "n02415577", "n02423022", "n02437312", "n02480495", "n02504458",
    "n02509815", "n02640242", "n02692877", "n02708093", "n02786058", "n02795169", "n02814533", "n02823428", "n02837789", "n02879718",
    "n02951358", "n02966193", "n03014705", "n03047690", "n03085013", "n03100240", "n03126707", "n03160309", "n03179701", "n03255030",
    "n03272010", "n03314780", "n03379051", "n03417042", "n03424325", "n03444034", "n03478589", "n03494278", "n03584829", "n03633091",
    "n03666591", "n03770439", "n03793489", "n03814639", "n03888257", "n03933933", "n03976657", "n04037443", "n04118538", "n04552348"
]

# Directories
imagenet_train_dir = "/kaggle/input/imagenet-object-localization-challenge/ILSVRC/Data/CLS-LOC/train"
subset_dir = "/kaggle/working/imagenet_subset/train"

# Create target root directory
os.makedirs(subset_dir, exist_ok=True)

# Iterate through specified class IDs
for class_id in subset_class_ids:
    src_class_dir = os.path.join(imagenet_train_dir, class_id)
    dst_class_dir = os.path.join(subset_dir, class_id)
    os.makedirs(dst_class_dir, exist_ok=True)
    
    # List all images
    images = sorted(os.listdir(src_class_dir))
    
    # Copy all images
    for img_name in images:
        src_img_path = os.path.join(src_class_dir, img_name)
        dst_img_path = os.path.join(dst_class_dir, img_name)
        shutil.copyfile(src_img_path, dst_img_path)


In [None]:
from torch.utils.data import Dataset, DataLoader# from torchvision import transforms
from PIL import Image
import os

class ImageNetColorizationDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.image_paths = []
        for subdir, _, files in os.walk(root_dir):
            for file in files:
                if file.endswith('.JPEG') or file.endswith('.jpg'):
                    self.image_paths.append(os.path.join(subdir, file))
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            img = Image.open(img_path).convert("RGB").resize((256, 256))
            img = np.array(img)
    
            lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
            L = lab[:, :, 0] / 255.0
            ab = lab[:, :, 1:] / 128.0
    
            L = torch.from_numpy(L).unsqueeze(0).float()        # [1, H, W]
            ab = torch.from_numpy(ab.transpose((2, 0, 1))).float()  # [2, H, W]
    
            return L, ab
        except Exception as e:
            print(f"Failed to load {img_path}: {e}")
            return None

In [None]:
from torchvision import transforms

transform = transforms.ToTensor()
imagenet_dataset = ImageNetColorizationDataset('/kaggle/working/imagenet_subset/train', transform=transform)
train_loader = DataLoader(
    imagenet_dataset,
    batch_size=32,                  # Increase batch size if memory allows
    shuffle=True,
    num_workers=8,                  # Increase workers (rule: 4 * num_GPUs)
    pin_memory=True,                # Keep this for GPU transfer efficiency
    prefetch_factor=2,              # Prefetch batches
    persistent_workers=True,        # Keep workers alive between epochs
    drop_last=True                  # Slightly faster by dropping incomplete batches
)


In [None]:
import cv2
import numpy as np

def rgb_to_lab(img_rgb):
    img_lab = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2LAB)
    L, a, b = cv2.split(img_lab)
    return L / 255.0, a / 128.0, b / 128.0  # normalize


In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

class ResNetUNetColor(nn.Module):
    def __init__(self):
        super(ResNetUNetColor, self).__init__()

        # Load pretrained ResNet18
        resnet = models.resnet18(pretrained=True)

        # Use only first conv block (accepts RGB), we’ll modify to accept 1-channel input
        self.input_conv = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Extract ResNet layers for encoder
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool

        self.layer1 = resnet.layer1  # [64, 64, 64]
        self.layer2 = resnet.layer2  # [128, 32, 32]
        self.layer3 = resnet.layer3  # [256, 16, 16]
        self.layer4 = resnet.layer4  # [512, 8, 8]

        # Decoder
        self.up1 = self.up_block(512, 256)
        self.up2 = self.up_block(512, 128)  # skip from layer3
        self.up3 = self.up_block(256, 64)   # skip from layer2
        self.up4 = self.up_block(128, 64)   # skip from layer1
        self.up5 = nn.Sequential(           # final upsampling to 256×256
            nn.ConvTranspose2d(64, 32, 2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 2, kernel_size=1)  # Output ab channels
        )

    def up_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Encoder
        x1 = self.relu(self.bn1(self.input_conv(x)))  # [64, 128, 128]
        x2 = self.layer1(self.maxpool(x1))            # [64, 64, 64]
        x3 = self.layer2(x2)                          # [128, 32, 32]
        x4 = self.layer3(x3)                          # [256, 16, 16]
        x5 = self.layer4(x4)                          # [512, 8, 8]

        # Decoder with skip connections
        d1 = self.up1(x5)                             # [256, 16, 16]
        d2 = self.up2(torch.cat([d1, x4], dim=1))     # [128, 32, 32]
        d3 = self.up3(torch.cat([d2, x3], dim=1))     # [64, 64, 64]
        d4 = self.up4(torch.cat([d3, x2], dim=1))     # [64, 128, 128]
        out = self.up5(d4)                            # [2, 256, 256]

        return out


In [None]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ResNetUNetColor().to(device)
criterion = nn.MSELoss()
# 3. Now define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
from tqdm import tqdm

num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for L, ab in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        L, ab = L.to(device), ab.to(device)

        optimizer.zero_grad()
        output_ab = model(L)
        loss = criterion(output_ab, ab)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

    checkpoint_path = f"/kaggle/working/resnet_colorization_model_epoch_{epoch+1}.pth"
    torch.save(model.state_dict(), checkpoint_path)
    print(f"Model saved to {checkpoint_path}")


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import cv2
import torch

def visualize_colorization(model, dataloader, device):
    model.eval()
    with torch.no_grad():
        for L, ab in dataloader:
            L, ab = L.to(device), ab.to(device)
            output_ab = model(L)

            # Get first sample in batch
            L_img = L[0].cpu().numpy()[0] * 255.0  # [256, 256]
            ab_gt = ab[0].cpu().numpy().transpose(1, 2, 0) * 128.0  # [256, 256, 2]
            ab_pred = output_ab[0].cpu().numpy().transpose(1, 2, 0) * 128.0

            # Construct LAB images
            lab_gt = np.zeros((256, 256, 3), dtype=np.float32)
            lab_pred = np.zeros((256, 256, 3), dtype=np.float32)
            lab_gt[:, :, 0] = L_img
            lab_gt[:, :, 1:] = ab_gt
            lab_pred[:, :, 0] = L_img
            lab_pred[:, :, 1:] = ab_pred

            # Convert LAB to RGB
            rgb_gt = cv2.cvtColor(lab_gt.astype(np.uint8), cv2.COLOR_LAB2RGB)
            rgb_pred = cv2.cvtColor(lab_pred.astype(np.uint8), cv2.COLOR_LAB2RGB)

            # Grayscale visualization
            gray_img = cv2.cvtColor(rgb_gt, cv2.COLOR_RGB2GRAY)

            # Plot
            plt.figure(figsize=(12, 4))
            plt.subplot(1, 3, 1)
            plt.imshow(rgb_gt)
            plt.title("Original")
            plt.axis('off')

            plt.subplot(1, 3, 2)
            plt.imshow(gray_img, cmap='gray')
            plt.title("Input Grayscale")
            plt.axis('off')

            plt.subplot(1, 3, 3)
            plt.imshow(rgb_pred)
            plt.title("Reconstructed")
            plt.axis('off')

            plt.tight_layout()
            plt.show()
            break  # Show only one batch


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
visualize_colorization(model, train_loader, device)

In [None]:
from PIL import Image
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt

def colorize_single_image(model, image_path, device):
    model.eval()

    # Load and resize image
    img = Image.open(image_path).convert("RGB").resize((256, 256))
    img_np = np.array(img)

    # Convert to LAB and extract L channel
    lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
    L = lab[:, :, 0] / 255.0  # Normalize L channel

    # Convert to tensor
    L_tensor = torch.from_numpy(L).unsqueeze(0).unsqueeze(0).float().to(device)  # [1, 1, 256, 256]

    # Forward pass
    with torch.no_grad():
        ab_pred = model(L_tensor)[0].cpu().numpy().transpose(1, 2, 0) * 128.0

    # Reconstruct LAB image and convert to RGB
    L_img = L * 255.0
    lab_pred = np.zeros((256, 256, 3), dtype=np.float32)
    lab_pred[:, :, 0] = L_img
    lab_pred[:, :, 1:] = ab_pred
    rgb_pred = cv2.cvtColor(lab_pred.astype(np.uint8), cv2.COLOR_LAB2RGB)

    # Show result
    plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(img_np)
    plt.title("Original Image")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(rgb_pred)
    plt.title("Colorized Output")
    plt.axis('off')
    plt.tight_layout()
    plt.show()


In [None]:
image_path = "/kaggle/input/imagenet-object-localization-challenge/ILSVRC/Data/CLS-LOC/test/ILSVRC2012_test_00000004.JPEG"  # Upload your own image to input folder
colorize_single_image(model, image_path, device)