In [1]:
import os
import glob
import numpy as np
from PIL import Image
from skimage import color
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models
import timm

from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

EPOCHS = 6
LR = 1e-4
BATCH_SIZE = 16
DATA_PATH = "/kaggle/input/cartoon-classification"



Using device: cuda


In [2]:
class CartoonColorizationDataset(Dataset):
    def __init__(self, root_dir, split='TRAIN', image_size=256):
        self.image_paths = glob.glob(os.path.join(root_dir, 'cartoon_classification', split, '**', '*.jpg')) + \
                           glob.glob(os.path.join(root_dir, 'cartoon_classification', split, '**', '*.png'))
        self.image_size = image_size
        self.transforms = T.Compose([
            T.Resize((image_size, image_size)),
            T.ToTensor(), 
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert("RGB")
        img = img.resize((self.image_size, self.image_size), Image.Resampling.BILINEAR)
        img_np = np.array(img)

        # RGB to Lab
        img_lab = color.rgb2lab(img_np)
        
        # Normalize L to [0, 1]
        img_l = img_lab[:, :, 0] / 100.0 
        
        # Normalize ab to [-1, 1]
        img_ab = img_lab[:, :, 1:] / 128.0 

        img_l = torch.from_numpy(img_l).unsqueeze(0).float() 
        img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1))).float() 

        return img_l, img_ab

In [3]:
class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        # Load VGG16 pretrained on ImageNet
        vgg = models.vgg16(pretrained=True).features
        
        # We use the first 16 layers (up to relu3_3) to capture texture/structure
        self.slice = nn.Sequential()
        for i in range(16):
            self.slice.add_module(str(i), vgg[i])
            
        # Freeze parameters (we don't train VGG)
        for param in self.slice.parameters():
            param.requires_grad = False
            
    def forward(self, pred, target):
        # Input must be 3 channels. 
        # pred/target: [B, 3, H, W]
        pred_feat = self.slice(pred)
        target_feat = self.slice(target)
        return F.mse_loss(pred_feat, target_feat)

In [4]:
class Backbone(nn.Module):
    def __init__(self, model_name='convnext_tiny'):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=True, features_only=True, in_chans=3)
        
        # IMPROVEMENT: Learnable adapter from 1 channel (L) to 3 channels
        # This replaces the hardcoded .repeat(1,3,1,1)
        self.input_adapter = nn.Conv2d(1, 3, kernel_size=1) 
        
        # Determine channel counts
        with torch.no_grad():
            dummy = torch.randn(1, 3, 256, 256)
            feats = self.model(dummy)
        self.chans = [f.shape[1] for f in feats] 
    
    def forward(self, x):
        # x: [B, 1, H, W] -> Adapter -> [B, 3, H, W]
        x = self.input_adapter(x) 
        return self.model(x)

In [5]:
class PixelDecoderBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv = nn.Conv2d(in_c, out_c * 4, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(2)
        self.act = nn.LeakyReLU(0.2)

    def forward(self, x):
        return self.act(self.pixel_shuffle(self.conv(x)))

class PixelDecoder(nn.Module):
    def __init__(self, encoder_chans, dec_chans=[256, 128, 64]):
        super().__init__()
        self.dec_chans = dec_chans
        self.up4 = PixelDecoderBlock(encoder_chans[3], dec_chans[0]) 
        self.up3 = PixelDecoderBlock(dec_chans[0] + encoder_chans[2], dec_chans[1])
        self.up2 = PixelDecoderBlock(dec_chans[1] + encoder_chans[1], dec_chans[2])
        self.up1 = PixelDecoderBlock(dec_chans[2] + encoder_chans[0], dec_chans[2]) 
        self.final = nn.Conv2d(dec_chans[2], 256, kernel_size=1) 

    def forward(self, feats):
        x = self.up4(feats[3]) 
        if x.shape != feats[2].shape: x = F.interpolate(x, size=feats[2].shape[2:])
        x = torch.cat([x, feats[2]], dim=1)
        x = self.up3(x)

        if x.shape != feats[1].shape: x = F.interpolate(x, size=feats[1].shape[2:])
        x = torch.cat([x, feats[1]], dim=1)
        x = self.up2(x)
        
        if x.shape != feats[0].shape: x = F.interpolate(x, size=feats[0].shape[2:])
        x = torch.cat([x, feats[0]], dim=1)
        x = self.up1(x)
        return self.final(x)

In [6]:
class ColorDecoder(nn.Module):
    def __init__(self, in_channels, num_queries=100, hidden_dim=256, n_heads=4):
        super().__init__()
        self.num_queries = num_queries
        self.projection = nn.Conv2d(in_channels, hidden_dim, kernel_size=1)
        self.queries = nn.Parameter(torch.zeros(1, num_queries, hidden_dim))
        self.transformer_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim, nhead=n_heads, dim_feedforward=hidden_dim*4, batch_first=True
        )
        self.transformer = nn.TransformerDecoder(self.transformer_layer, num_layers=3)
        
    def forward(self, img_features):
        img_features = self.projection(img_features)
        B, C, H, W = img_features.shape
        memory = img_features.view(B, C, -1).permute(0, 2, 1) 
        queries = self.queries.repeat(B, 1, 1) 
        out = self.transformer(queries, memory) 
        return out

