In [None]:
import os
import shutil
import random


imagenet_train_dir = "/kaggle/input/imagenet-object-localization-challenge/ILSVRC/Data/CLS-LOC/train"
# Path to new subset directory
subset_dir = "/kaggle/working/imagenet_subset/train"

# Number of classes and images per class
num_classes = 50
images_per_class = 1000

# Get list of all class folders
all_classes = sorted(os.listdir(imagenet_train_dir))
# Choose 50 classes (randomly or specify your own)
selected_classes = random.sample(all_classes, num_classes)

os.makedirs(subset_dir, exist_ok=True)

for class_name in selected_classes:
    src_class_dir = os.path.join(imagenet_train_dir, class_name)
    dst_class_dir = os.path.join(subset_dir, class_name)
    os.makedirs(dst_class_dir, exist_ok=True)
    
    # Get all images in this class
    images = sorted(os.listdir(src_class_dir))
    # Select up to 1000 images
    selected_images = images[:images_per_class]
    
    for img_name in selected_images:
        src_img_path = os.path.join(src_class_dir, img_name)
        dst_img_path = os.path.join(dst_class_dir, img_name)
        shutil.copyfile(src_img_path, dst_img_path)


In [None]:
from torch.utils.data import Dataset, DataLoader# from torchvision import transforms
from PIL import Image
import os

class ImageNetColorizationDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.image_paths = []
        for subdir, _, files in os.walk(root_dir):
            for file in files:
                if file.endswith('.JPEG') or file.endswith('.jpg'):
                    self.image_paths.append(os.path.join(subdir, file))
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            img = Image.open(img_path).convert("RGB").resize((256, 256))
            img = np.array(img)
    
            lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
            L = lab[:, :, 0] / 255.0
            ab = lab[:, :, 1:] / 128.0
    
            L = torch.from_numpy(L).unsqueeze(0).float()        # [1, H, W]
            ab = torch.from_numpy(ab.transpose((2, 0, 1))).float()  # [2, H, W]
    
            return L, ab
        except Exception as e:
            print(f"Failed to load {img_path}: {e}")
            return None



In [None]:
from torchvision import transforms

transform = transforms.ToTensor()
imagenet_dataset = ImageNetColorizationDataset('/kaggle/working/imagenet_subset/train', transform=transform)
train_loader = DataLoader(
    imagenet_dataset,
    batch_size=32,                  # Increase batch size if memory allows
    shuffle=True,
    num_workers=8,                  # Increase workers (rule: 4 * num_GPUs)
    pin_memory=True,                # Keep this for GPU transfer efficiency
    prefetch_factor=2,              # Prefetch batches
    persistent_workers=True,        # Keep workers alive between epochs
    drop_last=True                  # Slightly faster by dropping incomplete batches
)


In [None]:
import cv2
import numpy as np

def rgb_to_lab(img_rgb):
    img_lab = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2LAB)
    L, a, b = cv2.split(img_lab)
    return L / 255.0, a / 128.0, b / 128.0  # normalize


In [1]:
import torch
import torch.nn as nn

class UNetColor(nn.Module):
    def __init__(self):
        super(UNetColor, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2)
        )
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 2, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 2, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 2, stride=2), nn.ReLU(),
            nn.Conv2d(32, 2, 1)  # 2 output channels: a and b
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [None]:
criterion = nn.MSELoss()
# 1. Define the model
model = UNetColor()

# 2. Move model to GPU if available (optional but recommended)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# 3. Now define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


In [None]:
def lab_to_rgb(L, a, b):
    L = L * 255.0
    a = a * 128.0
    b = b * 128.0
    lab = cv2.merge([L.astype(np.uint8), a.astype(np.uint8), b.astype(np.uint8)])
    rgb = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
    return rgb


In [None]:
from tqdm import tqdm

num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for L, ab in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        L, ab = L.to(device), ab.to(device)

        optimizer.zero_grad()
        output_ab = model(L)
        loss = criterion(output_ab, ab)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import cv2
