In [None]:
!pip install deeplake --quiet

In [None]:
import deeplake
print(deeplake.__version__)

In [None]:
import matplotlib.pyplot as plt

In [None]:
places205_ds = deeplake.query('SELECT * FROM "hub://activeloop/places205" LIMIT 20000')

places205_train_loader = places205_ds.pytorch()

for i, sample in enumerate(places205_train_loader):
    if i >= 10:
        break
    img = sample['images']
    print(f"Sample {i+1}:")
    print(f"Image shape: {img.shape}")
    plt.imshow(img)
    plt.axis('off')
    plt.show()

In [None]:
print(len(places205_train_loader))

In [None]:
 coco_ds = deeplake.query('SELECT * FROM "hub://activeloop/coco-train" LIMIT 70000')

coco_train_loader = coco_ds.pytorch()

for i, sample in enumerate(coco_train_loader):
    if i >= 10:
        break
    img = sample['images']
    print(f"Sample {i+1}:")
    print(f"Image shape: {img.shape}")
    plt.imshow(img)
    plt.axis('off')
    plt.show()


In [None]:
print(len(coco_train_loader))

In [None]:
import numpy as np
import cv2
import torch
from torchvision import transforms
from PIL import Image

In [None]:
to_tensor = transforms.ToTensor()
to_pil = transforms.ToPILImage()
resize = transforms.Resize((256, 256))

def preprocess_image(img):
    if isinstance(img, torch.Tensor):
        img = to_pil(img)

    if isinstance(img, np.ndarray):
        img = np.squeeze(img)
        if img.ndim == 2 or img.ndim == 3:
            img = Image.fromarray(img.astype(np.uint8))
        else:
            raise ValueError(f"Unsupported image shape for PIL conversion: {img.shape}")

    img = resize(img)

    img_np = np.array(img)

    if img_np.ndim == 2:
        img_np = np.stack([img_np]*3, axis=-1)
    elif img_np.shape[2] == 1:
        img_np = np.repeat(img_np, 3, axis=2)

    lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB).astype(np.float32)
    L = lab[:, :, 0:1] / 255.0
    ab = (lab[:, :, 1:3] - 128) / 128.0

    L_tensor = torch.from_numpy(L).permute(2, 0, 1).float()
    ab_tensor = torch.from_numpy(ab).permute(2, 0, 1).float()

    return L_tensor, ab_tensor


In [None]:
from torch.utils.data import Dataset

In [None]:
from PIL import ImageOps

class ColorizationDataset(Dataset):
    def __init__(self, deeplake_dataset, style_encoder=None, device='cpu', test=0):
        self.dataset = deeplake_dataset
        self.style_encoder = style_encoder
        self.device = device
        self.to_tensor = transforms.ToTensor()
        self.resize_224 = transforms.Resize((224, 224))
        if test == 0:
            self.transform = transforms.Compose([
                transforms.RandomResizedCrop(256, scale=(0.8, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.RandomRotation(30),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
                transforms.ToTensor()
            ])

        else:
            self.transform = transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.ToTensor()
            ])


    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        img = sample['images']

        if isinstance(img, torch.Tensor):
            img_np = img.permute(1, 2, 0).cpu().numpy()
        else:
            img_np = np.array(img)

        img_np = np.squeeze(img_np)

        if img_np.ndim == 2:
            img_np = np.stack([img_np] * 3, axis=-1)
        elif img_np.ndim == 3 and img_np.shape[2] == 1:
            img_np = np.repeat(img_np, 3, axis=2)

        img_np = img_np.astype(np.uint8)
        img_pil = ImageOps.exif_transpose(Image.fromarray(img_np))

        img_aug = self.transform(img_pil)
        rgb_tensor = img_aug.unsqueeze(0).to(self.device)

        style_feats = None
        if self.style_encoder:
            with torch.no_grad():
                features = self.style_encoder(rgb_tensor)
                style_feats = torch.mean(features, dim=[2, 3])

        L, ab = preprocess_image(img_aug)

        return {
            'L': L,
            'ab': ab,
            'style_feats': style_feats.squeeze(0) if style_feats is not None else None
        }


In [None]:
import torch.nn as nn
from torchvision.models import resnet50
from einops import rearrange

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import torchvision.models as models