In [7]:
class DDColor(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = Backbone() # Now uses the improved Backbone
        self.pixel_decoder = PixelDecoder(self.backbone.chans)
        self.color_decoder = ColorDecoder(in_channels=self.backbone.chans[-1])
        self.fusion_conv = nn.Conv2d(100, 2, kernel_size=1)
        
    def forward(self, x_l):
        feats = self.backbone(x_l)
        img_emb = self.pixel_decoder(feats)
        color_queries = self.color_decoder(feats[-1]) 
        
        B, C, H, W = img_emb.shape
        img_emb_flat = img_emb.view(B, C, -1)
        attention_map = torch.bmm(color_queries, img_emb_flat) 
        attention_map = attention_map.view(B, 100, H, W)
        
        out_ab = self.fusion_conv(attention_map) 
        out_ab = F.interpolate(out_ab, size=(256, 256), mode='bilinear')
        return out_ab

In [8]:
train_dataset = CartoonColorizationDataset(DATA_PATH, split='TRAIN')
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

In [9]:
model = DDColor().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-2)

criterion_l1 = nn.L1Loss()
criterion_perceptual = VGGPerceptualLoss().to(device)

model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:02<00:00, 250MB/s]


In [10]:
def train_one_epoch(model, loader, optimizer, epoch, total_epochs):
    model.train()
    total_loss_val = 0
    
    loop = tqdm(loader, desc=f"Epoch [{epoch+1}/{total_epochs}]", leave=True)
    
    for l_img, ab_gt in loop:
        l_img, ab_gt = l_img.to(device), ab_gt.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        ab_pred = model(l_img)
        
        # --- LOSS CALCULATION ---
        # 1. Pixel L1 Loss
        loss_l1 = criterion_l1(ab_pred, ab_gt)
        
        # 2. Perceptual Loss
        # Concatenate L (1ch) + AB (2ch) to get pseudo-RGB (3ch) for VGG
        pred_lab = torch.cat([l_img, ab_pred], dim=1)
        gt_lab = torch.cat([l_img, ab_gt], dim=1)
        
        # We assume L channel is "perfect" for both, so gradients mainly affect AB
        loss_perc = criterion_perceptual(pred_lab, gt_lab)
        
        # Combined Loss (Weight 0.1 is standard for perceptual loss)
        loss = loss_l1 + (0.1 * loss_perc)
        
        loss.backward()
        optimizer.step()
        
        total_loss_val += loss.item()
        loop.set_postfix(loss=loss.item())
        
    return total_loss_val / len(loader)

In [11]:
history = []

print("Starting training...")
for epoch in range(EPOCHS):
    avg_loss = train_one_epoch(model, train_loader, optimizer, epoch, EPOCHS)
    history.append(avg_loss)
    print(f"Epoch {epoch+1} Complete. Average Loss: {avg_loss:.5f}")
    
    # Save checkpoint every 5 epochs
    if (epoch+1) % 2 == 0:
        torch.save(model.state_dict(), f'ddcolor_cartoon_ep{epoch+1}.pth')
        print(f"Checkpoint saved: ddcolor_cartoon_ep{epoch+1}.pth")

Starting training...


Epoch [1/6]:   0%|          | 0/7477 [00:00<?, ?it/s]

Epoch 1 Complete. Average Loss: 0.14665


Epoch [2/6]:   0%|          | 0/7477 [00:00<?, ?it/s]

Epoch 2 Complete. Average Loss: 0.04831
Checkpoint saved: ddcolor_cartoon_ep2.pth


Epoch [3/6]:   0%|          | 0/7477 [00:00<?, ?it/s]

Epoch 3 Complete. Average Loss: 0.03242


Epoch [4/6]:   0%|          | 0/7477 [00:00<?, ?it/s]

Epoch 4 Complete. Average Loss: 0.02630
Checkpoint saved: ddcolor_cartoon_ep4.pth


Epoch [5/6]:   0%|          | 0/7477 [00:00<?, ?it/s]

Epoch 5 Complete. Average Loss: 0.02307


Epoch [6/6]:   0%|          | 0/7477 [00:00<?, ?it/s]

Epoch 6 Complete. Average Loss: 0.02095
Checkpoint saved: ddcolor_cartoon_ep6.pth


In [12]:
torch.save(model.state_dict(), 'ddcolor_cartoon_last.pth')
print("FINAL MODEL SAVED: ddcolor_cartoon_last.pth")

FINAL MODEL SAVED: ddcolor_cartoon_last.pth
