<a href="https://colab.research.google.com/github/satani99/generative_deep_learning/blob/main/VAE_CelebA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch.nn as nn
import torch
import torchvision

In [2]:
class Reshape(nn.Module):
    def __init__(self, *args):
        super(Reshape, self).__init__()
        self.shape = args

    def forward(self, x):
        return x.view(self.shape)

In [3]:
class Lambda(nn.Module):
  def __init__(self, lambd):
    super(Lambda, self).__init__()
    self.lambd = lambd
  def forward(self, x):
    return self.lambd(x)

In [4]:
def sampling(args): 
  mu, log_var = args
  epsilon = torch.normal(mean=0., std=1., size=mu.shape)
  return mu + torch.exp(log_var / 2) * epsilon

In [5]:
def vae_r_loss(y_true, y_pred, r_loss_factor=10000):
  r_loss = nn.MSELoss(y_true, y_pred, reduction='none')
  return r_loss_factor * r_loss

def vae_kl_loss(y_true, y_pred):
  kl_loss = nn.KLDivLoss(y_true, y_pred, reduction='none')
  return kl_loss 

def vae_loss(y_true, y_pred):
  r_loss = vae_r_loss(y_true, y_true)
  kl_loss = vae_kl_loss(y_true, y_pred)
  return r_loss + kl_loss

In [6]:
class Encoder(nn.Module):
  def __init__(self):
    super(Encoder, self).__init__()

    self.conv0 = nn.Conv2d(3, 32, kernel_size=3, stride=1)
    self.batch_norm1 = nn.BatchNorm2d(32)
    self.leaky_relu = nn.LeakyReLU() 
    self.dropout = nn.Dropout()
    self.conv1 = nn.Conv2d(32, 64, kernel_size=3, stride=1)
    self.batch_norm2 = nn.BatchNorm2d(64)
    self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
    self.batch_norm3 = nn.BatchNorm2d(64)
    self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
    self.batch_norm4 = nn.BatchNorm2d(64)
    self.mu = nn.Linear(4096, 200)
    self.log_var = nn.Linear(4096, 200)
    self.lam = Lambda(sampling) 

  def forward(self, x):
    x = self.conv0(x)
    x = self.batch_norm1(x)
    x = self.dropout(self.leaky_relu(x))
    x = self.conv1(x)
    x = self.batch_norm2(x)
    x = self.dropout(self.leaky_relu(x))
    x = self.conv2(x)
    x = self.batch_norm3(x)
    x = self.dropout(self.leaky_relu(x))
    x = self.conv3(x)
    x = self.batch_norm4(x)
    x = self.dropout(self.leaky_relu(x))
    mu = self.mu(x.view(-1, 4096))
    log_var = self.log_var(x.view(-1, 4096))
    output = self.lam(mu, log_var)
    return output

In [7]:
encoder = Encoder()
print(encoder)

Encoder(
  (conv0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
  (batch_norm1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (leaky_relu): LeakyReLU(negative_slope=0.01)
  (dropout): Dropout(p=0.5, inplace=False)
  (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (batch_norm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (batch_norm3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (batch_norm4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (mu): Linear(in_features=4096, out_features=200, bias=True)
  (log_var): Linear(in_features=4096, out_features=200, bias=True)
  (lam): Lambda()
)


In [8]:
class Decoder(nn.Module):
  def __init__(self):
    super(Decoder, self).__init__()
    self.linear = nn.Linear(200, 4096)
    self.convt0 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1)
    self.batchnorm1 = nn.BatchNorm2d(64)
    self.leaky_relu = nn.LeakyReLU()
    self.dropout = nn.Dropout()
    self.convt1 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1)
    self.batchnorm2 = nn.BatchNorm2d(64)
    self.convt2 = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1, stride=2, output_padding=1)
    self.batchnorm3 = nn.BatchNorm2d(32)
    self.convt3 = nn.ConvTranspose2d(32, 3, kernel_size=3, padding=1, stride=2, output_padding=1)
    self.relu = nn.ReLU()

  def forward(self, x):
    x = self.linear(x)
    x = self.convt0(x.view(-1, 64, 8, 8))
    x = self.batchnorm1(x)
    x = self.dropout(self.leaky_relu(x))
    x = self.convt1(x)
    x = self.batchnorm2(x)
    x = self.dropout(self.leaky_relu(x))
    x = self.convt2(x)
    x = self.batchnorm3(x)
    x = self.dropout(self.leaky_relu(x))
    x = self.convt3(x)
    out = self.relu(x)
    return out
    

In [9]:
class AE(nn.Module):
  def __init__(self):
    super().__init__()
    self.encoder = Encoder()
    self.decoder = Decoder()

  def forward(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded

In [10]:
from torchvision import datasets 
from torchvision import transforms, utils
import matplotlib.pyplot as plt
import os 
from torch.utils.data import Dataset, DataLoader
from skimage import io, transform

In [11]:
! pip install -q kaggle

In [None]:
from google.colab import files

files.upload()

In [13]:
 ! mkdir ~/.kaggle

! cp kaggle.json ~/.kaggle/

In [14]:
! chmod 600 ~/.kaggle/kaggle.json

In [None]:
!kaggle datasets download -d jessicali9530/celeba-dataset
!unzip celeba-dataset.zip

In [16]:
class FaceDataset(Dataset):
  def __init__(self, csv_file, root_dir, transform=None):
    self.landmarks_frame = pd.read_csv(csv_file)
    self.root_dir = root_dir
    self.transform = transform 

  def __len__(self):
    return len(self.landmarks_frame)

  def __getitem__(self, idx):
    if torch.is_tensor(idx):
      idx =  idx.tolist()
    img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0])
    image = io.imread(img_name) 
    sample = {'image': image, 'target': image}

    if self.transform:
      sample = self.transform(sample) 
    
    return sample

In [17]:
import pandas as pd

In [35]:
face_dataset = FaceDataset(csv_file='/content/list_landmarks_align_celeba.csv',
                           root_dir='/content/img_align_celeba/img_align_celeba',
                           transform=transforms.ToTensor())



In [None]:
tensor_transform = transforms.ToTensor()

dataset = datasets.MNIST(root = "./data",
                         train = True,
                         download = True,
                         transform = tensor_transform)

loader = torch.utils.data.DataLoader(dataset = dataset,
                                     batch_size = 32,
                                     shuffle = True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [36]:
model = AE()

loss_function = nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(),
                             lr = 1e-1,
                             weight_decay = 1e-8)

In [37]:
loader = torch.utils.data.DataLoader(dataset = face_dataset,
                                     batch_size = 32,
                                     shuffle = True)

In [21]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [32]:
# model.to(device)
tensor_transform = transforms.ToTensor()

In [41]:
epochs = 20
losses = []

for epoch in range(epochs):
  for image in loader:
    image = image['image']
    # image = image.reshape(1, -1, 28, 28)
    reconstructed = model(image)
    # print(image.shape)
    # print(reconstructed.shape)
    reconstructed = reconstructed
    loss = loss_function(reconstructed, image)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    losses.append(loss)

TypeError: ignored

TypeError: ignored