In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc_mu = nn.Linear(128, 32)
        self.fc_logvar = nn.Linear(128, 32)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(32, 128)
        self.fc2 = nn.Linear(128, 784)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return x


class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(
        self, x: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        mu, logvar = self.encoder(x)
        z = self.reparametrize(mu, logvar)
        recon = self.decoder(z)
        return recon, mu, logvar

    def reparametrize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        sigma = torch.exp(0.5 * logvar)
        eps = torch.randn_like(sigma)
        z = mu + sigma * eps
        return z

In [None]:
def vae_loss(
    recon_x: torch.Tensor, x: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor
) -> torch.Tensor:
    recon_x = recon_x.view(recon_x.size(0), -1)
    x = x.view(x.size(0), -1)
    recon_loss = F.binary_cross_entropy(recon_x, x, reduction="sum")
    kl_div = (-0.5) * torch.sum(1 + logvar - (mu**2) - torch.exp(logvar))
    return (recon_loss + kl_div) / x.size(0)

In [None]:
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

train_ds = MNIST(root=".", download=True, train=True, transform=ToTensor())
train_dl = DataLoader(train_ds, batch_size=512, shuffle=True)

In [None]:
device = torch.device("mps" if torch.mps.is_available() else "cpu")
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

In [None]:
for epoch in range(100):
    total_loss = 0
    for x, _ in train_dl:
        x = x.to(device)
        optimizer.zero_grad()
        recon, mu, logvar = model(x)
        loss = vae_loss(recon, x, mu, logvar)
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch}  Loss: {total_loss}")

In [None]:
import matplotlib.pyplot as plt

plt.imshow(recon[1].view(1, 28, 28).squeeze().detach().cpu(), cmap="gray")

In [None]:
plt.imshow(x[1].view(1, 28, 28).squeeze().detach().cpu(), cmap="gray")