<a href="https://colab.research.google.com/github/satani99/generative_deep_learning/blob/main/VAE_CelebA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch.nn as nn
import torch
import torchvision

In [3]:
class Reshape(nn.Module):
    def __init__(self, *args):
        super(Reshape, self).__init__()
        self.shape = args

    def forward(self, x):
        return x.view(self.shape)

In [4]:
class Lambda(nn.Module):
  def __init__(self, lambd):
    super(Lambda, self).__init__()
    self.lambd = lambd
  def forward(self, x):
    return self.lambd(x)

In [5]:
def sampling(args): 
  mu, log_var = args
  epsilon = torch.normal(mean=0., std=1., size=mu.shape)
  return mu + torch.exp(log_var / 2) * epsilon

In [15]:
def vae_r_loss(y_true, y_pred, r_loss_factor=10000):
  r_loss = nn.MSELoss(y_true, y_pred, reduction='none')
  return r_loss_factor * r_loss

def vae_kl_loss(y_true, y_pred):
  kl_loss = nn.KLDivLoss(y_true, y_pred, reduction='none')
  return kl_loss 

def vae_loss(y_true, y_pred):
  r_loss = vae_r_loss(y_true, y_true)
  kl_loss = vae_kl_loss(y_true, y_pred)
  return r_loss + kl_loss

In [8]:
class Encoder(nn.Module):
  def __init__(self):
    super(Encoder, self).__init__()

    self.conv0 = nn.Conv2d(3, 32, kernel_size=3, stride=1)
    self.batch_norm1 = nn.BatchNorm2d(32)
    self.leaky_relu = nn.LeakyReLU() 
    self.dropout = nn.Dropout()
    self.conv1 = nn.Conv2d(32, 64, kernel_size=3, stride=1)
    self.batch_norm2 = nn.BatchNorm2d(64)
    self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
    self.batch_norm3 = nn.BatchNorm2d(64)
    self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
    self.batch_norm4 = nn.BatchNorm2d(64)
    self.mu = nn.Linear(4096, 200)
    self.log_var = nn.Linear(4096, 200)
    self.lam = Lambda(sampling) 

  def forward(self, x):
    x = self.conv0(x)
    x = self.batch_norm1(x)
    x = self.dropout(self.leaky_relu(x))
    x = self.conv1(x)
    x = self.batch_norm2(x)
    x = self.dropout(self.leaky_relu(x))
    x = self.conv2(x)
    x = self.batch_norm3(x)
    x = self.dropout(self.leaky_relu(x))
    x = self.conv3(x)
    x = self.batch_norm4(x)
    x = self.dropout(self.leaky_relu(x))
    mu = self.mu(x.view(-1, 4096))
    log_var = self.log_var(x.view(-1, 4096))
    output = self.lam(mu, log_var)
    return output

In [16]:
encoder = Encoder()
print(encoder)

Encoder(
  (conv0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
  (batch_norm1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (leaky_relu): LeakyReLU(negative_slope=0.01)
  (dropout): Dropout(p=0.5, inplace=False)
  (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (batch_norm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (batch_norm3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (batch_norm4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (mu): Linear(in_features=4096, out_features=200, bias=True)
  (log_var): Linear(in_features=4096, out_features=200, bias=True)
  (lam): Lambda()
)


tensor([[0.0419, 4.4632, 4.1102, 8.1344, 0.2808],
        [0.2386, 1.1621, 1.1126, 0.8074, 0.1105],
        [0.1706, 0.9230, 1.4701, 0.2561, 1.3024]], grad_fn=<MseLossBackward0>)