<a href="https://colab.research.google.com/github/raphaelletseng/AI4Good2021/blob/main/AutoEncoder_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#https://pytorch.org/vision/stable/datasets.html#mnist
import torch
import torchvision
import torchvision.transforms as transforms

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
transform = transforms.ToTensor()
trainset = torchvision.datasets.MNIST('/tmp', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST('/tmp', train=False, download=True, transform=transform)


1. Define the network, optimizer, and loss function.
2. Train - For N epochs iterate through the dataset and for each batch of data:
      1. output = Net(input)
      2. Compute loss and perform back propagation
      3. Run optimizer

In [None]:
import matplotlib.pyplot as plt
plt.imshow(trainset.data[2], cmap="gray")
print(trainset.data[2].shape)
print(device)
plt.show()

Dataloaders are used to efficiently split the datasets into batches.

In [None]:
BATCH_SIZE=32
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, drop_last = True)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, drop_last=True)

In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable

class AutoencoderModel(nn.Module):
    def __init__(self, input_channels, latent_dim):
        super(AutoencoderModel, self).__init__()
        self.encoder_layer1 = nn.Conv2d(input_channels, 16, (3, 3), stride =1, padding=1) 
        #in_channels = 3 for RGB, out = #of filters
        self.encoder_layer2 = nn.Conv2d(16, 32, (3, 3), stride =1, padding=1) 
        self.encoder_layer3 = nn.Conv2d(32, 64, (3, 3), stride =1, padding=1) 
        
        self.pool_layer4 = nn.MaxPool2d(kernel_size = 3, stride= None, padding = 0, dilation = 1)
        #Fully connected layer
        self.encoder_fc_1 = nn.Linear(28*28*64, 512)
        self.encoder_fc_2 = nn.Linear(512, latent_dim)

        self.decoder_fc_1 = nn.Linear(10, 512)
        self.decoder_fc_2 = nn.Linear(512, 28*28*64)

        self.decoder_layer1 = nn.ConvTranspose2d(64, 32, (3, 3), stride=1, padding =1)
        self.decoder_layer2 = nn.ConvTranspose2d(32, 16, (3, 3), stride=1, padding =1)
        self.decoder_layer3 = nn.ConvTranspose2d(16, input_channels, (3, 3), stride=1, padding =1)

    def encoder(self, x):
        x = F.relu(self.encoder_layer1(x))
        x = F.relu(self.encoder_layer2(x))
        x = F.relu(self.encoder_layer3(x))
#flatten the feature map:
        x = x.view(BATCH_SIZE, -1)
        x= F.relu(self.encoder_fc_1(x))
        x= F.relu(self.encoder_fc_2(x))

        return x

    def decoder(self, x):
        x = F.relu(self.decoder_fc_1(x))
        x = F.relu(self.decoder_fc_2(x))

        x = torch.reshape(x, (BATCH_SIZE, 64, 28, 28))

        x = F.relu(self.decoder_layer1(x))    
        x = F.relu(self.decoder_layer2(x))
        x = F.relu(self.decoder_layer3(x))
        
        return x

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
        


In [None]:
model = AutoencoderModel(input_channels = 1, latent_dim=10).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_function = nn.MSELoss()


In [None]:
EPOCHS  = 10
epoch_loss = 0
import tqdm
for epoch in tqdm.trange(EPOCHS):
    for images, labels in trainloader:
      images = images.to(device)

      # output = model(input)
      reconstructions = model(images)
      # compute loss function
      loss = loss_function(images, reconstructions)
      # backward pass
      optimizer.zero_grad()
      loss.backward()
      # run optimizer
      optimizer.step()

      # bookkeeping 
      epoch_loss += loss.item()*images.size(0)
      

Test Dataset

In [None]:
model.eval()
test_loss = 0

for image, _ in testloader:
  # no need to compute gradients - saves time and memory
  with torch.no_grad():
    image = image.to(device)
    # output = model(input)
    reconstruction = model(image)
    reconstruction = torch.squeeze(reconstruction.cpu()[1], dim = 0)
    image = torch.squeeze(image.cpu()[1], dim = 0)

    #bookkeeping
    test_loss += F.mse_loss(reconstruction, image, reduction='mean').item()*image.size(0) #loss

    plt.imshow(reconstruction, cmap = 'gray')
    plt.show()
    plt.imshow(image, cmap = 'gray')
    plt.show()
    input()


In [None]:
test_loss /= len(testloader.dataset)
print('\nTest set: Average Loss: {:.4f}\n'.format(test_loss))