In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class VariationalEncoder(nn.Module):
  def __init__(self, latent_dims, num_features, num_hidden_layers):
    super(VariationalEncoder, self).__init__()
    # Se definen 3 capas
    self.hidden_1 = nn.Linear(num_features, num_hidden_layers)
    self.z_mean = nn.Linear(num_hidden_layers, latent_dims)
    self.z_log_var = nn.Linear(num_hidden_layers, latent_dims)
    self.kl_divergence = 0

  def latentVector(self, z_mu, z_log_var):
    # # Se encarga de calcular la distribucion normal para aplicar las distribuciones

    # Sample epsilon from standard normal distribution
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    eps = torch.randn(z_mu.size(0), z_mu.size(1)).to(device)

    # note that log(x^2) = 2*log(x); hence divide by 2 to get std_dev
    # i.e., std_dev = exp(log(std_dev^2)/2) = exp(log(var)/2)
    sigma = torch.exp(z_log_var/2.)
    z = z_mu + eps * sigma
    return z

  # Kullback-Leibler divergence
  def klDivergence(self, x):
    kl = -0.5 * torch.sum(1 + self.z_log_var(x) - self.z_mean(x)**2 - torch.exp(self.z_log_var(x)), axis=1)
    self.kl_divergence = kl.mean()

  def forward(self, features):
    # if not str(features.dtype) == 'torch.float32':
    #   raise Exception("check dtype provided must be torch.float32")

    x = torch.flatten(features, start_dim=1)
    x = self.hidden_1(x)
    x = F.relu(x)

    z_mean =  self.z_mean(x) # mu o z_mean es la distribucion normal
    z_log_var = self.z_log_var(x)

    self.klDivergence(x)

    return self.latentVector(z_mean, z_log_var)


In [None]:
class Decoder(nn.Module):
  def __init__(self, latent_dims, num_features, num_hidden_layers):
    super(Decoder, self).__init__()
    # En este caso definimos la capa de manera inversa
    self.linear1 = nn.Linear(latent_dims, num_hidden_layers)
    self.linear2 = nn.Linear(num_hidden_layers, num_features)

  def forward(self, z_encoded):
    # Activacion de la primera capa a partir del codigo latente obtenido del encoder
    x = self.linear1(z_encoded)
    x = F.relu(x)
    # En el resultaod anterior activamos la segunda capa
    x = self.linear2(x)
    # aplicamos sigmoid para obtener la salida normalizada entre 0 y 1
    z_decoded = torch.sigmoid(x)

    return z_decoded #z_decoded.reshape((-1, 1, 28, 28)) # Reformateamos la salida a una matriz de un solo canal de 28x28px


In [None]:
class VariationalAutoencoder(nn.Module):
  # Constructor
  def __init__(self, latent_dims, num_features, num_hidden_layers):
    super(VariationalAutoencoder, self).__init__()
    #self.device = device
    self.encoder = VariationalEncoder(latent_dims, num_features, num_hidden_layers)
    self.decoder = Decoder(latent_dims, num_features, num_hidden_layers)
    #print('Device:', self.device)

  def forward(self, features):
    z = self.encoder(features)
    return self.decoder(z)
