In [1]:
import os
from PIL import Image
import numpy as np
import cv2 
from skimage.color import rgb2lab, lab2rgb

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [3]:
class ImageNetColorizationDataset(Dataset):
    def __init__(self, img_dir, split='train', img_size=224, mode='classification', cluster_path='data/pts_in_hull.npy'):
        """
        img_dir: Path to the ImageNet images (with subfolders if using ImageFolder structure).
        split: 'train' or 'val' for training or validation data.
        img_size: The size (height = width) to which images will be resized/cropped.
        mode: 'classification' for 313-bin classification output, 'regression' for direct ab prediction.
        """
        self.mode = mode
        self.img_size = img_size

        # Collect all image file paths. If using ImageFolder structure, traverse subdirectories.
        self.image_files = []
        for root, _, files in os.walk(img_dir):
            for fname in files:
                if fname.lower().endswith(('.jpg', '.jpeg', '.png')):
                    self.image_files.append(os.path.join(root, fname))
        self.image_files.sort()
        
        # Load ab cluster centers for quantization (for classification mode)
        if self.mode == 'classification':
            # pts_in_hull.npy should contain an array of shape (313, 2) for ab cluster centers
            self.ab_clusters = np.load(cluster_path)  # dtype float, shape (313,2)
            # Precompute a KD-tree or similar for speed (optional). Here we'll use a simple numpy method.
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        with Image.open(img_path) as img:
            img = img.convert('RGB')
            if hasattr(self, 'augment') or True:
                img = img.resize((self.img_size, self.img_size))
                if self.mode == 'train':
                    if np.random.rand() > 0.5:
                        img = img.transpose(Image.FLIP_LEFT_RIGHT)
            img_np = np.array(img)
        img_np = img_np.astype(np.float32) / 255.0
        lab = rgb2lab(img_np)
        L = lab[:, :, 0]
        ab = lab[:, :, 1:]
        L_tensor = torch.from_numpy(L).unsqueeze(0).float()
        ab_tensor = torch.from_numpy(ab).permute(2,0,1).float()
        
        if self.mode == 'classification':
            H, W = L.shape
            small_H, small_W = H // 4, W // 4
            ab_small = cv2.resize(ab, (small_W, small_H), interpolation=cv2.INTER_AREA)
            ab_small = ab_small.astype(np.float32)
            ab_pixels = ab_small.reshape(-1, 2)
            dists = np.linalg.norm(ab_pixels[:, None, :] - self.ab_clusters[None, :, :], axis=2)
            cluster_indices = dists.argmin(axis=1).astype(np.int64)
            class_map = cluster_indices.reshape(small_H, small_W)
            class_tensor = torch.from_numpy(class_map)
            return L_tensor, class_tensor
        else:
            return L_tensor, ab_tensor

In [4]:
def visualize_colorization_samples(model, data_loader, device, num_images=4, save_path=None):
    model.eval()
    model.to(device)

    with torch.no_grad():
        sample_L, _ = next(iter(data_loader))  # Get a batch
        sample_L = sample_L.to(device)
        pred_ab = model(sample_L)

    sample_L = sample_L.cpu().numpy()
    pred_ab = pred_ab.cpu().numpy()

    plt.figure(figsize=(num_images * 3, 6))

    for i in range(min(num_images, sample_L.shape[0])):
        # L and ab channels
        L_chan = sample_L[i, 0, :, :]
        ab_chan = pred_ab[i].transpose(1, 2, 0)

        # Combine and convert to RGB
        lab = np.concatenate([L_chan[:, :, np.newaxis], ab_chan], axis=2)
        rgb = lab2rgb(lab)

        # Plot grayscale L
        plt.subplot(2, num_images, i + 1)
        plt.imshow(L_chan, cmap='gray')
        plt.axis('off')
        plt.title("Grayscale")

        # Plot colorized output
        plt.subplot(2, num_images, i + 1 + num_images)
        plt.imshow(rgb)
        plt.axis('off')
        plt.title("Colorized")

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path)
        print(f"Visualization saved to {save_path}")
    else:
        plt.show()

In [5]:
# Example usage:
train_dir = "imagenet/train"
val_dir = "imagenet/val"

