<a href="https://colab.research.google.com/github/tayfununal/PyTorch/blob/main/Vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [135]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from sklearn.datasets import load_digits
from sklearn import datasets

In [136]:
class Digits(Dataset):
    """Scikit-Learn Digits dataset."""

    def __init__(self, mode='train', transforms=None):

        digits = load_digits()

        if mode == 'train':
            self.data = digits.data[:1000].astype(np.float32)
        elif mode == 'val':
            self.data = digits.data[1000:1350].astype(np.float32)
        else:
            self.data = digits.data[1350:].astype(np.float32)

        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.transforms:
            sample = self.transforms(sample)
        return sample

In [137]:
data = Digits()

In [138]:
PI = torch.from_numpy(np.asarray(np.pi))
EPS = 1.e-5

def log_normal_diag(x, mu, log_var, reduction=None, dim=None):
    D = x.shape[1]
    log_p = -0.5 * D * torch.log(2. * PI) - 0.5 * D * log_var - 0.5 * torch.exp(-log_var) * (x - mu)**2.
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

In [139]:
class Encoder(nn.Module):

  def __init__(self, encoder_net):
    super(Encoder, self).__init__()

    self.encoder = encoder_net

  @staticmethod
  def reparameterization(mu, log_var):

    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std)

    return mu + std * eps

  def encode(self, x):

    h_e = self.encoder(x)
    mu_e, log_var_e = torch.chunk(h_e, 2, dim=1)

    return mu_e, log_var_e

  def sample(self, x=None, mu_e=None, log_var_e=None):
    if (mu_e is None) and (log_var_e is None):
        mu_e, log_var_e = self.encode(x)
    else:
        if (mu_e is None) or (log_var_e is None):
            raise ValueError('mu and log-var can`t be None!')
    z = self.reparameterization(mu_e, log_var_e)
    return z

  def log_prob(self, x=None, mu_e=None, log_var_e=None, z=None):
    if x is not None:
        mu_e, log_var_e = self.encode(x)
        z = self.sample(mu_e=mu_e, log_var_e=log_var_e)
    else:
        if (mu_e is None) or (log_var_e is None) or (z is None):
            raise ValueError('mu, log-var and z can`t be None!')

    return log_normal_diag(z, mu_e, log_var_e)

  def forward(self, x, type='log_prob'):
    assert type in ['encode', 'log_prob'], 'Type could be either encode or log_prob'
    if type == 'log_prob':
        return self.log_prob(x)
    else:
        return self.sample(x)

In [140]:
D = 64      # Batch size
M = 256     # Number of neuron in the hidden layer
L = 16      # Number of latents

In [141]:
 encoder = nn.Sequential(
                          nn.Linear(D, M), nn.ReLU(),
                          nn.Linear(M, M), nn.ReLU(),
                          nn.Linear(M, 2 * L)
                        )

In [142]:
encod = Encoder(encoder)

In [145]:
mu, var = encod.forward(torch.from_numpy(data[:2]), 'log_prob')

tensor([[ 1.3316,  0.0640,  0.3693, -0.9586, -0.2108,  1.1805, -0.3041, -0.6399,
          1.1513,  1.1011, -0.2846,  1.5771, -0.7675,  1.0661, -0.4585, -0.6613],
        [ 0.7910,  0.1572,  1.3897, -1.4139, -0.9077, -0.1735,  0.8505,  0.2297,
          0.6535,  0.5542, -1.2437,  1.4594, -0.3360,  1.0421, -0.2311, -0.8626]],
       grad_fn=<SplitBackward0>)
