In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
class GrayscaleCIFAR10(Dataset):
    def __init__(self, train=True):
        self.dataset = torchvision.datasets.CIFAR10(
            root='./data',
            train=train,
            download=True,
            transform=transforms.ToTensor()
        )

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, _ = self.dataset[idx]  # shape: [3, 32, 32]
        gray = transforms.functional.rgb_to_grayscale(img)  # shape: [1, 32, 32]
        return gray, img  # input, target


In [5]:
train_loader = DataLoader(GrayscaleCIFAR10(train=True), batch_size=64, shuffle=True)
test_loader = DataLoader(GrayscaleCIFAR10(train=False), batch_size=64, shuffle=False)


In [6]:
class ColorizationNet(nn.Module):
    def __init__(self):
        super(ColorizationNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(64, 3, 3, padding=1), nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [None]:

model = ColorizationNet().to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(5):  # increase for better results
    model.train()
    running_loss = 0.0
    for gray, color in train_loader:
        gray, color = gray.to(device), color.to(device)

        optimizer.zero_grad()
        outputs = model(gray)
        loss = criterion(outputs, color)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/5], Loss: {running_loss/len(train_loader):.4f}")


Epoch [1/5], Loss: 0.0081


In [None]:
import matplotlib.pyplot as plt

model.eval()
with torch.no_grad():
    for gray, color in test_loader:
        gray = gray.to(device)
        output = model(gray)

        # Show results
        gray = gray.cpu().numpy()
        output = output.cpu().numpy()
        color = color.numpy()

        for i in range(5):  # display 5 examples
            fig, axs = plt.subplots(1, 3, figsize=(9, 3))
            axs[0].imshow(gray[i][0], cmap='gray')
            axs[0].set_title('Grayscale Input')
            axs[1].imshow(np.transpose(output[i], (1, 2, 0)))
            axs[1].set_title('Colorized Output')
            axs[2].imshow(np.transpose(color[i], (1, 2, 0)))
            axs[2].set_title('Ground Truth')
            for ax in axs:
                ax.axis('off')
            plt.show()
        break