import torch

def visualize_colorization(model, dataloader, device):
    model.eval()
    with torch.no_grad():
        for L, ab in dataloader:
            L, ab = L.to(device), ab.to(device)
            output_ab = model(L)

            # Get first sample in batch
            L_img = L[0].cpu().numpy()[0] * 255.0  # [256, 256]
            ab_gt = ab[0].cpu().numpy().transpose(1, 2, 0) * 128.0  # [256, 256, 2]
            ab_pred = output_ab[0].cpu().numpy().transpose(1, 2, 0) * 128.0

            # Construct LAB images
            lab_gt = np.zeros((256, 256, 3), dtype=np.float32)
            lab_pred = np.zeros((256, 256, 3), dtype=np.float32)
            lab_gt[:, :, 0] = L_img
            lab_gt[:, :, 1:] = ab_gt
            lab_pred[:, :, 0] = L_img
            lab_pred[:, :, 1:] = ab_pred

            # Convert LAB to RGB
            rgb_gt = cv2.cvtColor(lab_gt.astype(np.uint8), cv2.COLOR_LAB2RGB)
            rgb_pred = cv2.cvtColor(lab_pred.astype(np.uint8), cv2.COLOR_LAB2RGB)

            # Grayscale visualization
            gray_img = cv2.cvtColor(rgb_gt, cv2.COLOR_RGB2GRAY)

            # Plot
            plt.figure(figsize=(12, 4))
            plt.subplot(1, 3, 1)
            plt.imshow(rgb_gt)
            plt.title("Original")
            plt.axis('off')

            plt.subplot(1, 3, 2)
            plt.imshow(gray_img, cmap='gray')
            plt.title("Input Grayscale")
            plt.axis('off')

            plt.subplot(1, 3, 3)
            plt.imshow(rgb_pred)
            plt.title("Reconstructed")
            plt.axis('off')

            plt.tight_layout()
            plt.show()
            break  # Show only one batch


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
visualize_colorization(model, train_loader, device)


In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

class ResNetUNetColor(nn.Module):
    def __init__(self):
        super(ResNetUNetColor, self).__init__()

        # Load pretrained ResNet18
        resnet = models.resnet18(pretrained=True)

        # Use only first conv block (accepts RGB), we’ll modify to accept 1-channel input
        self.input_conv = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Extract ResNet layers for encoder
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool

        self.layer1 = resnet.layer1  # [64, 64, 64]
        self.layer2 = resnet.layer2  # [128, 32, 32]
        self.layer3 = resnet.layer3  # [256, 16, 16]
        self.layer4 = resnet.layer4  # [512, 8, 8]

        # Decoder
        self.up1 = self.up_block(512, 256)
        self.up2 = self.up_block(512, 128)  # skip from layer3
        self.up3 = self.up_block(256, 64)   # skip from layer2
        self.up4 = self.up_block(128, 64)   # skip from layer1
        self.up5 = nn.Sequential(           # final upsampling to 256×256
            nn.ConvTranspose2d(64, 32, 2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 2, kernel_size=1)  # Output ab channels
        )

    def up_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Encoder
        x1 = self.relu(self.bn1(self.input_conv(x)))  # [64, 128, 128]
        x2 = self.layer1(self.maxpool(x1))            # [64, 64, 64]
        x3 = self.layer2(x2)                          # [128, 32, 32]
        x4 = self.layer3(x3)                          # [256, 16, 16]
        x5 = self.layer4(x4)                          # [512, 8, 8]

        # Decoder with skip connections
        d1 = self.up1(x5)                             # [256, 16, 16]
        d2 = self.up2(torch.cat([d1, x4], dim=1))     # [128, 32, 32]
        d3 = self.up3(torch.cat([d2, x3], dim=1))     # [64, 64, 64]
        d4 = self.up4(torch.cat([d3, x2], dim=1))     # [64, 128, 128]
        out = self.up5(d4)                            # [2, 256, 256]

        return out


