In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader

transform = transforms.ToTensor()

train_ds = MNIST(root=".", train=True, download=True, transform=transform)
train_dl = DataLoader(train_ds, batch_size=128, shuffle=True)

In [None]:
import torch.nn as nn


class AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
        )
        self.decoder = nn.Sequential(
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view(x.size(0), -1)
        latent = self.encoder(x)
        output = self.decoder(latent)
        return output.view(x.size(0), 1, 28, 28)

In [None]:
model = AutoEncoder()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

In [None]:
for epoch in range(10):
    for x, _ in train_dl:
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, x)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch} Loss: {loss.item()}")

In [None]:
# Visualize pairs of original and reconstructed images
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i in range(5):
    # Show original
    axes[0, i].imshow(x[i].squeeze(), cmap="gray")
    axes[0, i].set_title("Original")
    axes[0, i].axis("off")
    # Show reconstructed
    axes[1, i].imshow(out[i].detach().numpy().squeeze(), cmap="gray")
    axes[1, i].set_title("Reconstructed")
    axes[1, i].axis("off")

plt.show()