resnet_style = models.resnet50(pretrained=True)
resnet_style = torch.nn.Sequential(*list(resnet_style.children())[:-1])
resnet_style.eval().to(device)

In [None]:
class ColorQueryDecoder(nn.Module):
    def __init__(self, feature_dim, num_queries=64):
        super().__init__()
        self.queries = nn.Parameter(torch.randn(num_queries, feature_dim))
        self.transformer = nn.Transformer(d_model=feature_dim, batch_first=True, dropout=0.1)
        self.pos_encoding = nn.Parameter(torch.randn(64, 1, feature_dim))
        self.query_norm = nn.LayerNorm(feature_dim)

    def forward(self, features):
        B, C, H, W = features.shape
        x = features.flatten(2).permute(0, 2, 1)
        x = x + self.pos_encoding[:x.size(0)]
        queries = self.queries.unsqueeze(0).expand(B, -1, -1)

        color_features = self.transformer(queries, x)
        color_features = self.query_norm(color_features)
        color_features = F.dropout(color_features, p=0.1, training=self.training)

        return color_features

In [None]:
class SPADE(nn.Module):
    def __init__(self, norm_nc, label_nc):
        super().__init__()
        self.norm = nn.BatchNorm2d(norm_nc, affine=False)
        nhidden = 128
        self.mlp_shared = nn.Sequential(
            nn.Conv2d(label_nc, nhidden, 3, padding=1),
            nn.ReLU()
        )
        self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, 3, padding=1)
        self.mlp_beta = nn.Conv2d(nhidden, norm_nc, 3, padding=1)

    def forward(self, x, segmap):
        segmap = segmap.to(x.device)
        normalized = self.norm(x)
        segmap = nn.functional.interpolate(segmap, size=x.size()[2:], mode='nearest')
        actv = self.mlp_shared(segmap)
        gamma = self.mlp_gamma(actv)
        beta = self.mlp_beta(actv)
        out = normalized * (1 + gamma) + beta
        return out

In [None]:
class InstanceFusion(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels * 2, in_channels, kernel_size=3, padding=1),
            nn.ReLU()
        )

    def forward(self, global_feats, instance_feats):
        combined = torch.cat([global_feats, instance_feats], dim=1)
        return self.conv(combined)

