In [31]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torchvision import datasets
from torch.utils.data import DataLoader
from skimage.color import rgb2lab, lab2rgb
from PIL import Image
import os
from torchvision.models import vgg16
import torch.nn.functional as F
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import gradio as gr

In [32]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [33]:
class LabColorTransform:
    def __call__(self, img):
        img_np = np.array(img) / 255.0
        lab_img = rgb2lab(img_np).astype("float32")
        L = lab_img[:, :, 0] / 100.0  # Normalize L to [0, 1]
        ab = lab_img[:, :, 1:] / 128.0  # Normalize ab to [-1, 1]
        return torch.tensor(L).unsqueeze(0), torch.tensor(ab).permute(2, 0, 1)

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    LabColorTransform()
])

# Load CIFAR-10 (or your custom dataset)
train_dataset = datasets.CIFAR10(
    root='./data', 
    train=True,
    download=True,
    transform=transform
)
test_dataset = datasets.CIFAR10(
    root='./data', 
    train=False,
    download=True,
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

In [34]:
class ColorizationNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 2, kernel_size=3, stride=1, padding=1),
            nn.Tanh()  # Output in [-1, 1] range for ab channels
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [35]:
class CombinedLoss(nn.Module):
    def __init__(self, vgg):
        super().__init__()
        self.vgg = vgg
        self.mse = nn.MSELoss()
        
    def forward(self, output, target):
        # MSE Loss on ab channels
        mse_loss = self.mse(output, target)
        
        # Perceptual Loss
        if output.shape[1] == 2:  # Convert to 3-channel for VGG
            output_rgb = torch.cat([torch.zeros_like(output[:, :1]), output], dim=1)
            target_rgb = torch.cat([torch.zeros_like(target[:, :1]), target], dim=1)
        
        output_rgb = F.interpolate(output_rgb, size=(224, 224), mode='bilinear')
        target_rgb = F.interpolate(target_rgb, size=(224, 224), mode='bilinear')
        output_features = self.vgg(output_rgb)
        target_features = self.vgg(target_rgb)
        perceptual_loss = self.mse(output_features, target_features)
        
        return mse_loss + 0.1 * perceptual_loss

# Initialize VGG for perceptual loss
vgg_model = vgg16(pretrained=True).features[:9].eval().to(device)
for param in vgg_model.parameters():
    param.requires_grad = False



In [37]:
model = ColorizationNet().to(device)
criterion = CombinedLoss(vgg_model)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

In [38]:
def train_model(epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for L, ab in train_loader:
            L, ab = L.to(device), ab.to(device)
            
            optimizer.zero_grad()
            outputs = model(L)
            loss = criterion(outputs, ab)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        scheduler.step()
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader):.4f}")
        # Save model checkpoint
        torch.save(model.state_dict(), f"colorization_model_epoch{epoch+1}.pth")
    
    print("Training complete!")
    torch.save(model.state_dict(), "final_colorization_model.pth")

In [39]:
def evaluate_model(model, dataloader):
    model.eval()
    total_psnr = 0
    total_ssim = 0
    count = 0
    
    with torch.no_grad():
        for L, ab in dataloader:
            L, ab = L.to(device), ab.to(device)
            outputs = model(L)
            
            # Convert to numpy for evaluation
            outputs_np = outputs.cpu().numpy()
            ab_np = ab.cpu().numpy()
            L_np = L.cpu().numpy()
            
            # Calculate metrics for each image in batch
            for i in range(outputs_np.shape[0]):
                # Reconstruct LAB image
                pred_lab = np.concatenate([
                    (L_np[i][0] * 100).astype('float32'),
                    (outputs_np[i] * 128).astype('float32')
                ], axis=0).transpose(1, 2, 0)
                
                target_lab = np.concatenate([
                    (L_np[i][0] * 100).astype('float32'),
                    (ab_np[i] * 128).astype('float32')
                ], axis=0).transpose(1, 2, 0)
                
                # Convert to RGB for visualization
                pred_rgb = lab2rgb(pred_lab)
                target_rgb = lab2rgb(target_lab)
                # Calculate metrics
                total_psnr += psnr(target_rgb, pred_rgb, data_range=1.0)
                total_ssim += ssim(target_rgb, pred_rgb, multichannel=True, data_range=1.0)
                count += 1
                
    print(f"Average PSNR: {total_psnr/count:.4f}")
    print(f"Average SSIM: {total_ssim/count:.4f}")
    return total_psnr/count, total_ssim/count


