# pytorch_autoencoder.ipynb
# WESmith 07/08/23
## reference https://www.youtube.com/watch?v=zp8clK9yCro

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

In [None]:
transform = transforms.ToTensor()

In [None]:
batch_size       = 64
lr               = 1e-3
decay            = 1e-5
data_dir         = 'data'
model_path       = 'results/model_autoencoder.pth'
optimizer_path   = 'results/optimizer_autoencoder.pth'

In [None]:
mnist_data = datasets.MNIST(root=data_dir, train=True, download=True, transform=transform)

In [None]:
mnist_data.data.shape

In [None]:
data_loader = torch.utils.data.DataLoader(dataset=mnist_data, batch_size=batch_size, shuffle=True)

In [None]:
#torch.utils.data.DataLoader?

In [None]:
# examine the data
dataiter = iter(data_loader)
images, labels = next(dataiter)

In [None]:
images.shape, labels.shape  # apparently the DataLoader introduces the singleton dimension

In [None]:
torch.min(images), torch.max(images)

In [None]:
class Autoencoder(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 12),
            nn.ReLU(),
            nn.Linear(12, 3))
        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.ReLU(),
            nn.Linear(12, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 28*28),
            nn.Sigmoid())  # use sigmoid since 0 to 1
        
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [None]:
model     = Autoencoder()
loss_fn   = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=decay)

In [None]:
count = 0
for k in model.parameters():
    count += k.numel()
count

In [None]:
n_epochs = 5  # 5 epochs took 1m 30s to run on acer
outputs= []
for epoch in range(n_epochs):
    for (img, _) in data_loader:
        img   = img.reshape(-1, 28*28)
        recon = model(img)
        #print(img.shape)
        #print(recon.shape)
        loss  = loss_fn(recon, img)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch: {epoch + 1}, Loss: {loss.item():.4f}')
    outputs.append((epoch, img, recon))

In [None]:
nc = 12
for k in range(0, n_epochs, 4):
    plt.figure(figsize=(14,2))
    plt.gray()
    imgs  = outputs[k][1].detach().numpy()
    recon = outputs[k][2].detach().numpy()
    for i, item in enumerate(imgs):
        if i >= nc: break
        plt.subplot(2, nc, i+1)
        item = item.reshape(-1, 28, 28)
        # item is (1, 28, 28) with singleton from data_loader
        plt.imshow(item[0])
        
    for i, item in enumerate(recon):
        if i >= nc: break
        plt.subplot(2, nc, nc+i+1)
        item = item.reshape(-1, 28, 28)
        plt.imshow(item[0]) # also a singleton, so take [0]

In [None]:
class ConvAutoencoder(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1), # N,1,28,28 -> N,16,14,14
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1), # N,16,14,14 -> N,32,7,7
            nn.ReLU(),
            nn.Conv2d(32, 64, 7)) # N,32,7,7 -> N,64,1,1
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 7), # N,64,1,1 -> N,32,7,7
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1), # N,32,7,7 -> N,16,14,14
            nn.ReLU(),
            nn.ConvTranspose2d(16,  1, 3, stride=2, padding=1, output_padding=1), # N,16,14,14 -> N,1,28,28
            nn.Sigmoid())  # use sigmoid since 0 to 1
        
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [None]:
model_conv = ConvAutoencoder()
loss_fn    = nn.MSELoss()
optim_conv = optim.Adam(model_conv.parameters(), lr=lr, weight_decay=decay)

In [None]:
count = 0
for k in model_conv.parameters():
    count += k.numel()
count

In [None]:
n_epochs = 5  # 5 epochs took 1m 30s to run on acer for conv model also: about same num of params
outputs_conv= []
for epoch in range(n_epochs):
    for (img, _) in data_loader:
        #img   = img.reshape(-1, 28*28)
        recon = model_conv(img)
        loss  = loss_fn(recon, img)
        
        optim_conv.zero_grad()
        loss.backward()
        optim_conv.step()

    print(f'Epoch: {epoch + 1}, Loss: {loss.item():.4f}')
    outputs_conv.append((epoch, img, recon))

In [None]:
nc = 20
for k in range(0, n_epochs, 4):
    plt.figure(figsize=(16,2))
    plt.gray()
    imgs  = outputs_conv[k][1].detach().numpy()
    recon = outputs_conv[k][2].detach().numpy()
    for i, item in enumerate(imgs):
        if i >= nc: break
        plt.subplot(2, nc, i+1)
        item = item.reshape(-1, 28, 28)
        # item is (1, 28, 28) with singleton from data_loader
        plt.imshow(item[0])
        
    for i, item in enumerate(recon):
        if i >= nc: break
        plt.subplot(2, nc, nc+i+1)
        item = item.reshape(-1, 28, 28)
        plt.imshow(item[0]) # also a singleton, so take [0]