In [None]:
class PixelDecoder(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.up4 = nn.ConvTranspose2d(in_channels, 512, 4, 2, 1)
        self.up3 = nn.ConvTranspose2d(512 + 1024, 256, 4, 2, 1)
        self.up2 = nn.ConvTranspose2d(256 + 512, 128, 4, 2, 1)
        self.out = nn.Conv2d(128, 2, 3, padding=1)

    def forward(self, x4, x3, x2):
        x = self.up4(x4)
        x = torch.cat([x, x3], dim=1)
        x = self.up3(x)
        x = torch.cat([x, x2], dim=1)
        x = self.up2(x)
        out = torch.tanh(self.out(x))

        return out

In [None]:
import torch.nn.functional as F

In [None]:
class ColorizationNet(nn.Module):
    def __init__(self, device='cpu'):
        super().__init__()
        self.device = device
        resnet = resnet50(pretrained=True)
        resnet = resnet.to(self.device)
        self.encoder = nn.ModuleDict({
            "layer1": nn.Sequential(*list(resnet.children())[:5]),
            "layer2": resnet.layer2,
            "layer3": resnet.layer3,
            "layer4": resnet.layer4,
        })

        self.encoder.to(self.device)

        self.pixel_decoder = PixelDecoder(2048).to(self.device)
        self.color_query_decoder = ColorQueryDecoder(2048).to(self.device)
        self.spade = SPADE(2048, 1).to(self.device)
        self.instance_fusion = InstanceFusion(2048).to(self.device)
        self.guidance_proj = nn.Linear(2048, 2048).to(self.device)
        self.instance_gate = nn.Conv2d(2048, 2048, kernel_size=1)
        self.query_gate = nn.Conv2d(2048, 2048, kernel_size=1).to(self.device)

    def forward(self, x_gray, segmap, instance_feats=None, style_feats=None):
        x1 = self.encoder["layer1"](x_gray)
        x2 = self.encoder["layer2"](x1)
        x3 = self.encoder["layer3"](x2)
        x4 = self.encoder["layer4"](x3)

        x4 = self.spade(x4, segmap)

        if instance_feats is not None:
            if instance_feats.shape[1] == 1:
                instance_feats = instance_feats.repeat(1, x4.shape[1], 1, 1)
            if instance_feats.shape[2:] != x4.shape[2:]:
                instance_feats = F.interpolate(instance_feats, size=x4.shape[2:], mode='bilinear', align_corners=False)
            gate = torch.sigmoid(self.instance_gate(x4))
            x4 = self.instance_fusion(x4, instance_feats * gate)

        color_query = self.color_query_decoder(x4)

        query_map = self.guidance_proj(color_query)
        query_map = query_map.permute(0, 2, 1).contiguous()
        query_map = query_map.view(x4.size(0), x4.size(1), x4.size(2), x4.size(3))

        query_map = query_map / (query_map.norm(dim=1, keepdim=True) + 1e-6)
        query_map = F.dropout(query_map, p=0.3, training=self.training)

        gate = torch.sigmoid(self.query_gate(x4))
        x4 = x4 + 0.02 * gate * query_map

        ab_channels = self.pixel_decoder(x4, x3, x2)
        return ab_channels, color_query

In [None]:
from torchvision.models import vgg16
from torchvision.models.feature_extraction import create_feature_extractor

In [None]:
class CombinedLoss(nn.Module):
    def __init__(self, perceptual_weight=0.7, l1_weight=0.8, colorfulness_weight=0.4, query_loss_weight=0.3, centering_weight=0.01, histogram_weight=0.2):
        super().__init__()
        self.l1_loss = nn.L1Loss()
        self.perceptual_weight = perceptual_weight
        self.l1_weight = l1_weight
        self.colorfulness_weight = colorfulness_weight
        self.query_loss_weight = query_loss_weight
        self.centering_weight = centering_weight
        self.histogram_weight = histogram_weight

        vgg = vgg16(pretrained=True).features.eval()
        self.perceptual_extractor = create_feature_extractor(
            vgg,
            return_nodes={
                "3": "relu1_2",
                "8": "relu2_2",
                "15": "relu3_3",
                "22": "relu4_3"
            }
        )
        for param in self.perceptual_extractor.parameters():
            param.requires_grad = False

    def compute_histogram_loss(self, pred_ab, target_ab, bins=32):
        B, _, H, W = pred_ab.shape
        loss = 0.0

        for c in range(2):
            pred_hist = []
            target_hist = []

            for i in range(B):
                pred_vals = pred_ab[i, c].flatten()
                target_vals = target_ab[i, c].flatten()

                pred_h = torch.histc(pred_vals.float(), bins=bins, min=-1.0, max=1.0)
                target_h = torch.histc(target_vals.float(), bins=bins, min=-1.0, max=1.0)

                pred_h = pred_h / (pred_h.sum() + 1e-6)
                target_h = target_h / (target_h.sum() + 1e-6)

                pred_hist.append(pred_h)
                target_hist.append(target_h)

            pred_hist = torch.stack(pred_hist)
            target_hist = torch.stack(target_hist)

            loss += nn.functional.mse_loss(pred_hist, target_hist)

        return loss

    def forward(self, pred_ab, target_ab, input_l, color_query=None, target_query=None):
        pred_ab_upsampled = nn.functional.interpolate(pred_ab, size=input_l.shape[2:], mode='bilinear', align_corners=False)
        target_ab_upsampled = nn.functional.interpolate(target_ab, size=input_l.shape[2:], mode='bilinear', align_corners=False)

        l1_loss_val = self.l1_loss(pred_ab_upsampled, target_ab_upsampled)
        loss = self.l1_weight * l1_loss_val

        pred_lab = torch.cat([input_l, pred_ab_upsampled], dim=1)
        target_lab = torch.cat([input_l, target_ab_upsampled], dim=1)

        pred_feats = self.perceptual_extractor(pred_lab)
        target_feats = self.perceptual_extractor(target_lab)

        perceptual_loss_val = sum(nn.functional.mse_loss(pred_feats[k], target_feats[k]) for k in pred_feats)
        loss += self.perceptual_weight * perceptual_loss_val

        ab_flat = pred_ab.view(pred_ab.size(0), 2, -1)
        std_ab = ab_flat.std(dim=2).mean(dim=1)
        colorfulness_loss_val = torch.relu(0.5 - std_ab).mean()
        loss += self.colorfulness_weight * colorfulness_loss_val

        query_loss_val = 0.0
        if color_query is not None and target_query is not None:
            query_loss_val = nn.functional.mse_loss(color_query, target_query)
            loss += self.query_loss_weight * query_loss_val

        a_mean = pred_ab[:, 0].mean()
        ab_centering_loss_val = torch.abs(a_mean)
        loss += self.centering_weight * ab_centering_loss_val

        histogram_loss_val = self.compute_histogram_loss(pred_ab, target_ab)
        loss += self.histogram_weight * histogram_loss_val

        if torch.isnan(loss):
            print("=== NaN detected in loss ===")
            print("L1 Loss:", l1_loss_val.item())
            print("Perceptual Loss:", perceptual_loss_val.item())
            print("Colorfulness Loss:", colorfulness_loss_val.item())
            print("Query Loss:", query_loss_val.item() if isinstance(query_loss_val, torch.Tensor) else query_loss_val)
            print("Centering Loss:", ab_centering_loss_val.item())
            print("Histogram Loss:", histogram_loss_val.item())
            print("Total Loss: NaN")
            print("===========================")

        return loss

In [None]:
from tqdm import tqdm
import gzip
import pickle
import os

In [None]:
os.makedirs("places_images_only", exist_ok=True)

places_samples = 20000


for i, sample in enumerate(tqdm(places205_train_loader, total=places_samples, desc="Saving resized Places images")):
    if i >= places_samples:
        break

    img = sample['images']

    if isinstance(img, torch.Tensor):
        img = img.permute(1, 2, 0).cpu().numpy()

    img = img.astype(np.uint8)

    if (
        img.ndim != 3 or
        img.shape[0] < 10 or img.shape[1] < 10 or
        img.shape[2] not in [1, 3]
    ):
        print(f"Skipping invalid image at index {i} with shape {img.shape}")
        continue

    try:
        img_pil = Image.fromarray(img.squeeze() if img.shape[2] == 1 else img)
    except Exception as e:
        print(f"Failed to convert image at index {i} with shape {img.shape}: {e}")
        continue

    img_resized = img_pil.resize((256, 256))
    img_resized_np = np.array(img_resized)

    with gzip.open(f"places_images_only/sample_{i}.pt.gz", 'wb') as f:
        pickle.dump({'images': img_resized_np}, f)

In [None]:
os.makedirs("coco_images_only", exist_ok=True)

coco_samples = 20000


for i, sample in enumerate(tqdm(coco_train_loader, total=coco_samples, desc="Saving resized COCO images")):
    if i >= coco_samples:
        break

    img = sample['images']

    if isinstance(img, torch.Tensor):
        img = img.permute(1, 2, 0).cpu().numpy()

    img = img.astype(np.uint8)

    if (
        img.ndim != 3 or
        img.shape[0] < 10 or img.shape[1] < 10 or
        img.shape[2] not in [1, 3]
    ):
        print(f"Skipping invalid image at index {i} with shape {img.shape}")
        continue

    try:
        img_pil = Image.fromarray(img.squeeze() if img.shape[2] == 1 else img)
    except Exception as e:
        print(f"Failed to convert image at index {i} with shape {img.shape}: {e}")
        continue

    img_resized = img_pil.resize((256, 256))
    img_resized_np = np.array(img_resized)

    with gzip.open(f"coco_images_only/sample_{i}.pt.gz", 'wb') as f:
        pickle.dump({'images': img_resized_np}, f)

In [None]:
coco_ds_offset = deeplake.query('SELECT * FROM "hub://activeloop/coco-train" LIMIT 20000 OFFSET 20000')
coco_train_loader_2 = coco_ds_offset.pytorch()

In [None]:
start_index = 20000
coco_samples = 20000

for i, sample in enumerate(tqdm(coco_train_loader_2, total=coco_samples, desc="Saving resized COCO images")):
    if i >= coco_samples:
        break

    img = sample['images']

    if isinstance(img, torch.Tensor):
        img = img.permute(1, 2, 0).cpu().numpy()

    img = img.astype(np.uint8)

    if (
        img.ndim != 3 or
        img.shape[0] < 10 or img.shape[1] < 10 or
        img.shape[2] not in [1, 3]
    ):
        print(f"Skipping invalid image at index {i} with shape {img.shape}")
        continue

    try:
        img_pil = Image.fromarray(img.squeeze() if img.shape[2] == 1 else img)
    except Exception as e:
        print(f"Failed to convert image at index {i} with shape {img.shape}: {e}")
        continue

    img_resized = img_pil.resize((256, 256))
    img_resized_np = np.array(img_resized)

    save_index = start_index + i
    with gzip.open(f"coco_images_only/sample_{save_index}.pt.gz", 'wb') as f:
        pickle.dump({'images': img_resized_np}, f)

In [None]:
coco_ds_offset = deeplake.query('SELECT * FROM "hub://activeloop/coco-train" LIMIT 10000 OFFSET 40000')
coco_train_loader_3 = coco_ds_offset.pytorch()

In [None]:
start_index = 40000
coco_samples = 10000

for i, sample in enumerate(tqdm(coco_train_loader_3, total=coco_samples, desc="Saving resized COCO images")):
    if i >= coco_samples:
        break

    img = sample['images']

    if isinstance(img, torch.Tensor):
        img = img.permute(1, 2, 0).cpu().numpy()

    img = img.astype(np.uint8)

    if (
        img.ndim != 3 or
        img.shape[0] < 10 or img.shape[1] < 10 or
        img.shape[2] not in [1, 3]
    ):
        print(f"Skipping invalid image at index {i} with shape {img.shape}")
        continue

    try:
        img_pil = Image.fromarray(img.squeeze() if img.shape[2] == 1 else img)
    except Exception as e:
        print(f"Failed to convert image at index {i} with shape {img.shape}: {e}")
        continue

    img_resized = img_pil.resize((256, 256))
    img_resized_np = np.array(img_resized)

    save_index = start_index + i
    with gzip.open(f"coco_images_only/sample_{save_index}.pt.gz", 'wb') as f:
        pickle.dump({'images': img_resized_np}, f)

In [None]:
places_dataset = ColorizationDataset(places205_ds, style_encoder=resnet_style, device=device)
coco_dataset = ColorizationDataset(coco_ds, style_encoder=resnet_style, device=device)

In [None]:
from torch.utils.data import DataLoader

In [None]:
combined_dataset = torch.utils.data.ConcatDataset([places_dataset, coco_dataset])
train_loader = DataLoader(combined_dataset, batch_size=8, shuffle=True)

In [None]:
class DiskImageDataset(Dataset):
    def __init__(self, cache_dir, num_samples):
        self.cache_dir = cache_dir
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # return torch.load(os.path.join(self.cache_dir, f"sample_{idx}.pt"), weights_only=False)
        path = os.path.join(self.cache_dir, f"sample_{idx}.pt.gz")
        with gzip.open(path, 'rb') as f:
            return pickle.load(f)

In [None]:
places_disk = DiskImageDataset("places_images_only", 20000)
coco_disk = DiskImageDataset("coco_images_only", 50000)

places_dataset = ColorizationDataset(places_disk, style_encoder=resnet_style, device=device)
coco_dataset = ColorizationDataset(coco_disk, style_encoder=resnet_style, device=device)

In [None]:
combined_dataset = torch.utils.data.ConcatDataset([places_dataset, coco_dataset])

In [None]:
places_visualization = ColorizationDataset(places_disk, style_encoder=resnet_style, device=device, test=True)
coco_visualization = ColorizationDataset(coco_disk, style_encoder=resnet_style, device=device, test=True)

In [None]:
from torch.utils.data import random_split

train_size = int(0.9 * len(combined_dataset))
val_size = int(0.05 * len(combined_dataset))
test_size = len(combined_dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(combined_dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

In [None]:
print(len(train_loader))
print(len(val_loader))
print(len(test_loader))

In [None]:
from torch.amp import GradScaler, autocast

In [None]:
model = ColorizationNet(device=device)
loss_fn = CombinedLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
scaler = GradScaler('cuda')

In [None]:
from torchvision.models.segmentation import deeplabv3_resnet50

deeplab = deeplabv3_resnet50(pretrained=True).eval().to(device)

In [None]:
from torchvision.models.detection import maskrcnn_resnet50_fpn
mask_rcnn = maskrcnn_resnet50_fpn(pretrained=True).eval().to(device)

In [None]:
def extract_instance_feats(model, image_tensor, device, threshold=0.5):
    model.eval()
    with torch.no_grad():
        image_tensor = image_tensor.to(device)
        detections = model([image_tensor])[0]

        if "masks" not in detections or len(detections["masks"]) == 0:
            return torch.zeros((1, 256, 256)).to(device)

        masks = detections["masks"]
        masks = masks.squeeze(1)

        binary_masks = (masks > threshold).float()
        instance_map = binary_masks.sum(0, keepdim=True)

        return instance_map

In [None]:
def lab_to_rgb_opencv(L, ab):
    L = (L * 255).astype(np.uint8)
    ab = ((ab * 128) + 128).astype(np.uint8)
    lab = np.concatenate([L, ab], axis=2)
    rgb = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
    return np.clip(rgb, 0, 255).astype(np.uint8)

In [None]:
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

def compute_metrics(pred_ab, target_ab, input_l):
    pred_ab_np = pred_ab.detach().cpu().permute(1, 2, 0).numpy()
    target_ab_np = target_ab.detach().cpu().permute(1, 2, 0).numpy()
    L_np = input_l.detach().cpu().squeeze(0).numpy()[..., np.newaxis]

    pred_rgb = lab_to_rgb_opencv(L_np, pred_ab_np)
    target_rgb = lab_to_rgb_opencv(L_np, target_ab_np)

    psnr_val = psnr(target_rgb, pred_rgb, data_range=255)
    ssim_val = ssim(target_rgb, pred_rgb, channel_axis=-1, data_range=255)

    return psnr_val, ssim_va

In [None]:
from google.colab import drive
drive.mount('/content/drive')

SAVE_DIR = '/content/drive/MyDrive/colorization_models/'
os.makedirs(SAVE_DIR, exist_ok=True)

In [None]:
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F
from tqdm import tqdm
import os

def train(model, loss_fn, train_loader, val_loader, optimizer, device, epochs=10):
    model.to(device)
    scaler = GradScaler()

    best_val_loss = float('inf')
    patience = 3
    epochs_no_improve = 0

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=True)

        for batch in progress_bar:
            input_l = batch['L'].to(device)
            target_ab = batch['ab'].to(device)
            input_l_3ch = input_l.repeat(1, 3, 1, 1) if input_l.shape[1] == 1 else input_l

            with torch.no_grad():
                seg_output = deeplab(input_l_3ch)['out']
                segmap = torch.argmax(seg_output, dim=1, keepdim=True).float()

                instance_feats = torch.stack([
                    extract_instance_feats(mask_rcnn, img.cpu(), device)
                    for img in input_l_3ch
                ]).to(device)

            optimizer.zero_grad(set_to_none=True)

            with autocast():
                pred_ab, color_query = model(input_l_3ch, segmap, instance_feats=instance_feats, style_feats=style_feats)
                _, target_query = model(input_l_3ch, segmap, instance_feats=instance_feats)
                loss = loss_fn(pred_ab, target_ab, input_l, color_query, target_query)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()
            avg_loss = running_loss / (progress_bar.n + 1)
            progress_bar.set_postfix(train_loss=avg_loss)

        print(f"Epoch {epoch+1}, Train Loss: {avg_loss:.4f}")

        model.eval()
        val_loss = 0.0
        psnr_list = []
        ssim_list = []

        with torch.no_grad():
            for batch in val_loader:
                input_l = batch['L'].to(device)
                target_ab = batch['ab'].to(device)
                input_l_3ch = input_l.repeat(1, 3, 1, 1) if input_l.shape[1] == 1 else input_l

                seg_output = deeplab(input_l_3ch)['out']
                segmap = torch.argmax(seg_output, dim=1, keepdim=True).float()

                instance_feats = torch.stack([
                    extract_instance_feats(mask_rcnn, img.cpu(), device)
                    for img in input_l_3ch
                ]).to(device)

                pred_ab, color_query = model(input_l_3ch, segmap, instance_feats=instance_feats, style_feats=style_feats)
                _, target_query = model(input_l_3ch, segmap, instance_feats=instance_feats)

                loss = loss_fn(pred_ab, target_ab, input_l, color_query, target_query)
                val_loss += loss.item()

                pred_ab_upsampled = F.interpolate(pred_ab, size=input_l.shape[2:], mode='bilinear', align_corners=False)
                for i in range(pred_ab_upsampled.size(0)):
                    psnr_val, ssim_val = compute_metrics(pred_ab_upsampled[i], target_ab[i], input_l[i])
                    psnr_list.append(psnr_val)
                    ssim_list.append(ssim_val)

        avg_val_loss = val_loss / len(val_loader)
        avg_psnr = sum(psnr_list) / len(psnr_list)
        avg_ssim = sum(ssim_list) / len(ssim_list)

        print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}, PSNR: {avg_psnr:.2f}, SSIM: {avg_ssim:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), os.path.join(SAVE_DIR, 'best_model_epoch.pth'))
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scaler_state_dict': scaler.state_dict(),
                'best_val_loss': best_val_loss,
                'epochs_no_improve': epochs_no_improve
            }
            torch.save(checkpoint, os.path.join(SAVE_DIR, 'checkpoint_epoch.pth'))
            print("Saved new best model.")
        else:
            epochs_no_improve += 1
            print(f"No improvement for {epochs_no_improve} epoch(s).")
            if epochs_no_improve >= patience:
                print("Early stopping triggered.")
                break

