In [None]:
# Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# MNIST dataset (28x28 images of digits 0-9)
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('.', download=True, transform=transforms.ToTensor()), # download and transform to tensor
    batch_size=64, shuffle=True
)

In [None]:
# check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f'Using device: {device}')

In [None]:
# Autoencoder Model
class Autoencoder(nn.Module):
    def __init__(self, dim_latent_space=5):
        super().__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(784, 128),
            nn.ReLU(True),
            nn.Linear(128, 64),
            nn.ReLU(True),
            nn.Linear(64, dim_latent_space)
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(dim_latent_space, 64),
            nn.ReLU(True),
            nn.Linear(64, 128),
            nn.ReLU(True),
            nn.Linear(128, 784),
            nn.Sigmoid()  # to ensure output is between 0 and 1
        )

    # Encode to Latent Space
    def encode(self, x):
        z = self.encoder(x)
        return z

    # Decode to original space
    def decode(self, z):
        x_hat = self.decoder(z)
        return x_hat

    # Forward pass through the network
    def forward(self, x):
        z = self.encode(x)      # pass through encoder
        x_hat = self.decode(z)  # pass through decoder
        return x_hat

# Instantiate model
dim_latent_space = 5
model = Autoencoder(dim_latent_space)

# Move model to GPU if available
model.to(device)

# Loss function
criterion = nn.MSELoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
# Training loop
num_epochs = 20
for epoch in range(num_epochs):

    # Train for one epoch
    for imgs, _ in train_loader:

        imgs = imgs.view(imgs.size(0), -1)       # flatten 28×28 to 784
        imgs = imgs.to(device)                   # move to GPU if available

        outputs = model(imgs)                    # forward pass
        loss = criterion(outputs, imgs)          # compute loss

        optimizer.zero_grad()   # clear old gradients
        loss.backward()         # backpropagation
        optimizer.step()        # update model parameters

    # Print loss every epoch
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")



In [None]:
import matplotlib.pyplot as plt

# Get one batch
imgs, _ = next(iter(train_loader))
imgs = imgs.view(imgs.size(0), -1).to(device)  # Move to GPU

# Get reconstructions
with torch.no_grad():
    reconstruction = model(imgs)

# Move back to CPU for plotting
imgs = imgs.cpu()
reconstruction = reconstruction.cpu()

# Show original and reconstructed images
num_tests = 10
plt.figure(figsize=(10, 4))
for i in range(num_tests):
    # Original
    ax = plt.subplot(2, num_tests, i + 1)
    plt.imshow(imgs[i].view(28, 28), cmap="gray")
    plt.axis("off")
    
    # Reconstruction
    ax = plt.subplot(2, num_tests, i + 1 + num_tests)
    plt.imshow(reconstruction[i].view(28, 28), cmap="gray")
    plt.axis("off")

plt.show()
