In [None]:
# Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# MNIST dataset (28x28 images of digits 0-9)
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('.', download=True, transform=transforms.ToTensor()), # download and transform to tensor
    batch_size=128, shuffle=True
)

In [None]:
# check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f'Using device: {device}')

In [None]:
# Autoencoder Model
class Autoencoder(nn.Module):
    def __init__(self, dim_latent_space=5):
        super().__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(784, 128),
            nn.ReLU(True),
            nn.Linear(128, 64),
            nn.ReLU(True),
            nn.Linear(64, dim_latent_space)
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(dim_latent_space, 64),
            nn.ReLU(True),
            nn.Linear(64, 128),
            nn.ReLU(True),
            nn.Linear(128, 784),
            nn.Sigmoid()  # to ensure output is between 0 and 1
        )

    # Encode to Latent Space
    def encode(self, x):
        z = self.encoder(x)
        return z

    # Decode to original space
    def decode(self, z):
        x_hat = self.decoder(z)
        return x_hat

    # Forward pass through the network
    def forward(self, x):
        z = self.encode(x)      # pass through encoder
        x_hat = self.decode(z)  # pass through decoder
        return x_hat, z


# Convolutional Autoencoder
class ConvAutoencoder(nn.Module):
    def __init__(self, dim_latent_space=5):
        super().__init__()
        
        # Encoder: convolutional layers
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),  # 28x28 -> 14x14
            nn.ReLU(True),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 14x14 -> 7x7
            nn.ReLU(True),
            nn.Conv2d(32, 64, kernel_size=7),                      # 7x7 -> 1x1
            nn.ReLU(True)
        )
        
        # Fully connected layer to latent space
        self.fc_enc = nn.Linear(64, dim_latent_space)
        self.fc_dec = nn.Linear(dim_latent_space, 64)
        
        # Decoder: transpose convolutions
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=7),             # 1x1 -> 7x7
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),  # 7x7 -> 14x14
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),   # 14x14 -> 28x28
            nn.Sigmoid()  # output in [0,1]
        )
    
    # Encode to latent space
    def encode(self, x):
        x = self.encoder(x)            # x shape: (batch, 64, 1, 1)
        x = x.view(x.size(0), -1)      # flatten (batch, 64)
        z = self.fc_enc(x)             # latent vector
        return z

    # Decode from latent space
    def decode(self, z):
        x = self.fc_dec(z)             # (batch, 64)
        x = x.view(x.size(0), 64, 1, 1)  # reshape for conv transpose
        x_hat = self.decoder(x)
        return x_hat

    # Forward pass
    def forward(self, x):
        z = self.encode(x)
        x_hat = self.decode(z)
        return x_hat, z

# Hyperparameters
dim_latent_space = 3
learning_rate = 1e-4
num_epochs = 30
lambda_reg = 1e-4  # latent regularization strength

# Instantiate model
# model = Autoencoder(dim_latent_space)
model = ConvAutoencoder(dim_latent_space)

# Move model to GPU if available
model.to(device)

# Loss function
criterion = nn.MSELoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


In [None]:
# Training loop
for epoch in range(num_epochs):

    total_recon_loss = 0.0
    total_reg_loss = 0.0

    # Train for one epoch
    for imgs, _ in train_loader:

        imgs = imgs.to(device)  # keep shape (batch, 1, 28, 28)

        # Forward pass
        outputs, z = model(imgs)

        # Reconstruction loss
        recon_loss = criterion(outputs, imgs)
        
        # Latent regularization loss
        reg_loss = lambda_reg * torch.mean(torch.sum(z**2, dim=1))

        # Total loss
        loss = recon_loss + reg_loss

        optimizer.zero_grad()   # clear old gradients
        loss.backward()         # backpropagation
        optimizer.step()        # update model parameters

        # Accumulate losses
        total_recon_loss += recon_loss.item() * imgs.size(0)
        total_reg_loss += reg_loss.item() * imgs.size(0)

    avg_recon_loss = total_recon_loss / len(train_loader.dataset)
    avg_reg_loss = total_reg_loss / len(train_loader.dataset)

    # Print loss every epoch
    print(f"Epoch {epoch+1}/{num_epochs}, "
          f"Recon Loss: {avg_recon_loss:.4f}, "
          f"Reg Loss: {avg_reg_loss:.6f}, "
          f"Total Loss: {avg_recon_loss + avg_reg_loss:.4f}")


In [None]:
import matplotlib.pyplot as plt

# Get one batch
imgs, _ = next(iter(train_loader))
imgs = imgs.to(device)  # keep shape (batch, 1, 28, 28)

# Get reconstructions
with torch.no_grad():
    reconstruction, _ = model(imgs)  # unpack tuple

# Move back to CPU for plotting
imgs = imgs.cpu()
reconstruction = reconstruction.cpu()

# Show original and reconstructed images
num_tests = 10
plt.figure(figsize=(10, 4))
for i in range(num_tests):
    # Original
    ax = plt.subplot(2, num_tests, i + 1)
    plt.imshow(imgs[i].squeeze(), cmap="gray")  # remove channel dimension
    plt.axis("off")
    
    # Reconstruction
    ax = plt.subplot(2, num_tests, i + 1 + num_tests)
    plt.imshow(reconstruction[i].squeeze(), cmap="gray")
    plt.axis("off")

plt.show()


In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import animation
import torch
import numpy as np

# Collect all latent vectors and labels
all_z = []
all_labels = []

with torch.no_grad():
    for imgs, labels in train_loader:
        imgs = imgs.to(device)  # keep shape (batch, 1, 28, 28)
        _, z = model(imgs)
        all_z.append(z.cpu())
        all_labels.append(labels)

all_z = torch.cat(all_z, dim=0).numpy()
all_labels = torch.cat(all_labels, dim=0).numpy()

# Subsample for faster plotting
num_points = 5000
indices = np.random.choice(all_z.shape[0], num_points, replace=False)
subset_z = all_z[indices]
subset_labels = all_labels[indices]

# Create figure and 3D axis
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

# Scatter plot (draw once)
scatter = ax.scatter(subset_z[:, 0], subset_z[:, 1], subset_z[:, 2],
                     c=subset_labels, cmap='tab10', s=10)
ax.set_xlabel('z1')
ax.set_ylabel('z2')
ax.set_zlabel('z3')
ax.set_title('3D Latent Space of MNIST')
plt.legend(*scatter.legend_elements(), title="Digits")

num_frames = 120       # number of frames in the video
total_rotation = 360   # full rotation

def update(frame):
    azim = (frame / num_frames) * total_rotation  # scale frame to full 360°
    ax.view_init(elev=30, azim=azim)
    return ax,

# Create animation
anim = animation.FuncAnimation(fig, update, frames=num_frames, interval=50, blit=False)

# Save video
anim.save('latent_space_rotation_360deg.mp4', writer='ffmpeg', fps=15)

plt.close(fig)
