#This version works for 28x28 dimension images

In [20]:
import torch.nn as nn
import matplotlib.pyplot as plt
import torch
import numpy as np
import torchvision
import pandas as pd


In [21]:
data_dir = 'dataset'

train_dataset = torchvision.datasets.MNIST(data_dir, train=True, download=True)
test_dataset  = torchvision.datasets.MNIST(data_dir, train=False, download=True)

In [22]:
from torchvision import transforms
from torch.utils.data import DataLoader,random_split

In [23]:
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean

'''noise_factor = 0.2
x_train_noisy = X_train + noise_factor * numpy.random.normal(loc=0.0, scale=1.0, size=X_train.shape)
x_test_noisy = X_test + noise_factor * numpy.random.normal(loc=0.0, scale=1.0, size=X_test.shape)
x_train_noisy = numpy.clip(x_train_noisy, 0., 1.)
x_test_noisy = numpy.clip(x_test_noisy, 0., 1.)'''

In [24]:


train_transform = transforms.Compose([
    transforms.ToTensor()
])

test_transform = transforms.Compose([
transforms.ToTensor(),
])

train_dataset.transform = train_transform
test_dataset.transform = test_transform

m=len(train_dataset)

train_data, val_data = random_split(train_dataset, [int(m-m*0.2), int(m*0.2)])

batch_size=256

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)

valid_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,shuffle=True)

In [25]:
s = train_data[0][0]

In [26]:
pool = nn.MaxPool2d(kernel_size=2)
pool(s).shape

torch.Size([1, 14, 14])

In [27]:
def add_noise(inputs,noise_factor=0.3):
     noisy = inputs+torch.randn_like(inputs) * noise_factor
     noisy = torch.clip(noisy,0.,1.)
     return noisy

In [42]:
latent_space_size= 8

class Encoder(nn.Module):
  def __init__(self):
    super().__init__()

    self.encoder = nn.Sequential(
        nn.Conv2d(1,8,kernel_size=2,stride=1,padding='same'), 
        nn.ReLU(),
        nn.Conv2d(8,16,kernel_size=2,stride=1,padding='same'),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),
        nn.Conv2d(16,24,kernel_size=2,stride=1,padding='same'),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),
        nn.Conv2d(24,32,kernel_size=3,stride=2,padding=0),
        nn.ReLU()
    )

    self.Flatten = nn.Flatten(start_dim=1)

    self.enc_lin = nn.Sequential(
        
        nn.Linear(3*3*32,128),
        nn.ReLU(),
        nn.Linear(128,latent_space_size)
    )

    self.dec_lin =nn.Sequential(
        nn.Linear(latent_space_size,128),
        nn.ReLU(),
        nn.Linear(128, 32*3*3),
        nn.ReLU(),
        nn.Unflatten(dim=1,unflattened_size =(32,3,3))
    )

    #self.pool = nn.MaxPool2d(kernel_size=2,padding=1)
    
    self.decoder = nn.Sequential(
        
        nn.ConvTranspose2d(32,24,kernel_size=3,stride=2),
        nn.ReLU(),
        nn.ConvTranspose2d(24,16,kernel_size=2,stride=2),
        nn.ReLU(),
        nn.ConvTranspose2d(16,8,kernel_size=2,stride=2),
        nn.ReLU(),
        nn.Conv2d(8,1,kernel_size=2,stride=1, padding='same'),
        nn.Sigmoid()
        
    )

  def forward(self,x):
    
    x = self.encoder(x)
    #print(x.shape)
    x=self.Flatten(x)
    #print(x.shape)
    x=self.enc_lin(x)
    #print(x.shape)
    x=self.dec_lin(x)
    
    #print(x.shape)
    x = self.decoder(x)
    #print(x.shape)
    
    
    return x


In [29]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [40]:
autoencoder_model = Encoder()
autoencoder_model.to(device)

