In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

class Autoencoder(nn.Module):
    def __init__(self, n_filters=[1, 10, 10, 10], filter_sizes=[3, 3, 3, 3], corruption=False):
        super(Autoencoder, self).__init__()
        
        self.corruption = corruption
        if corruption:
            self.noise = nn.GaussianNoise(0.5)
        
        # Encoder
        self.encoders = nn.ModuleList()
        in_channels = n_filters[0]
        for out_channels, kernel_size in zip(n_filters[1:], filter_sizes):
            self.encoders.append(nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2)))
            self.encoders.append(nn.LeakyReLU())
            in_channels = out_channels
        
        # Decoder
        self.decoders = nn.ModuleList()
        for out_channels, kernel_size in zip(n_filters[:-1][::-1], filter_sizes[::-1]):
            self.decoders.append(nn.ConvTranspose2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2)))
            self.decoders.append(nn.LeakyReLU())
            in_channels = out_channels
            
        self.final_decoder = nn.Conv2d(in_channels, 1, 3, padding=1)
        
    def forward(self, x):
        if self.corruption:
            x = self.noise(x)
        
        for layer in self.encoders:
            x = layer(x)
            
        for layer in self.decoders:
            x = layer(x)
            
        x = torch.sigmoid(self.final_decoder(x))
        return x

def test_mnist():
    # Load MNIST dataset
    transform = transforms.Compose([transforms.ToTensor()])
    train_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=True, download=True, transform=transform), batch_size=256, shuffle=True)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=False, transform=transform), batch_size=256, shuffle=False)
    
    # Create the Autoencoder model
    model = Autoencoder()
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters())
    
    # Train the model
    for epoch in range(10):
        model.train()
        for data, _ in train_loader:
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, data)
            loss.backward()
            optimizer.step()
        print(f"Epoch [{epoch + 1}/10], Loss: {loss.item():.4f}")
    
    # Test the trained model
    model.eval()
    with torch.no_grad():
        for data, _ in test_loader:
            outputs = model(data)
            
    # Plot the reconstructions
    n = 10
    plt.figure(figsize=(20, 4))
    for i in range(n):
        # Display original
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(data[i][0], cmap='gray')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        # Display reconstruction
        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(outputs[i][0], cmap='gray')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
    plt.show()

if __name__ == '__main__':
    test_mnist()


Epoch [1/10], Loss: 0.0098
Epoch [2/10], Loss: 0.0044
Epoch [3/10], Loss: 0.0029
Epoch [4/10], Loss: 0.0024
Epoch [5/10], Loss: 0.0022
Epoch [6/10], Loss: 0.0018
Epoch [7/10], Loss: 0.0016
Epoch [8/10], Loss: 0.0016
Epoch [9/10], Loss: 0.0015
Epoch [10/10], Loss: 0.0012