In [None]:
train(model, loss_fn, train_loader, val_loader, optimizer, device, epochs=5)

In [None]:
def evaluate_test(model, loss_fn, test_loader, device):
    model.eval()
    model.to(device)

    total_loss = 0.0
    with torch.no_grad():
        psnr_list = []
        ssim_list = []

        for batch in tqdm(test_loader, desc="Testing"):
            input_l = batch['L'].to(device)
            target_ab = batch['ab'].to(device)
            input_l_3ch = input_l.repeat(1, 3, 1, 1)

            seg_output = deeplab(input_l_3ch)['out']
            segmap = torch.argmax(seg_output, dim=1, keepdim=True).float()

            instance_feats = torch.stack([
                extract_instance_feats(mask_rcnn, img.cpu(), device)
                for img in input_l_3ch
            ]).to(device)

            _, target_query = model(input_l_3ch, segmap, instance_feats=instance_feats)

            pred_ab, color_query = model(input_l_3ch, segmap, instance_feats=instance_feats, style_feats=style_feats)

            loss = loss_fn(pred_ab, target_ab, input_l, color_query, target_query)
            total_loss += loss.item()

            pred_ab_upsampled = F.interpolate(pred_ab, size=input_l.shape[2:], mode='bilinear', align_corners=False)

            for i in range(pred_ab_upsampled.size(0)):
                psnr_val, ssim_val = compute_metrics(
                    pred_ab_upsampled[i], target_ab[i], input_l[i]
                )
                psnr_list.append(psnr_val)
                ssim_list.append(ssim_val)


    avg_psnr = sum(psnr_list) / len(psnr_list)
    avg_ssim = sum(ssim_list) / len(ssim_list)
    print(f"Validation PSNR: {avg_psnr:.2f}, SSIM: {avg_ssim:.4f}")
    avg_test_loss = total_loss / len(test_loader)
    print(f"Test Loss: {avg_test_loss:.4f}")

