In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt


In [None]:
transform = transforms.ToTensor()  #  Convert images to pytorch Tensor
# transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.5),(0.5))
#  ])
mnist_data = datasets.MNIST(root = "./data", train = True, download = True, transform=transform) # download the data and save it in ./data
data_loader = torch.utils.data.DataLoader(dataset = mnist_data, batch_size = 64, shuffle = True)

In [None]:
# dataiter = iter(data_loader)    # To see how the data looks like; inspect the first image or the first batch
# images,labels = dataiter.next()
# print(torch.min(images), torch.max(images)) # This mayy change, if we chane the transform type

In [None]:
class Autoencoder_Linear(nn.Module):
  def __init__(self):
    #N: number of batches, 784 = 28x28, we want to reduce the size
    # A good way to structure your model in an autoencoder is to use a sequential model
    super().__init__()
    self.encoder = nn.Sequential(
        nn.Linear(28*28, 128), # Reduce N, 784 to N,128
        nn.ReLU(),
        nn.Linear(128,64),
        nn.ReLU(),
        nn.Linear(64,12),
        nn.ReLU(),
        nn.Linear(12,3)   #N,3
        # In the last layer, we dont need an activation function
    )
    self.decoder = nn.Sequential(
        # Increase N, 3 to N,784
        nn.Linear(3, 12),
        nn.ReLU(),
        nn.Linear(12,64),
        nn.ReLU(),
        nn.Linear(64,128),
        nn.ReLU(),
        nn.Linear(128,28*28),
        nn.Sigmoid()  # in spite of encoder, we need an activation function to put the values of images in torch.min(images), torch.max(images)

    )



  def forward(self,x):

    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded

# Note: Keep the last layer in mind: If our input images be in [-1,1], then, we dont need the sigmoid and instead apply tanh, this can happen if we apply a normalization transform

In [None]:
class Autoencoder_CNN(nn.Module):
  def __init__(self):
    #N, 1, 28, 28
    # A good way to structure your model in an autoencoder is to use a sequential model
    super().__init__()
    self.encoder = nn.Sequential(
        nn.Conv2d(1, 16, 3, stride = 2, padding=1), # N, 16 channels, 14, 14 size
        nn.ReLU(),
        nn.Conv2d(16 , 32 ,3, stride = 2, padding=1), # N, 32 channels, 7, 7 size
        nn.ReLU(),
        nn.Conv2d(32, 64, 7)  # N, 64 channels, 1, 1 size

        # In the last layer, we dont need an activation function
    )
    # N, 64 channels, 1, 1 size
    self.decoder = nn.Sequential(

        nn.ConvTranspose2d(64, 32, 7), #N, 32, 7,7
        nn.ReLU(),
        nn.ConvTranspose2d(32 ,16, 3, stride = 2, padding=1, output_padding= 1 ), #N, 16, 13, 13: output_padding = 1 put zeros in the margine to convert it to 14, 14
        nn.ReLU(),
        nn.ConvTranspose2d(16, 1, 3,  stride = 2, padding=1, output_padding= 1 ),  #N, 1, 27, 27: output_padding = 1 put zeros in the margin to convert it to 28, 28
        nn.Sigmoid()  # in spite of encoder, we need an activation function to put the values of images in torch.min(images), torch.max(images)

    )



  def forward(self,x):

    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded

# Note: Keep the last layer in mind: If our input images be in [-1,1], then, we dont need the sigmoid and instead apply tanh, this can happen if we apply a normalization transform

# nn.MaxPool2d  reduces the size VS nn.MaxUnpool2d


In [None]:
model = Autoencoder_CNN()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr= 1e-3, weight_decay=1e-5)


In [None]:
num_epochs = 10
outputs = []
for epoch in range(num_epochs):
  for (img,_) in data_loader:
    #img = img.reshape(-1,28*28)
    recon = model(img)
    loss = criterion(recon,img)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  print(f"Epoch:{epoch+1}, Loss:{loss.item(): .4f}")
  outputs.append((epoch,img,recon))


In [None]:
for k in range(0,num_epochs,4):
  plt.figure(figsize=(9,2))
  plt.gray()
  imgs = outputs[k][1].detach().numpy()
  recon = outputs[k][2].detach().numpy()
  for i, item in enumerate(imgs):
    if i>= 9: break   # Plot the first 9 images
    plt.subplot(2,9,i+1)
    #item = item.reshape(-1,28,28)
    # item:1,28,28
    plt.imshow(item[0])

  for i, item in enumerate(recon):
    if i>= 9: break  # Plot the first 9 reconstructed images
    plt.subplot(2,9,9+i+1)
    #item = item.reshape(-1,28,28)
    # item:1,28,28
    plt.imshow(item[0])