In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np


In [3]:
class CIFARGrayscaleColorization(Dataset):
    def __init__(self, train=True):
        self.dataset = datasets.CIFAR10(root='./data', train=train, download=True)
        self.transform_input = transforms.Compose([
            transforms.ToTensor(),
            transforms.Grayscale(),
        ])
        self.transform_target = transforms.Compose([
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, _ = self.dataset[idx]
        gray = self.transform_input(img)
        color = self.transform_target(img)
        return gray, color


In [4]:
class ColorizationNet(nn.Module):
    def __init__(self):
        super(ColorizationNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(256, 128, 3, padding=1), nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 3, 3, padding=1), nn.Sigmoid()  # Output: RGB [0,1]
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [5]:
# Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
epochs = 20
lr = 1e-3

# Dataset and loader
train_dataset = CIFARGrayscaleColorization(train=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Model, loss, optimizer
model = ColorizationNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)


In [None]:
for epoch in range(epochs):
    total_loss = 0
    for gray, color in train_loader:
        gray, color = gray.to(device), color.to(device)
        optimizer.zero_grad()
        output = model(gray)
        loss = criterion(output, color)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{epochs} - Loss: {total_loss/len(train_loader):.4f}")


Epoch 1/20 - Loss: 0.0108
Epoch 2/20 - Loss: 0.0073
Epoch 3/20 - Loss: 0.0067
Epoch 4/20 - Loss: 0.0064
Epoch 5/20 - Loss: 0.0061
Epoch 6/20 - Loss: 0.0060
Epoch 7/20 - Loss: 0.0059
Epoch 8/20 - Loss: 0.0058
Epoch 9/20 - Loss: 0.0057
Epoch 10/20 - Loss: 0.0057
Epoch 11/20 - Loss: 0.0056
Epoch 12/20 - Loss: 0.0056
Epoch 13/20 - Loss: 0.0055
Epoch 14/20 - Loss: 0.0055
Epoch 15/20 - Loss: 0.0054


In [None]:
def show_colorization_results(model, dataset, num_images=5):
    model.eval()
    fig, axs = plt.subplots(num_images, 3, figsize=(10, num_images * 3))

    for i in range(num_images):
        gray, real = dataset[i]
        with torch.no_grad():
            pred = model(gray.unsqueeze(0).to(device)).squeeze(0).cpu()
        axs[i, 0].imshow(gray.squeeze(), cmap='gray')
        axs[i, 0].set_title("Grayscale")
        axs[i, 1].imshow(real.permute(1, 2, 0))
        axs[i, 1].set_title("Ground Truth")
        axs[i, 2].imshow(pred.permute(1, 2, 0).numpy())
        axs[i, 2].set_title("Predicted")
        for j in range(3):
            axs[i, j].axis("off")

    plt.tight_layout()
    plt.show()

# Visualize
show_colorization_results(model, train_dataset)