# Create dataset and dataloader
train_dataset = ImageNetColorizationDataset(train_dir, split='train', img_size=224, mode='classification', cluster_path='./data/pts_in_hull.npy')
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)

val_dataset = ImageNetColorizationDataset(val_dir, split='val', img_size=224, mode='classification')
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2, pin_memory=True)

In [6]:
class ECCVGenerator(nn.Module):
    def __init__(self, norm_layer=nn.BatchNorm2d):
        super(ECCVGenerator, self).__init__()
        self.l_cent = 50.0
        self.l_norm = 100.0
        self.ab_norm = 110.0

        # Encoder
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1, 1), nn.ReLU(True),
            nn.Conv2d(64, 64, 3, 2, 1), nn.ReLU(True),
            norm_layer(64)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(True),
            nn.Conv2d(128, 128, 3, 2, 1), nn.ReLU(True),
            norm_layer(128)
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(True),
            nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(True),
            nn.Conv2d(256, 256, 3, 2, 1), nn.ReLU(True),
            norm_layer(256)
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(256, 512, 3, 1, 1), nn.ReLU(True),
            nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(True),
            nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(True),
            norm_layer(512)
        )
        self.layer5 = nn.Sequential(
            nn.Conv2d(512, 512, 3, 1, 2, dilation=2), nn.ReLU(True),
            nn.Conv2d(512, 512, 3, 1, 2, dilation=2), nn.ReLU(True),
            nn.Conv2d(512, 512, 3, 1, 2, dilation=2), nn.ReLU(True),
            norm_layer(512)
        )
        self.layer6 = nn.Sequential(
            nn.Conv2d(512, 512, 3, 1, 2, dilation=2), nn.ReLU(True),
            nn.Conv2d(512, 512, 3, 1, 2, dilation=2), nn.ReLU(True),
            nn.Conv2d(512, 512, 3, 1, 2, dilation=2), nn.ReLU(True),
            norm_layer(512)
        )
        self.layer7 = nn.Sequential(
            nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(True),
            nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(True),
            nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(True),
            norm_layer(512)
        )
        self.layer8 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, 2, 1), nn.ReLU(True),
            nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(True),
            nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(True),
            nn.Conv2d(256, 313, 1, 1, 0)
        )

        self.softmax = nn.Softmax(dim=1)
        self.model_out = nn.Conv2d(313, 2, 1, 1, 0, bias=False)
        self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)

    def normalize_l(self, L):
        return (L - self.l_cent) / self.l_norm

    def unnormalize_ab(self, ab):
        return ab * self.ab_norm

    def forward(self, input_l):
        conv1 = self.layer1(self.normalize_l(input_l))
        conv2 = self.layer2(conv1)
        conv3 = self.layer3(conv2)
        conv4 = self.layer4(conv3)
        conv5 = self.layer5(conv4)
        conv6 = self.layer6(conv5)
        conv7 = self.layer7(conv6)
        conv8 = self.layer8(conv7)

        ab_prob = self.softmax(conv8)
        ab_pred = self.model_out(ab_prob)
        ab_pred_upsampled = self.upsample4(ab_pred)
        ab_pred_unnorm = self.unnormalize_ab(ab_pred_upsampled)
        return ab_pred_unnorm

In [7]:
def get_logits(model, L):
    m = model.module if isinstance(model, nn.DataParallel) else model
    x = m.layer1(m.normalize_l(L))
    x = m.layer2(x)
    x = m.layer3(x)
    x = m.layer4(x)
    x = m.layer5(x)
    x = m.layer6(x)
    x = m.layer7(x)
    return m.layer8(x)


def load_best_checkpoint(model, checkpoint_path):
    if os.path.exists(checkpoint_path):
        state_dict = torch.load(checkpoint_path, map_location='cpu')
        if isinstance(model, nn.DataParallel):
            model.module.load_state_dict(state_dict)
        else:
            model.load_state_dict(state_dict)
        print(f"✓ Loaded checkpoint from: {checkpoint_path}")
        return True
    else:
        print(f"✗ No checkpoint found at: {checkpoint_path}")
        return False


