<a href="https://colab.research.google.com/github/yugpsyfer/Playing_with_PyTorch/blob/main/Denoising_Auto_encoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torchvision.datasets as ds
from torchvision.transforms import v2,ToTensor, Resize, functional
from torch.utils.data import DataLoader, Dataset

import os
from PIL import Image
import numpy as np

In [2]:
transf = v2.Compose([ToTensor(), Resize(size=(128,128))])

flowers = ds.Flowers102(root='./', split="train" ,download=True,
                        transform=transf)

Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/102flowers.tgz to flowers-102/102flowers.tgz


100%|██████████| 344862509/344862509 [00:16<00:00, 20488546.63it/s]


Extracting flowers-102/102flowers.tgz to flowers-102
Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/imagelabels.mat to flowers-102/imagelabels.mat


100%|██████████| 502/502 [00:00<00:00, 446846.48it/s]


Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/setid.mat to flowers-102/setid.mat


100%|██████████| 14989/14989 [00:00<00:00, 18474411.59it/s]


In [3]:
batch_size = 128

loader = DataLoader(dataset=flowers, shuffle=True, batch_size=batch_size)

## Encoder - Decoder

*Encoder output of layers:-*  

Height output and Width output are same in our case.

###Convolution Layers
* Hout=(Hin−1)×stride[0]−2×padding[0]+dilation[0]×(kernel_size[0]−1)+output_padding[0]+1

###Max Pooling
* Hout=⌊(Hin+2∗padding[0]−dilation[0]×(kernel_size[0]−1)−1)/stride[0]
 +1⌋

###Transposed Conv Layer:

* (H-1) * stride - 2 * padding + dilation * (kernel_size-1)+output_padding+1

In [129]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=(3,3)),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3,3)),
            nn.MaxPool2d(kernel_size=(3,3)),80
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3,3)),
            nn.MaxPool2d(kernel_size=(3,3)),
        )

        self.encoded_rep = nn.Sequential(
            nn.Linear(5408, 5408),
            nn.Linear(5408, 2048),
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=(7,7), stride=3, output_padding=2, dilation=3),
            nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=(7,7), stride=2, output_padding=1, dilation=3),
            nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=(7,7), stride=1, output_padding=0, dilation=2),
            nn.ConvTranspose2d(in_channels=16, out_channels=16, kernel_size=(5,5), stride=1, dilation=2),
            nn.ConvTranspose2d(in_channels=16, out_channels=16, kernel_size=(5,5), stride=1),
            nn.ConvTranspose2d(in_channels=16, out_channels=3, kernel_size=(3,3))
        )

    def forward(self, x):
        out = self.encoder(x)

        b, f, m, n = out.shape
        out = self.encoded_rep(out.view(b,f*m*n))

        return self.decoder(out.view(b,32,8,8))

In [138]:
lr = 1e-3
weight_decay = 0.03
device = "cuda"

model = AutoEncoder()
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr)
loss_fn = torch.nn.MSELoss()

In [147]:
@torch.no_grad()
def validation(model, dataloader, val_images=10, path='./out/'):
    count = 0
    for batch in dataloader:
        x,_ = batch
        x = x.to(device)
        out = model(x)
        b,c,m,n = out.shape
        INPUT = functional.to_pil_image(x[0,:,:,:].squeeze())
        INPUT.save(path + 'orig_img_'+str(count)+'.jpg')
        TARGET = functional.to_pil_image(out[0,:,:,:].squeeze())
        TARGET.save(path + 'img_' + str(count)+'.jpg')
        count+=1
        if count > val_images:
            return

In [None]:
def train(model, epochs, dataloader):
    for eps in range(epochs):
        LOSS = 0
        for batch in dataloader:
            x, _ = batch
            x = x.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = loss_fn(out, x)
            loss.backward()
            optimizer.step()
            LOSS+=loss

        print("EPOCH:{} | LOSS:{}".format(eps,LOSS))



train(model, 100, loader)

In [148]:
validation(model, loader)