Encoder(
  (encoder): Sequential(
    (0): Conv2d(1, 8, kernel_size=(2, 2), stride=(1, 1), padding=same)
    (1): ReLU()
    (2): Conv2d(8, 16, kernel_size=(2, 2), stride=(1, 1), padding=same)
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(16, 24, kernel_size=(2, 2), stride=(1, 1), padding=same)
    (6): ReLU()
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(24, 32, kernel_size=(3, 3), stride=(2, 2))
    (9): ReLU()
  )
  (Flatten): Flatten(start_dim=1, end_dim=-1)
  (enc_lin): Sequential(
    (0): Linear(in_features=288, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=2, bias=True)
  )
  (dec_lin): Sequential(
    (0): Linear(in_features=2, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=288, bias=True)
    (3): ReLU()
    (4): Unflatten(dim=1, unflattened_size=(32, 3, 3))
  )
  (d

In [31]:
s= train_data[0][0]
#s1=autoencoder_model(s)

In [32]:
loss_func = nn.MSELoss()
optim = torch.optim.Adam(autoencoder_model.parameters(),lr=0.001)

In [33]:
from tqdm import tqdm

In [34]:
samx,samy = next(iter(train_loader))
#output = autoencoder_model(samx)
#output.shape

In [35]:

def plot_ae_outputs_den(encoder,n=10,noise_factor=0.2):
    plt.figure(figsize=(16,4.5))
    targets = test_dataset.targets.numpy()
    t_idx = {i:np.where(targets==i)[0][0] for i in range(n)}    
    for i in range(n):

      ax = plt.subplot(3,n,i+1)
      img = test_dataset[t_idx[i]][0].unsqueeze(0)
      image_noisy = add_noise(img,noise_factor)     
      image_noisy = image_noisy.to(device)

      encoder.eval()
      

      with torch.inference_mode():
         rec_img  = encoder(image_noisy)

      plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
        ax.set_title('Original images')
      ax = plt.subplot(3, n, i + 1 + n)
      plt.imshow(image_noisy.cpu().squeeze().numpy(), cmap='gist_gray')
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
        ax.set_title('Corrupted images')

      ax = plt.subplot(3, n, i + 1 + n + n)
      plt.imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')  
      ax.get_xaxis().set_visible(False)
      ax.get_yaxis().set_visible(False)  
      if i == n//2:
         ax.set_title('Reconstructed images')
    plt.subplots_adjust(left=0.1,
                    bottom=0.1, 
                    right=0.7, 
                    top=0.9, 
                    wspace=0.3, 
                    hspace=0.3)     
    plt.show()   

In [36]:
def train_step(model,dataloader,loss_fn,optimizer,noise_factor,device=device):
  model.train()
  train_loss=[]

  for X,y in dataloader:
    noisy_X = add_noise(X,noise_factor)
    X = X.to(device)
    noisy_X = noisy_X.to(device)
    output = model(noisy_X)
    loss = loss_fn(output, X)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    #print(f'\t The batch loss is : {loss.data}')
    train_loss.append(loss.detach().cpu().numpy())
  
  return np.mean(train_loss)

In [37]:
def test_func(model,dataloader,loss_fn,noise_factor,device=device):
  model.eval()
  with torch.inference_mode():
    out=[]
    original=[]
    for X,_ in dataloader:
      noisy_X = add_noise(X,noise_factor)
      noisy_X = noisy_X.to(device)
      X = X.to(device)
      pred = model(noisy_X)
      
      out.append(pred)
      original.append(X)
    out = torch.cat(out)
    original = torch.cat(original)

    val_loss = loss_fn(out,original)
  return val_loss.data

In [43]:
EPOCHS = 20
history={'train_loss':[],'val_loss':[]}
for epoch in tqdm(range(EPOCHS)):
  
  train_loss=train_step(model=autoencoder_model,
             dataloader=train_loader,
             loss_fn=loss_func,
             optimizer=optim,
             noise_factor=0.2, 
             device=device)
  val_loss = test_func(model=autoencoder_model,
                       dataloader=valid_loader,
                       loss_fn= loss_func,
                       noise_factor=0.2,
                       device=device)
  history['train_loss'].append(train_loss)
  history['val_loss'].append(val_loss)
  print(f'\n training loss =  {train_loss}\t validation loss = {val_loss}')
  plot_ae_outputs_den(autoencoder_model , n=10, noise_factor=0.2)

  

Output hidden; open in https://colab.research.google.com to view.