In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)

# Define the Autoencoder model
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(128, 28*28),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Initialize the model, loss function, and optimizer
autoencoder = Autoencoder().to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(autoencoder.parameters(), lr=0.001)

# Training loop
epochs = 10

for epoch in range(epochs):
    for data in train_loader:
        inputs, _ = data
        inputs = inputs.view(inputs.size(0), -1).to(device)

        optimizer.zero_grad()
        outputs = autoencoder(inputs)
        loss = criterion(outputs, inputs)
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}")

# Test the Autoencoder
with torch.no_grad():
    for data in test_loader:
        inputs, _ = data
        inputs = inputs.view(inputs.size(0), -1).to(device)
        outputs = autoencoder(inputs)

        # Display original and reconstructed images
        n = 10  # Number of digits to display
        plt.figure(figsize=(20, 4))
        for i in range(n):
            # Original Images
            ax = plt.subplot(2, n, i + 1)
            plt.imshow(inputs[i].cpu().view(28, 28).numpy(), cmap='gray')
            plt.gray()
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)

            # Reconstructed Images
            ax = plt.subplot(2, n, i + 1 + n)
            plt.imshow(outputs[i].cpu().view(28, 28).numpy(), cmap='gray')
            plt.gray()
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)

        plt.show()
        break  # Only display one batch of test data