In [None]:
evaluate_test(model, loss_fn, test_loader, device)

In [None]:
import matplotlib.pyplot as plt
import torch.nn.functional as F
import cv2

def test_and_visualize(model, dataset, device, num_samples=30):
    model.eval()
    model.to(device)

    fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4 * num_samples))

    for i in range(num_samples):
        sample = dataset[i]
        input_l = sample['L'].unsqueeze(0).to(device)
        target_ab = sample['ab'].unsqueeze(0).to(device)

        input_l_3ch = input_l.repeat(1, 3, 1, 1)

        with torch.no_grad():
            seg_output = deeplab(input_l_3ch)['out']
            segmap = torch.argmax(seg_output, dim=1, keepdim=True).float().to(device)

            img_for_detection = input_l_3ch[0].cpu()
            instance_feats = extract_instance_feats(mask_rcnn, img_for_detection, device)
            instance_feats = instance_feats.unsqueeze(0)

            pred_ab, _ = model(input_l_3ch, segmap, instance_feats=instance_feats)

        pred_ab_upsampled = F.interpolate(pred_ab, size=input_l.shape[2:], mode='bilinear', align_corners=False)

        L_np = input_l.squeeze().cpu().numpy()[..., np.newaxis]
        pred_ab_np = pred_ab_upsampled.squeeze().cpu().numpy().transpose(1, 2, 0)
        target_ab_np = target_ab.squeeze().cpu().numpy().transpose(1, 2, 0)

        # print(pred_ab.min().item(), pred_ab.max().item())

        pred_rgb = lab_to_rgb_opencv(L_np, pred_ab_np)
        target_rgb = lab_to_rgb_opencv(L_np, target_ab_np)
        gray_img = (L_np * 255).astype(np.uint8).squeeze()

        axes[i, 0].imshow(gray_img, cmap='gray')
        axes[i, 0].set_title('Grayscale Input')
        axes[i, 1].imshow(pred_rgb)
        axes[i, 1].set_title('Predicted Colorization')
        axes[i, 2].imshow(target_rgb)
        axes[i, 2].set_title('Ground Truth')

        for ax in axes[i]:
            ax.axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