In [None]:
model2 = ResNetUNetColor().to(device)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model2 = model2.to(device)

# 3. Now define the optimizer
optimizer = torch.optim.Adam(model2.parameters(), lr=1e-4)

In [None]:
from tqdm import tqdm

num_epochs = 10

for epoch in range(num_epochs):
    model2.train()
    running_loss = 0.0

    for L, ab in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        L, ab = L.to(device), ab.to(device)

        optimizer.zero_grad()
        output_ab = model2(L)
        loss = criterion(output_ab, ab)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
visualize_colorization(model2, train_loader, device)


In [None]:
torch.save(model2.state_dict(), "/kaggle/working/resnet_unet_colorization_model_weights.pth")


In [None]:
torch.save(model.state_dict(), "/kaggle/working/unet_colorization_model_weights.pth")


In [None]:
import requests
from PIL import Image
from io import BytesIO

# Load image from URL
url = "https://huggingface.co/takuma104/controlnet_dev/resolve/main/bird_512x512.png"
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert("RGB").resize((256, 256))


In [None]:
def colorize_image_from_url(model, url, device):
    model.eval()

    # Load and resize image
    response = requests.get(url)
    img = Image.open(BytesIO(response.content)).convert("RGB").resize((256, 256))
    img_np = np.array(img)

    # Convert RGB to LAB and extract L
    lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
    L = lab[:, :, 0] / 255.0

    # Convert to tensor
    L_tensor = torch.from_numpy(L).unsqueeze(0).unsqueeze(0).float().to(device)

    # Predict ab channels
    with torch.no_grad():
        ab_pred = model(L_tensor)[0].cpu().numpy().transpose(1, 2, 0) * 128.0

    # Reconstruct LAB and convert to RGB
    lab_pred = np.zeros((256, 256, 3), dtype=np.float32)
    lab_pred[:, :, 0] = L * 255.0
    lab_pred[:, :, 1:] = ab_pred
    rgb_pred = cv2.cvtColor(lab_pred.astype(np.uint8), cv2.COLOR_LAB2RGB)

    # Plot
    plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(img_np)
    plt.title("Original Input")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(rgb_pred)
    plt.title("Colorized Output")
    plt.axis('off')
    plt.tight_layout()
    plt.show()


In [None]:
url = "https://huggingface.co/takuma104/controlnet_dev/resolve/main/bird_512x512.png"
colorize_image_from_url(model2, url, device)


In [None]:
from PIL import Image
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt

def colorize_single_image(model, image_path, device):
    model.eval()

    # Load and resize image
    img = Image.open(image_path).convert("RGB").resize((256, 256))
    img_np = np.array(img)

    # Convert to LAB and extract L channel
    lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB)
    L = lab[:, :, 0] / 255.0  # Normalize L channel

    # Convert to tensor
    L_tensor = torch.from_numpy(L).unsqueeze(0).unsqueeze(0).float().to(device)  # [1, 1, 256, 256]

    # Forward pass
    with torch.no_grad():
        ab_pred = model(L_tensor)[0].cpu().numpy().transpose(1, 2, 0) * 128.0

    # Reconstruct LAB image and convert to RGB
    L_img = L * 255.0
    lab_pred = np.zeros((256, 256, 3), dtype=np.float32)
    lab_pred[:, :, 0] = L_img
    lab_pred[:, :, 1:] = ab_pred
    rgb_pred = cv2.cvtColor(lab_pred.astype(np.uint8), cv2.COLOR_LAB2RGB)

    # Show result
    plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(img_np)
    plt.title("Original Image")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(rgb_pred)
    plt.title("Colorized Output")
    plt.axis('off')
    plt.tight_layout()
    plt.show()


In [None]:
image_path = "/kaggle/input/imagenet-object-localization-challenge/ILSVRC/Data/CLS-LOC/test/ILSVRC2012_test_00000003.JPEG"  # Upload your own image to input folder
colorize_single_image(model2, image_path, device)
