In [None]:
import torch
import torch.nn as nn
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from pprint import pprint
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchinfo
from  tqdm import tqdm

In [None]:
path_mnist = r'C:/Users/utkar/Desktop/ML/Dataset'

train = datasets.MNIST(
    path_mnist, 
    train=True, 
    download=False, 
    transform=transforms.Compose([transforms.ToTensor(),])
)
test = datasets.MNIST(
    path_mnist, 
    train=False, 
    download=False,
    transform=transforms.Compose([transforms.ToTensor(),])
)

train_loader = torch.utils.data.DataLoader(
    train, 
    batch_size=64, 
    pin_memory=True, 
    num_workers=4, 
    shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    test, 
    batch_size=64, 
    pin_memory=True, 
    num_workers=4, 
    shuffle=False
)


device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
print(train_loader.dataset, test_loader.dataset)

In [None]:
class DenoiseAutoencoder(nn.Module):
    def __init__(self):
        super(DenoiseAutoencoder, self).__init__()
        self.pool = nn.MaxPool2d((2, 2), stride=2)
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 32, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(32, 8, kernel_size=3, padding=1)   # instead of stride of 2 use maxpool it is good
        
        self.up_conv1 = nn.ConvTranspose2d(8, 32, kernel_size=3, stride=2) 
        self.up_conv2 = nn.ConvTranspose2d(32, 64, kernel_size=2, stride=2) 
        self.up_conv3 = nn.ConvTranspose2d(64, 1, kernel_size=2, stride=2) 
        #self.up_conv4 = nn.ConvTranspose2d(32, 1, kernel_size=2, stride=2) 
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = self.pool(x)
        x = F.relu(self.conv4(x))
        
        x = F.relu(self.up_conv1(x))
        x = F.relu(self.up_conv2(x))
        x = torch.sigmoid(self.up_conv3(x))
        return x

In [None]:
model = DenoiseAutoencoder().to(device)

In [None]:
pprint(torchinfo.summary(model, input_size=(1, 1, 28, 28)))

In [None]:
print(train_loader.batch_size)

In [None]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
scaler = torch.cuda.amp.GradScaler()   # to avoid vanishing gradient problem

In [None]:
# number of epochs to train the model
n_epochs = 20

# for adding noise to images
noise_factor=0.5

for epoch in range(1, n_epochs+1):
    # monitor training loss
    train_loss = 0.0
    loop = tqdm(train_loader)
    ###################
    # train the model #
    ###################
    for data in loop:
        # _ stands in for labels, here
        # no need to flatten images
        images, _ = data
        
        ## add random noise to the input images
        noisy_imgs = images + noise_factor * torch.randn(*images.shape)
        # Clip the images to be between 0 and 1
        noisy_imgs = np.clip(noisy_imgs, 0., 1.)
        
        noisy_imgs = noisy_imgs.to('cuda')
        images = images.to('cuda')
        
        ## forward pass: compute predicted outputs by passing *noisy* images to the model
        with torch.cuda.amp.autocast():
            outputs = model(noisy_imgs)
            # calculate the loss
            # the "target" is still the original, not-noisy images
            loss = criterion(outputs, images)
            # backward pass: compute gradient of the loss with respect to model parameters
        
        
        # perform a single optimization step (parameter update)
        # clear the gradients of all optimized variables
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        # update running training loss
        train_loss += loss.item()*images.size(0)
        loop.set_postfix(loss=loss.item())
            
    # print avg training statistics 
    train_loss = train_loss/len(train_loader)
    print('Epoch: {} \tTraining Loss: {:.6f}'.format(
        epoch, 
        train_loss
        ))
    with torch.no_grad():
        model.eval()
        img = (images.cpu()[2].permute(1, 2, 0)+1)/2
        gen = (outputs.cpu()[2].permute(1, 2, 0)+1)/2
        img = np.concatenate((img, gen), axis=1)
        plt.imshow(img)
        plt.show()
        model.train()