In [None]:
import os
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as tt
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
from torch.utils.data import random_split
from tqdm import tqdm

In [None]:
train_path = '../input/celeba-dataset/img_align_celeba/img_align_celeba'
batch_size=16
IMAGES_COUNT = 20000

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
nc = 3

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

In [None]:
img_file_list = os.listdir(train_path)
img_file_file = [os.path.join(train_path,one) for one in img_file_list]


In [None]:
train_x = np.zeros((IMAGES_COUNT,3,64,64,))

In [None]:
train_tfms = tt.Compose([tt.RandomHorizontalFlip(), 
                          tt.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
                          tt.ToTensor()])

In [None]:
for i,pic_file  in tqdm(enumerate(img_file_file[:IMAGES_COUNT])):
    train_x[i,:,:,:] = train_tfms(Image.open(pic_file).resize((64, 64)))

In [None]:
train_dl = DataLoader(train_x,batch_size, shuffle=True,num_workers=3, pin_memory=True)

In [None]:
def show_batch(dl):
    for images in dl:
        fig, ax = plt.subplots(figsize=(12, 12))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(images[:64], nrow=4).permute(1, 2, 0).clamp(0,1))
        break

In [None]:
show_batch(train_dl)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

In [None]:
kernel_size = 4
stride = 1
padding = 0
init_kernel = 16

class VAE(nn.Module):
    def __init__(self,device):
        super(VAE, self).__init__()
 
        # encoder
        self.enc1 = nn.Conv2d(
            in_channels=3, out_channels=init_kernel, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.enc2 = nn.Conv2d(
            in_channels=init_kernel, out_channels=init_kernel*2, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.enc3 = nn.Conv2d(
            in_channels=init_kernel*2, out_channels=init_kernel*4, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.enc4 = nn.Conv2d(
            in_channels=init_kernel*4, out_channels=init_kernel*8, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.enc5 = nn.Conv2d(
            in_channels=init_kernel*8, out_channels=init_kernel, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )

        # decoder 
        self.dec1 = nn.ConvTranspose2d(
            in_channels=init_kernel, out_channels=init_kernel*8, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.dec2 = nn.ConvTranspose2d(
            in_channels=init_kernel*8, out_channels=init_kernel*4, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.dec3 = nn.ConvTranspose2d(
            in_channels=init_kernel*4, out_channels=init_kernel*2, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.dec4 = nn.ConvTranspose2d(
            in_channels=init_kernel*2, out_channels=init_kernel, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )
        self.dec5 = nn.ConvTranspose2d(
            in_channels=init_kernel, out_channels=3, kernel_size=kernel_size, 
            stride=stride, padding=padding
        )

    def reparameterize(self, mu, log_var):
        """
        :param mu: mean from the encoder's latent space
        :param log_var: log variance from the encoder's latent space
        """
        std = torch.exp(0.5*log_var) # standard deviation
        eps = torch.randn_like(std) # `randn_like` as we need the same size
        sample = mu + (eps * std) # sampling
        return sample
 
    def forward(self, x):
        # encoding
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        x = F.relu(self.enc3(x))
        x = F.relu(self.enc4(x))
        x = self.enc5(x)

        # get `mu` and `log_var`
        mu = x
        log_var = x

        # get the latent vector through reparameterization
        z = self.reparameterize(mu, log_var)
 
        # decoding
        x = F.relu(self.dec1(z))
        x = F.relu(self.dec2(x))
        x = F.relu(self.dec3(x))
        x = F.relu(self.dec4(x))
        reconstruction = torch.sigmoid(self.dec5(x))
        return reconstruction, mu, log_var

In [None]:
vae = VAE(device).double().to(device)
optimizer = torch.optim.Adam(params=vae.parameters(), lr=0.0001, weight_decay=1e-5)
def vae_loss(recon_x,x,mu,log_var):
    recon_loss = F.binary_cross_entropy(recon_x.view(-1,3,64,64).to(device), x.view(-1, 3,64,64).to(device), reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return recon_loss + kl_loss

In [None]:
from torchvision.utils import save_image

def save_fake_images(sample_vectors,model,name):
    fake_images = model(sample_vectors)[0]
    fake_images = fake_images.reshape(fake_images.size(0), 3, 64, 64).to(device)
    fake_fname = name+'fake_images.png'
    print('Saving', fake_fname)
    save_image(fake_images, os.path.join('./', fake_fname), nrow=4)

In [None]:
def fit(model, dataloader,epochs):
    model.train()
    running_loss = 0.0
    for epoch in range(0,epochs): 
        print(f"Epoch {epoch+1}")
        for i,data in enumerate(dataloader): 
            optimizer.zero_grad()
            reconstruction, mu, logvar = model(data.double().to(device))
            loss = vae_loss(reconstruction,data, mu, logvar)
            loss.backward()
            running_loss += loss.item()
            optimizer.step()
        if epoch%100==0:
            save_fake_images(reconstruction,vae,str(epoch))
        train_loss = running_loss/len(dataloader.dataset)
        print(f"Train Loss: {train_loss:.4f}")
        
            
    return train_loss

In [None]:
fit(vae.to(device),train_dl,1000)

In [None]:
def show_res_batch(images):
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(make_grid(images[:16], nrow=4).permute(1, 2, 0).clamp(0,1).detach().numpy())

In [None]:
r_input = torch.randn([16,3,64,64]).to(device)

In [None]:
res_noise = vae.forward(r_input.double())

In [None]:
#show_res_batch(res_noise[0])
save_fake_images(res_noise[0],vae,'fake')