def evaluate(model, val_loader, criterion, use_classification, device):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for L, target in val_loader:
            L, target = L.to(device), target.to(device)
            if use_classification:
                logits = get_logits(model, L)
                loss = criterion(logits, target)
            else:
                ab_pred = model(L)
                loss = criterion(ab_pred, target)
            total_loss += loss.item()
    return total_loss / len(val_loader)


def save_sample_images(model, val_loader, device, epoch):
    model.eval()
    sample_L, _ = next(iter(val_loader))
    sample_L = sample_L.to(device)
    with torch.no_grad():
        pred_ab = model(sample_L)

    for i in range(min(4, pred_ab.size(0))):
        L_chan = sample_L[i].cpu().numpy().transpose(1, 2, 0)
        ab_chan = pred_ab[i].cpu().numpy().transpose(1, 2, 0)
        lab_img = np.concatenate([L_chan, ab_chan], axis=2)
        rgb_img = (np.clip(255 * np.clip(lab2rgb(lab_img), 0, 1), 0, 255)).astype('uint8')
        Image.fromarray(rgb_img).save(f"sample_epoch{epoch}_img{i}.png")


def train(model, train_loader, val_loader, use_classification=True,
          num_epochs=50, log_interval=10, save_interval=5,
          prefix="colorization_model"):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    checkpoint_dir = "checkpoints"
    os.makedirs(checkpoint_dir, exist_ok=True)
    best_path = os.path.join(checkpoint_dir, f"{prefix}_best.pth")

    resumed = load_best_checkpoint(model, best_path)

    criterion = nn.CrossEntropyLoss() if use_classification else nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    best_val_loss = float('inf') if not resumed else evaluate(model, val_loader, criterion, use_classification, device)

    for epoch in range(1, num_epochs + 1):
        model.train()
        running_loss = 0.0

        for batch_idx, (L, target) in enumerate(train_loader):
            L, target = L.to(device), target.to(device)
            optimizer.zero_grad()

            if use_classification:
                logits = get_logits(model, L)
                loss = criterion(logits, target)
            else:
                ab_pred = model(L)
                loss = criterion(ab_pred, target)

            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            if (batch_idx + 1) % log_interval == 0:
                avg = running_loss / log_interval
                print(f"[Epoch {epoch}] Batch {batch_idx + 1}/{len(train_loader)} | Loss: {avg:.4f}")
                running_loss = 0.0

        scheduler.step()

        val_loss = evaluate(model, val_loader, criterion, use_classification, device)
        print(f"[Epoch {epoch}] Validation Loss: {val_loss:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict(), best_path)
            print(f"New best model saved to {best_path}")

        if epoch % save_interval == 0 or epoch == num_epochs:
            save_path = os.path.join(checkpoint_dir, f"{prefix}_epoch{epoch}.pth")
            torch.save(model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict(), save_path)
            print(f"Checkpoint saved to {save_path}")
            save_sample_images(model, val_loader, device, epoch)


def test_model(model, test_loader, use_classification=True,
               checkpoint_path="checkpoints/colorization_model_best.pth"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    if not os.path.exists(checkpoint_path):
        print(f"No checkpoint found at: {checkpoint_path}")
        return

    # Load best model
    load_best_checkpoint(model, checkpoint_path)

    criterion = nn.CrossEntropyLoss() if use_classification else nn.MSELoss()
    model.eval()
    total_loss = 0.0

    with torch.no_grad():
        for L, target in test_loader:
            L, target = L.to(device), target.to(device)
            if use_classification:
                logits = get_logits(model, L)
                loss = criterion(logits, target)
            else:
                ab_pred = model(L)
                loss = criterion(ab_pred, target)
            total_loss += loss.item()

    avg_loss = total_loss / len(test_loader)
    print(f"Test Loss: {avg_loss:.4f}")

In [None]:
# training
model = ECCVGenerator()
train(model, train_loader, val_loader, use_classification=True)

In [9]:
# testing
model = ECCVGenerator()
test_model(model, val_loader, use_classification=True)

✓ Loaded checkpoint from: checkpoints/colorization_model_best.pth
Test Loss: 2.6048


In [None]:
# visualization
model = ECCVGenerator()
load_best_checkpoint(model, "checkpoints/colorization_model_best.pth")

visualize_colorization_samples(model, val_loader, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))