# MNIST auto-encoder
Here we build an auto-encoder for the MNIST dataset. Later, we will use the same network but make it generative with the M-estimator Auto-encoder (MAE).

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt
from torch.autograd import Variable



transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

training_data = datasets.MNIST(
    root="data",
    train="true",
    download=True,
    transform=transform
)

test_data = datasets.MNIST(
    root="data",
    train="test",
    download=True,
    transform=transform
)

dataloader = torch.utils.data.DataLoader(training_data, batch_size=64, shuffle=True, num_workers=4)

In [None]:
class AutoEncoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            # 1 x 28 x 28
            nn.Conv2d(1, 4, kernel_size=5),
            # 4 x 24 x 24 = 2304
            nn.Flatten(),
            nn.ReLU(True),
            nn.Linear(2304, 10),
            # 10
            nn.ReLU(True),
            nn.Linear(10, 2)
            # 2
        )

        self.decoder = nn.Sequential(
            # 2
            nn.Linear(2, 10),
            # 10
            nn.ReLU(True),
            nn.Linear(10, 2304),
            # 2304
            nn.ReLU(True),
            nn.Unflatten(1, (4, 24, 24)),
            # 4 x 24 x 24
            nn.ConvTranspose2d(4, 1, kernel_size = 5),
            # 1 x 28 x 28
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
    
model = AutoEncoder()
distance = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

# fit

num_epochs = 100
for epoch in range(num_epochs):
    for data in dataloader:
        img, _ = data
        output = model(img)
        loss = distance(output, img)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print('epoch [{}/{}], loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))



# The simplest

In [2]:
class AutoEncoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            # 1 x 28 x 28 = 784
            nn.Flatten(),
            nn.ReLU(),
            nn.Linear(784, 2)
            # 2
        )

        self.decoder = nn.Sequential(
            nn.Linear(2, 784),
            # 784
            nn.ReLU(),
            nn.Unflatten(1, (1, 28, 28)),
            # 1 x 28 x 28
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
    
model = AutoEncoder()
distance = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

# fit

num_epochs = 100
for epoch in range(num_epochs):
    for data in dataloader:
        img, _ = data
        output = model(img)
        loss = distance(output, img)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print('epoch [{}/{}], loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))



epoch [1/100], loss: 1.2964
epoch [2/100], loss: 1.2807
epoch [3/100], loss: 1.2785
epoch [4/100], loss: 1.3064
epoch [5/100], loss: 1.2748
epoch [6/100], loss: 1.2777
epoch [7/100], loss: 1.2894
epoch [8/100], loss: 1.2684
epoch [9/100], loss: 1.3062
epoch [10/100], loss: 1.2859
epoch [11/100], loss: 1.3150
epoch [12/100], loss: 1.2937
epoch [13/100], loss: 1.2624
epoch [14/100], loss: 1.2819
epoch [15/100], loss: 1.2332
epoch [16/100], loss: 1.2860
epoch [17/100], loss: 1.2503
epoch [18/100], loss: 1.2526
epoch [19/100], loss: 1.2413
epoch [20/100], loss: 1.3144
epoch [21/100], loss: 1.2939
epoch [22/100], loss: 1.2739
epoch [23/100], loss: 1.2522
epoch [24/100], loss: 1.2748
epoch [25/100], loss: 1.2599
epoch [26/100], loss: 1.2202
epoch [27/100], loss: 1.2535
epoch [28/100], loss: 1.2501
epoch [29/100], loss: 1.2921
epoch [30/100], loss: 1.2614
epoch [31/100], loss: 1.2537
epoch [32/100], loss: 1.2547
epoch [33/100], loss: 1.2518
epoch [34/100], loss: 1.2153
epoch [35/100], loss: 1