In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
class CrossDomainDataset(torch.utils.data.Dataset):
    def __init__(self, domain='sketch'):
        self.domain = domain
        transform = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor()
        ])
        self.dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)

    def __getitem__(self, index):
        img, _ = self.dataset[index]
        color = img

        if self.domain == 'sketch':
            gray = torch.mean(img, dim=0, keepdim=True)
        elif self.domain == 'satellite':
            gray = img[1].unsqueeze(0)  # use only green channel
        else:
            raise ValueError("Invalid domain")

        return gray, color

    def __len__(self):
        return len(self.dataset)

In [4]:
class UNetColorization(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 3, 3, padding=1), nn.Tanh()  # Output 3 channels
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [5]:
def train_one_epoch(domain='sketch'):
    dataset = CrossDomainDataset(domain)
    loader = DataLoader(dataset, batch_size=32, shuffle=True)

    model = UNetColorization().to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    model.train()
    for gray, color in loader:
        gray, color = gray.to(device), color.to(device)

        output = model(gray)
        loss = criterion(output, color)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Domain: {domain} | Loss: {loss.item():.4f}")
    return model

In [None]:

torch.save(model_sketch.state_dict(), 'unet_sketch.pth')
torch.save(model_satellite.state_dict(), 'unet_satellite.pth')


with torch.no_grad():
    pred = model_sketch(gray.unsqueeze(0).to(device))
    accuracy = (pred.round() == color.to(device)).float().mean()
    print(f"Sketch model accuracy: {accuracy.item():.4f}")

In [6]:
def visualize_output(model, domain='sketch'):
    dataset = CrossDomainDataset(domain)
    gray, color = dataset[0]
    model.eval()

    with torch.no_grad():
        pred = model(gray.unsqueeze(0).to(device))
    
    plt.figure(figsize=(12,4))
    plt.subplot(1, 3, 1)
    plt.title("Grayscale Input")
    plt.imshow(gray.squeeze().cpu(), cmap='gray')
    
    plt.subplot(1, 3, 2)
    plt.title("Ground Truth Color")
    plt.imshow(color.permute(1,2,0).cpu())

    plt.subplot(1, 3, 3)
    plt.title("Predicted Color")
    plt.imshow(pred.squeeze().permute(1,2,0).cpu().clamp(0,1))
    plt.show()

In [None]:
model_sketch = train_one_epoch(domain='sketch')
visualize_output(model_sketch, domain='sketch')

model_satellite = train_one_epoch(domain='satellite')
visualize_output(model_satellite, domain='satellite')

Domain: sketch | Loss: 0.0064