test_and_visualize(model, coco_visualization, device)

In [None]:
import numpy as np
import cv2
from PIL import Image
import torch

def preprocess_external_image(path):
    img = Image.open(path).convert("RGB").resize((256, 256))
    img_np = np.array(img)

    lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB).astype(np.float32)
    L = lab[:, :, 0:1] / 255.0
    ab = (lab[:, :, 1:3] - 128) / 128.0

    L_tensor = torch.from_numpy(L).permute(2, 0, 1).unsqueeze(0).float().to(device)
    ab_tensor = torch.from_numpy(ab).permute(2, 0, 1).unsqueeze(0).float().to(device)

    return L_tensor, ab_tensor, img_np


In [None]:
def test_external_images(model, paths, device):
    model.eval()
    model.to(device)

    fig, axes = plt.subplots(len(paths), 3, figsize=(12, 4 * len(paths)))

    for i, path in enumerate(paths):
        input_l, target_ab, original_img = preprocess_external_image(path)
        input_l_3ch = input_l.repeat(1, 3, 1, 1)

        with torch.no_grad():
            seg_output = deeplab(input_l_3ch)['out']
            segmap = torch.argmax(seg_output, dim=1, keepdim=True).float().to(device)

            instance_feats = extract_instance_feats(mask_rcnn, input_l_3ch[0].cpu(), device).unsqueeze(0)

            pred_ab, _ = model(input_l_3ch, segmap, instance_feats=instance_feats)

            pred_ab_upsampled = F.interpolate(pred_ab, size=input_l.shape[2:], mode='bilinear', align_corners=False)

        L_np = input_l.squeeze().cpu().numpy()[..., np.newaxis]
        pred_ab_np = pred_ab_upsampled.squeeze().cpu().numpy().transpose(1, 2, 0)
        target_ab_np = target_ab.squeeze().cpu().numpy().transpose(1, 2, 0)

        pred_rgb = lab_to_rgb_opencv(L_np, pred_ab_np)
        target_rgb = lab_to_rgb_opencv(L_np, target_ab_np)
        gray_img = (L_np * 255).astype(np.uint8).squeeze()

        axes[i, 0].imshow(gray_img, cmap='gray')
        axes[i, 0].set_title(f"{path} - Grayscale")
        axes[i, 1].imshow(pred_rgb)
        axes[i, 1].set_title("Predicted Colorization")
        axes[i, 2].imshow(original_img)
        axes[i, 2].set_title("Original RGB")

        for ax in axes[i]:
            ax.axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
image_paths = ["1.png", "2.png", "3.png", "4.png", "5.png"]
test_external_images(model, image_paths, device)