In [1]:
# Torch is for PyTorch which is a deep learning framework used for neural networks
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
#For plotting
import matplotlib.pyplot as plt

import pandas as pd

In [2]:
# For ...
transform = transforms.ToTensor()

#transform = transforms.Compose([
#    transforms.ToTensor(),
#    transforms.Normalize((0.5),(0.5))
#])

#coil_data_df = pd.read_csv('ticdata2000.txt', sep='\t', header=None)
#coil_data_df.to_csv('ticdata2000.csv', index=False)

mnist_data = datasets.MNIST(root='./',train=True, download=True, transform=transform)

# See Patrick Loeber video about this more specifically
data_loader = torch.utils.data.DataLoader(dataset=mnist_data,
                                          batch_size=64,
                                          shuffle=True)

In [3]:
dataiter = iter(data_loader)
images, labels = next(dataiter)
print(torch.min(images), torch.max(images))

tensor(0.) tensor(1.)


In [4]:
class Autoencoder(nn.Module):
    def __init__(self):
        # N, 784

        super().__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 128), # N, 784 --> N, 128
            nn.ReLU(), 
            nn.Linear(128,64),
            nn.ReLU(),
            nn.Linear(64, 12),
            nn.ReLU(),
            nn.Linear(12,3) # --> N, 3
        )

        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.ReLU(), 
            nn.Linear(12,64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 28*28),
            nn.Sigmoid() # N, 3 --> N, 784
        )
            
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [5]:
model = Autoencoder()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

In [13]:
num_epochs = 10
outputs = []
for epoch in range(num_epochs):
    for (img, _) in data_loader:
        img = img. reshape(-1, 28*28)
        recon = model(img)
        loss = criterion(recon, img)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch:{epoch+1}, Loss:{loss.item(): 4f}')
    outputs.append((epoch, img, recon))

Epoch:1, Loss: 0.043879
Epoch:2, Loss: 0.044147
Epoch:3, Loss: 0.043394
Epoch:4, Loss: 0.034475
Epoch:5, Loss: 0.032627
Epoch:6, Loss: 0.037653
Epoch:7, Loss: 0.039189
Epoch:8, Loss: 0.036148
Epoch:9, Loss: 0.034107
Epoch:10, Loss: 0.037434


In [None]:
for k in range (0, num_epochs, 4):
    plt.figure(figsize=(9,2))
    plt.gray()
    imgs = outputs[k][1].detach().numpy()
    recon = outputs[k][2].detach().numpy()
    for i, item in enumerate(imgs):
        if i >= 9: break
        plt.subplot(2, 9, i+1)
        item = item.reshape(-1, 28, 28)
        # item: 1, 28, 28
        plt.imshow(item[0])

    for i, item in enumerate(recon):
        if i >= 9: break
        plt.subplot(2,9, 9+i+1) # row_length + i + 1
        item = item.reshape(-1, 28, 28)
        # item: 1, 28, 28
        plt.imshow(item[0])