<a href="https://colab.research.google.com/github/tomonari-masada/course2025-stats2/blob/main/vae_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)

X_train, X_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]

In [None]:
X_train = X_train / 255.0
X_test = X_test / 255.0

In [None]:
X_train = torch.tensor(X_train, dtype=torch.float32, device=device)
X_test = torch.tensor(X_test, dtype=torch.float32, device=device)
y_train = torch.tensor([int(y_i) for y_i in y_train], dtype=torch.int64, device=device)
y_test = torch.tensor([int(y_i) for y_i in y_test], dtype=torch.int64, device=device)

In [None]:
encoder = nn.Sequential(
    nn.Linear(784, 400),
    nn.Linear(400, 200),
).to(device)
decoder = nn.Sequential(
    nn.Linear(100, 400),
    nn.Linear(400, 784),
).to(device)

In [None]:
def vae_forward(x, return_latents=False):
    z = encoder(x)
    mu, logvar = z.chunk(2, dim=-1)
    z = mu + torch.exp(logvar / 2) * torch.randn_like(mu)
    x_reconstructed = decoder(z)
    if return_latents:
        return x_reconstructed, mu, logvar, z
    else:
        return x_reconstructed, mu, logvar

In [None]:
def vae_loss(x, x_reconstructed, mu, logvar, kl_weight=1.0):
    reconstruction_loss = nn.functional.binary_cross_entropy_with_logits(x_reconstructed, x, reduction='sum')
    kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return reconstruction_loss + kl_weight * kl_divergence

In [None]:
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3)

batch_size = 1000
kl_weight = batch_size / X_train.size(0)

num_epochs = 100
for epoch in range(num_epochs):
    shuffle_indices = torch.randperm(X_train.size(0))
    for i in range(0, X_train.size(0), batch_size):
        indices = shuffle_indices[i:i + batch_size]
        X_batch = X_train[indices]
        X_reconstructed, mu, logvar = vae_forward(X_batch)
        loss = vae_loss(X_batch, X_reconstructed, mu, logvar, kl_weight=kl_weight)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

In [None]:
X_reconstructed, _, _, z = vae_forward(X_test, return_latents=True)
X_reconstructed = X_reconstructed.detach().cpu().numpy()

In [None]:
plt.imshow(X_reconstructed[0].reshape(28, 28), cmap='gray')
plt.show()

In [None]:
def visualize_reconstruction(X_test, X_reconstructed, test_image_index=0):
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(X_test[test_image_index].reshape(28, 28), cmap='gray')
    plt.title('Original Image')
    plt.subplot(1, 2, 2)
    plt.imshow(X_reconstructed[test_image_index].reshape(28, 28), cmap='gray')
    plt.title('Reconstructed Image')
    plt.show()

In [None]:
visualize_reconstruction(X_test.cpu(), X_reconstructed, test_image_index=0)

In [None]:
visualize_reconstruction(X_test.cpu(), X_reconstructed, test_image_index=100)

In [None]:
# interpolate between two points in latent space
def interpolate_latent_space(z1, z2, num_steps=20):
    z1 = z1.unsqueeze(0)
    z2 = z2.unsqueeze(0)
    interpolated = []
    for alpha in torch.linspace(0, 1, num_steps):
        z_interp = (1 - alpha) * z1 + alpha * z2
        x_reconstructed = decoder(z_interp)
        interpolated.append(x_reconstructed.cpu().detach().numpy())
    return interpolated

In [None]:
interpolated_images = interpolate_latent_space(z[0], z[100])

In [None]:
for i, img in enumerate(interpolated_images):
    plt.subplot(1, len(interpolated_images), i + 1)
    plt.imshow(img.reshape(28, 28), cmap='gray')
    plt.axis('off')
plt.show()