# 🧠 UNet Autoencoder for Image-to-Image Translation
This notebook trains a UNet-style model on CIFAR-10 to learn image reconstruction (or image-to-image translation).

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [None]:
# UNet building block
class UNetBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(UNetBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.block(x)

In [None]:
# Full UNet model
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(UNet, self).__init__()
        self.enc1 = UNetBlock(in_channels, 64)
        self.enc2 = UNetBlock(64, 128)
        self.enc3 = UNetBlock(128, 256)

        self.pool = nn.MaxPool2d(2)

        self.dec3 = UNetBlock(256 + 128, 128)
        self.dec2 = UNetBlock(128 + 64, 64)

        self.final = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))

        d3 = self.dec3(torch.cat([nn.functional.interpolate(e3, scale_factor=2), e2], dim=1))
        d2 = self.dec2(torch.cat([nn.functional.interpolate(d3, scale_factor=2), e1], dim=1))
        out = self.final(d2)
        return torch.sigmoid(out)

In [None]:
# Load CIFAR10 dataset
transform = transforms.ToTensor()
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.L1Loss()

In [None]:
# Training loop
for epoch in range(5):
    model.train()
    running_loss = 0
    for img, _ in train_loader:
        img = img.to(device)
        input_img = img * 0.5  # Simulated sketch
        target_img = img

        output = model(input_img)
        loss = criterion(output, target_img)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")

In [None]:
# Visualize results
def show_results():
    model.eval()
    with torch.no_grad():
        imgs, _ = next(iter(train_loader))
        input_img = imgs * 0.5
        output = model(input_img.to(device)).cpu()

        for i in range(3):
            fig, ax = plt.subplots(1, 3)
            ax[0].imshow(input_img[i].permute(1, 2, 0))
            ax[0].set_title("Input")
            ax[0].axis("off")

            ax[1].imshow(imgs[i].permute(1, 2, 0))
            ax[1].set_title("Target")
            ax[1].axis("off")

            ax[2].imshow(output[i].permute(1, 2, 0))
            ax[2].set_title("Output")
            ax[2].axis("off")
            plt.show()

# Run it
show_results()