In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
import tkinter as tk
from tkinter import ttk
from PIL import Image, ImageTk


In [2]:
class CIFARSegColorDataset(Dataset):
    def __init__(self, train=True):
        self.dataset = datasets.CIFAR10(root='./data', train=train, download=True)
        self.to_tensor = transforms.ToTensor()
        self.to_gray = transforms.Compose([transforms.Grayscale(), transforms.ToTensor()])

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, _ = self.dataset[idx]
        gray = self.to_gray(img)
        color = self.to_tensor(img)

        # Simulated foreground mask: 16x16 center
        mask = torch.zeros(1, 32, 32)
        mask[:, 8:24, 8:24] = 1.0  # foreground

        return gray, color, mask


In [3]:
class SegmentColorNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(),
        )
        self.middle = nn.Sequential(
            nn.Conv2d(128, 128, 3, padding=1), nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 3, 3, padding=1), nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        return x


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = CIFARSegColorDataset(train=True)
loader = DataLoader(dataset, batch_size=64, shuffle=True)

model = SegmentColorNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

for epoch in range(15):
    model.train()
    total_loss = 0
    for gray, color, mask in loader:
        gray, color, mask = gray.to(device), color.to(device), mask.to(device)

        output = model(gray)

        # Apply mask: only compute loss where mask == 1 (target region)
        masked_output = output * mask
        masked_color = color * mask

        loss = criterion(masked_output, masked_color)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1} Loss: {total_loss/len(loader):.4f}")


Epoch 1 Loss: 0.0021
Epoch 2 Loss: 0.0016
Epoch 3 Loss: 0.0015
Epoch 4 Loss: 0.0015
Epoch 5 Loss: 0.0015


In [None]:
torch.save(model.state_dict(), 'segment_color_model.pth')

In [None]:

from sklearn.metrics import confusion_matrix, precision_score, recall_score

def evaluate_model(model, dataloader):
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for gray, color, mask in dataloader:
            gray, color, mask = gray.to(device), color.to(device), mask.to(device)
            output = model(gray)
            pred_mask = (output > 0.5).float()
            gt_mask = (mask > 0.5).float()
            y_true.extend(gt_mask.cpu().numpy().flatten())
            y_pred.extend(pred_mask.cpu().numpy().flatten())

    cm = confusion_matrix(y_true, y_pred)
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall = recall_score(y_true, y_pred, zero_division=0)
    print("Confusion Matrix:\n", cm)
    print(f"Precision: {precision:.4f}, Recall: {recall:.4f}")


evaluate_model(model, val_loader)


In [None]:
def visualize_targeted_colorization(model, dataset, region="foreground"):
    model.eval()
    gray, color, mask = dataset[0]
    with torch.no_grad():
        pred = model(gray.unsqueeze(0).to(device)).squeeze(0).cpu()

    mask_np = mask.squeeze().numpy()
    pred_np = pred.permute(1, 2, 0).numpy()
    gray_np = gray.squeeze().numpy()
    color_np = color.permute(1, 2, 0).numpy()

    if region == "foreground":
        final = pred_np * mask_np[..., None] + gray_np[..., None] * (1 - mask_np[..., None])
    else:
        final = pred_np * (1 - mask_np[..., None]) + gray_np[..., None] * mask_np[..., None]

    plt.figure(figsize=(10, 4))
    plt.subplot(1, 3, 1)
    plt.imshow(gray_np, cmap='gray')
    plt.title("Input Grayscale")
    plt.subplot(1, 3, 2)
    plt.imshow(color_np)
    plt.title("Ground Truth")
    plt.subplot(1, 3, 3)
    plt.imshow(final)
    plt.title(f"Targeted Colorization ({region})")
    plt.show()


In [None]:
def launch_gui(model, dataset):
    def update(region):
        canvas.delete("all")
        visualize_targeted_colorization(model, dataset, region=region)

    root = tk.Tk()
    root.title("Targeted Colorization Viewer")

    ttk.Button(root, text="Colorize Foreground", command=lambda: update("foreground")).pack()
    ttk.Button(root, text="Colorize Background", command=lambda: update("background")).pack()

    canvas = tk.Canvas(root, width=300, height=300)
    canvas.pack()

    root.mainloop()
