<a href="https://colab.research.google.com/github/tayfununal/PyTorch/blob/main/Vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from sklearn.datasets import load_digits
from sklearn import datasets

In [None]:
class Digits(Dataset):
    """Scikit-Learn Digits dataset."""

    def __init__(self, mode='train', transforms=None):

        digits = load_digits()

        if mode == 'train':
            self.data = digits.data[:1000].astype(np.float32)
        elif mode == 'val':
            self.data = digits.data[1000:1350].astype(np.float32)
        else:
            self.data = digits.data[1350:].astype(np.float32)

        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.transforms:
            sample = self.transforms(sample)
        return sample

In [None]:
data = Digits()

In [None]:
PI = torch.from_numpy(np.asarray(np.pi))
EPS = 1.e-5

def log_normal_diag(x, mu, log_var, reduction=None, dim=None):
    D = x.shape[1]
    log_p = -0.5 * D * torch.log(2. * PI) - 0.5 * D * log_var - 0.5 * torch.exp(-log_var) * (x - mu)**2.
    if reduction == 'avg':
        return torch.mean(log_p, dim)
    elif reduction == 'sum':
        return torch.sum(log_p, dim)
    else:
        return log_p

In [None]:
class Encoder(nn.Module):

  def __init__(self, encoder_net):
    super(Encoder, self).__init__()

    self.encoder = encoder_net

  @staticmethod
  def reparametrization(mu, log_var):

    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std)

    return mu + std * eps

  def encode(self, x):

    h_e = self.encoder(x)
    mu_e, log_var_e = torch.chunk(h_e, 2, dim=1)

    return mu_e, log_var_e

  def sample(self, x=None):

    mu_e, log_var_e = self.encode(x)
    z = self.reparameterization(mu_e, log_var_e)

    return z

  def log_prob(self, x=None, mu_e=None, log_var_e=None, z=None):
    if x is not None:
        mu_e, log_var_e = self.encode(x)
        z = self.sample(mu_e=mu_e, log_var_e=log_var_e)
    else:
        if (mu_e is None) or (log_var_e is None) or (z is None):
            raise ValueError('mu, log-var and z can`t be None!')

    return log_normal_diag(z, mu_e, log_var_e)