In [None]:
import torch
import torchvision
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import torch.nn.functional as F

from pprint import pprint
import numpy as np
import torchinfo
from matplotlib import pyplot as plt
import time
import cv2
from PIL import Image
import time
from tqdm import tqdm

In [None]:
# Setting random seeds

torch.manual_seed(1)
torch.cuda.manual_seed(1)


In [None]:
channels_img = 3
batch_size = 64
data_dir = r'C:/Users/utkar/Desktop/ML/Dataset/Celeb_dataset/img_align_celeba'
data_dir_new = r'C:/Users/utkar/Desktop/ML/Dataset/Celeb_dataset/500_img'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.CenterCrop((128, 128)),
    ])
dataset = datasets.ImageFolder(root=data_dir, transform=transform)

train_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    pin_memory=True,
    shuffle=True
)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
img = dataset[5][0].permute(1, 2, 0)
plt.imshow(img)
plt.show()
  
img_blur = torch.tensor(cv2.blur(np.array(img), (5, 5)))
plt.imshow(img_blur)
plt.show()


In [None]:
def bn(in_c, out_c):
    bn = nn.Sequential(
        nn.Conv2d(in_c, in_c*2, kernel_size=1),
        nn.BatchNorm2d(in_c*2),
        nn.LeakyReLU(0.1, inplace=True),
        nn.Conv2d(in_c*2, in_c*2, kernel_size=3, padding=1),
        nn.BatchNorm2d(in_c*2),
        nn.LeakyReLU(0.1, inplace=True),
        nn.Conv2d(in_c*2, out_c, kernel_size=1),
        nn.BatchNorm2d(out_c),
        nn.LeakyReLU(0.1, inplace=True),
    )
    
    return bn

In [None]:
# Defining the model

class VAE(nn.Module):
    def __init__(self, imgChannels=3, feature_dim=64*8*8, z_dim=512):
        super(VAE, self).__init__()
        # encoder
        self.encConv1 = nn.Conv2d(imgChannels, 32*2, kernel_size=4, stride=2, padding=1)
        self.bn_1 = bn(32*2, 32)
        self.encConv2 = nn.Conv2d(32*3, 64*2, kernel_size=4, stride=2, padding=1)
        self.bn_2 = bn(64*2, 64)
        self.encConv3 = nn.Conv2d(64*3, 64*2, kernel_size=4, stride=2, padding=1)
        self.bn_3 = bn(64*2, 64)
        self.encConv4 = nn.Conv2d(64*3, 64, kernel_size=4, stride=2, padding=1)
        
        # reparameterize
        self.encFC1 = nn.Linear(feature_dim, z_dim)
        self.encFC2 = nn.Linear(feature_dim, z_dim)

        # decoder
        self.decFC = nn.Linear(z_dim, feature_dim)
        self.decTconv0 = nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1)
        self.bn_4 = bn(64, 64*2)
        self.decTconv1 = nn.ConvTranspose2d(64*3, 64*2, kernel_size=4, stride=2, padding=1)
        self.bn_5 = bn(64*2, 64)
        self.decTconv2 = nn.ConvTranspose2d(64*3, 32*2, kernel_size=4, stride=2, padding=1)
        self.bn_6 = bn(32*2, 32)
        self.decTconv3 = nn.ConvTranspose2d(32*3, imgChannels, kernel_size=4, stride=2, padding=1)
        
    def encoder(self, x):
        x = F.leaky_relu(self.encConv1(x), 0.1, inplace=True)
        y = self.bn_1(x)
        x = F.leaky_relu(self.encConv2(torch.concat([y, x], axis=1)), 0.1, inplace=True)
        y = self.bn_2(x)
        x = F.leaky_relu(self.encConv3(torch.concat([y, x], axis=1)), 0.1, inplace=True)
        y = self.bn_3(x)
        x = F.leaky_relu(self.encConv4(torch.concat([y, x], axis=1)), 0.1, inplace=True)
        x = x.view(-1, 64*8*8)        
        mu = self.encFC1(x)
        logvar = self.encFC2(x)
        return mu, logvar
        
    def reparameterise(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = std.data.new(std.size()).normal_()
            return eps.mul(std).add_(mu)
        else:
            return mu
        
    def decoder(self, x):
        x = F.leaky_relu(self.decFC(x), 0.1, inplace=True)
        x = x.view(-1, 64, 8, 8)
        x = F.leaky_relu(self.decTconv0(x), 0.1, inplace=True)
        y = self.bn_4(x)
        x = F.leaky_relu(self.decTconv1(torch.concat([y, x], axis=1)), 0.1, inplace=True) #+ F.relu(torch.randn(x.shape[0], 128, 16, 16)*0.2).to(device)
        y = self.bn_5(x)
        x = F.leaky_relu(self.decTconv2(torch.concat([y, x], axis=1)), 0.1, inplace=True) #+ F.relu(torch.randn(x.shape[0], 64, 32, 32)*0.01).to(device)
        y = self.bn_6(x)
        x = torch.sigmoid(self.decTconv3(torch.concat([y, x], axis=1)))
        return x

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterise(mu, logvar)
        out = self.decoder(z)
        return z, mu, logvar, out 

model = VAE().to(device)
#pprint(torchinfo.summary(model, (1, 3, 128, 128)))

In [None]:
# Setting the optimiser

optimizer = optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.MSELoss(reduction='none')

def display_photo():
    with torch.no_grad():
        img, _ = dataset[np.random.randint(1, 200000)]
        encoded, mean, logvar, decoded = model(torch.unsqueeze(img, 0).to(device))
        plt.imshow(decoded[0].cpu().permute(1, 2, 0))
        plt.show()
        plt.imshow((img).permute(1, 2, 0))
        plt.show()

In [None]:
epochs = 52

LOSS = []
PIXELWISE = []
KLD_DIV = []

for epoch in range(epochs):
    LOSS = []
    PIXELWISE = []
    KLD_DIV = []
    n = 0
    epoch += 1
    train_loss = 0
    loop = tqdm(train_loader)
    model.train()
    for data in loop:
        n += 1
        img, _ = data
        img =  img.to(device)
        # z, mu, logvar, out 
        encoded, mean, logvar, decoded = model(img)
        KLD = -0.5*torch.sum(1 + logvar - mean**2 - torch.exp(logvar), axis=1)
        batch_size = KLD.size(0)
        KLD = KLD.mean()
        pixelwise = criterion(decoded, img)
        pixelwise = pixelwise.view(batch_size, -1).sum(axis=1)
        pixelwise = pixelwise.mean()
        loss = 1*pixelwise + KLD
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()*img.size(0)
        loop.set_postfix(loss=loss.item())
        
        
        LOSS.append(loss.item())
        PIXELWISE.append(pixelwise.item())
        KLD_DIV.append(KLD.item())
        
        if n%1500 == 0:
            print(n)
            model.eval()
            display_photo()
            
            model.train()
            
    time.sleep(30)
    print(f'Epoch:{epoch}/{epochs} training loss {loss}')

In [None]:
# saving the model

PATH = r'C:/Users/utkar/Desktop/ML/pytorch/autoencoder/autoencoder_CELEBA_DATASET.pth.tar'
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss
}, PATH)

In [None]:
display_photo()

In [None]:
# loading the model

PATH = r'C:/Users/utkar/Desktop/ML/pytorch/autoencoder/autoencoder_using_mseloss.pth.tar'
model = VAE().to(device)
#optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

#model.eval()
# - or -
model.train()
print(" ")