## Reconstruct MNIST with AutoEncoder 


### Import Packages

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms,datasets
import numpy as np

print('pytorch version:',torch.__version__,
      '\ntorchvision version: ',torchvision.__version__,
      '\nnumpy version:' ,np.__version__)

### Settings

In [None]:
# model runs on  GPU or CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Hyperparameters
learning_rate = 1e-3 
num_epochs = 10
batch_size = 64

### Dataset: MNIST


In [None]:
train_dataset = datasets.MNIST(root='data', 
                               train=True, 
                               transform=transforms.Compose([
                                    transforms.ToTensor(),
                                ]),
                               download=True)

test_dataset = datasets.MNIST(root='data', 
                              train=False, 
                              transform=transforms.Compose([
                                    transforms.ToTensor(),
                                ]))

train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=batch_size, 
                          shuffle=True)

test_loader = DataLoader(dataset=test_dataset, 
                         batch_size=batch_size, 
                         shuffle=False)

### Define AE

In [None]:
class AE(nn.Module):

    def __init__(self):
        super(AE, self).__init__()
        # [b, 784] => [b, 20]
        self.encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 20),
            nn.ReLU()
        )
        # [b, 20] => [b, 784]
        self.decoder = nn.Sequential(
            nn.Linear(20, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid()
        )

    def forward(self, x):
        batchsz = x.size(0)
        # flatten
        x = x.view(batchsz, 784)
        # encoder
        x = self.encoder(x)
        # decoder
        x = self.decoder(x)
        # reshape
        x = x.view(batchsz, 1, 28, 28)

        return x, None
    

### Init AE, define optimizer and loss function

In [None]:
model = AE()
model = model.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
criteon = nn.MSELoss().to(device)

### Training AE

In [None]:
for epoch in range(num_epochs):

    model.train()
    for batch_idx, (x, _) in enumerate(train_loader):
        x = x.to(device)
        x = x.view(-1, 28*28)
        
        # forward
        x_hat = model(x)
        loss = criteon(x_hat, x)

        # backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(x), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))
    
    model.eval()
    with torch.no_grad():
        for x, _ in test_loader:
            x = x.to(device)
            x = x.view(-1, 28*28)

            x_hat = model(x)