In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns

def color_confusion_matrix(model, dataloader):
    model.eval()
    true_colors = []
    pred_colors = []
    
    with torch.no_grad():
        for L, ab in dataloader:
            L, ab = L.to(device), ab.to(device)
            outputs = model(L)
            
            # Convert to color bins (8 bins per channel)
            true_bins = (ab * 4 + 4).clamp(0, 7).long().flatten()
            pred_bins = (outputs * 4 + 4).clamp(0, 7).long().flatten()
            
            true_colors.extend(true_bins.cpu().numpy())
            pred_colors.extend(pred_bins.cpu().numpy())
    
    cm = confusion_matrix(true_colors, pred_colors, labels=range(8))
    
    plt.figure(figsize=(10,8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('Color Value Confusion Matrix')
    plt.xlabel('Predicted Color Bins')
    plt.ylabel('True Color Bins')
    plt.show()
    
    return cm

In [None]:
from sklearn.metrics import precision_score, recall_score

def calculate_precision_recall(model, dataloader):
    model.eval()
    true_colors = []
    pred_colors = []
    
    with torch.no_grad():
        for L, ab in dataloader:
            L, ab = L.to(device), ab.to(device)
            outputs = model(L)
            
            # Convert to color bins
            true_bins = (ab * 4 + 4).clamp(0, 7).long().flatten()
            pred_bins = (outputs * 4 + 4).clamp(0, 7).long().flatten()
            
            true_colors.extend(true_bins.cpu().numpy())
            pred_colors.extend(pred_bins.cpu().numpy())
    
    precision = precision_score(true_colors, pred_colors, average='macro', zero_division=0)
    recall = recall_score(true_colors, pred_colors, average='macro', zero_division=0)
    
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    
    return precision, recall

In [40]:
def colorize_image(input_image):
    # Convert to LAB
    img_np = np.array(input_image) / 255.0
    lab_img = rgb2lab(img_np).astype("float32")
    L = lab_img[:, :, 0] / 100.0
    
    # Convert to tensor
    L_tensor = torch.tensor(L).unsqueeze(0).unsqueeze(0).float().to(device)
    
    # Predict
    with torch.no_grad():
        ab_pred = model(L_tensor)
    
    # Convert back to RGB
    ab_pred = ab_pred.squeeze().cpu().numpy() * 128
    pred_lab = np.concatenate([
        (lab_img[:, :, 0]),
        ab_pred.transpose(1, 2, 0)
    ], axis=2)
    pred_rgb = (lab2rgb(pred_lab) * 255).astype('uint8')
    
    return pred_rgb
        

In [None]:
if __name__ == "__main__":
    # Train the model
    train_model(epochs=10)
    
    # Evaluate
    print("\nEvaluating model...")
    avg_psnr, avg_ssim = evaluate_model(model, test_loader)
    
    # Confusion Matrix and Metrics
    print("\nCalculating confusion matrix...")
    color_confusion_matrix(model, test_loader)
    
    print("\nCalculating precision/recall...")
    precision, recall = calculate_precision_recall(model, test_loader)
    
    # Save model if metrics are good
    if avg_psnr > 20 and avg_ssim > 0.7 and precision > 0.7:
        torch.save({
            'model_state_dict': model.state_dict(),
            'metrics': {
                'psnr': avg_psnr,
                'ssim': avg_ssim,
                'precision': precision,
                'recall': recall
            }
        }, "final_colorization_model.pth")
        print("\nModel saved with good performance!")
    
    # Launch GUI with proper examples
    print("\nLaunching GUI...")
    examples = [
        ["examples/grayscale1.jpg"],
        ["examples/grayscale2.jpg"]
    ]
    
    iface = gr.Interface(
        fn=colorize_image,
        inputs=gr.Image(type="pil"),
        outputs="image",
        title="Image Colorization",
        examples=examples
    )
    iface.launch()