<a href="https://colab.research.google.com/github/pppcwen/Autoencoder/blob/main/Autoencoder_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5), (0.5))
    ])

transform = transforms.ToTensor()

mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

data_loader = torch.utils.data.DataLoader(dataset=mnist_data,
                                          batch_size=64,
                                          shuffle=True)

In [None]:
dataiter = iter(data_loader)
images, labels = next(dataiter)
print(torch.min(images), torch.max(images))

In [None]:
class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        # N, 1, 28, 28
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=2, padding=1), # -> N, 16, 14, 14
            nn.ReLU()
        )
        self.encoder2 = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(32, 64, 3) # -> N, 64, 1, 1
        )
        # N , 64, 1, 1
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 3), # -> N, 32, 7, 7
            nn.ReLU()
        )
        self.decoder2 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1,output_padding=1), # N, 1, 28, 28  (N,1,27,27)
            nn.Sigmoid()
        )

    def forward(self, x):
        pool = nn.MaxPool2d(2, stride=2, return_indices=True)
        unpool = nn.MaxUnpool2d(2, stride=2)
        
        
        encoded_1 = self.encoder(x)
        pool_output, indices = pool(encoded_1)
        encoded = self.encoder2(pool_output)
        decoded_1 = self.decoder(encoded)
        unPool_output = unpool(decoded_1, indices)
        decoded = self.decoder2(unPool_output)

        return decoded


# Note: nn.MaxPool2d -> use nn.MaxUnpool2d, or use different kernelsize, stride etc to compensate...
# Input [-1, +1] -> use nn.Tanh

In [None]:
model = Autoencoder()

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3,weight_decay=1e-6)

In [None]:
# Point to training loop video
num_epochs = 20
outputs = []
for epoch in range(num_epochs):
    for (img, _) in data_loader:
        recon = model(img)
        loss = criterion(recon, img)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch:{epoch+1}, Loss:{loss.item():.4f}')
    outputs.append((epoch, img, recon))

In [None]:
for k in range(0, num_epochs, 4):
    plt.figure(figsize=(9, 2))
    plt.gray()
    imgs = outputs[k][1].detach().numpy()
    recon = outputs[k][2].detach().numpy()
    for i, item in enumerate(imgs):
        if i >= 9: break
        plt.subplot(2, 9, i+1)
        # item: 1, 28, 28
        plt.imshow(item[0])

    for i, item in enumerate(recon):
        if i >= 9: break
        plt.subplot(2, 9, 9+i+1) # row_length + i + 1
        # item: 1, 28, 28
        plt.imshow(item[0])

In [None]:
# Homework: Use MaxPool2d, inspect the encoded data (can it be plotted as img?)