In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class ConditionalVariationalEncoder(nn.Module):
  def __init__(self, latent_dims, num_features, num_hidden_layers, num_classes, to_onehot_fn):
    super(ConditionalVariationalEncoder, self).__init__()
    # Se definen 3 capas
    self.hidden_1 = nn.Linear(num_features+num_classes, num_hidden_layers)
    self.z_mean = nn.Linear(num_hidden_layers, latent_dims)
    self.z_log_var = nn.Linear(num_hidden_layers, latent_dims)

    self.num_classes = num_classes
    self.kl_divergence = 0
    self.to_onehot_fn = to_onehot_fn

  def latentVector(self, z_mu, z_log_var, device):
    # Sample epsilon from standard normal distribution
    eps = torch.randn(z_mu.size(0), z_mu.size(1)).to(device)
    # note that log(x^2) = 2*log(x); hence divide by 2 to get std_dev
    # i.e., std_dev = exp(log(std_dev^2)/2) = exp(log(var)/2)
    sigma = torch.exp(z_log_var/2.)
    z = z_mu + eps * sigma
    return z

  # Kullback-Leibler divergence
  def klDivergence(self, x, kl_version):
    #diff
      # - VAE do sums, then multiplies and calculate average
      # - CVAE do multiplies, then sums and not calculate averange
    if kl_version == 1:
      #v1 as VAE
      kl = -0.5 * torch.sum(1 + self.z_log_var(x) - self.z_mean(x)**2 - torch.exp(self.z_log_var(x)), axis=1)
      self.kl_divergence = kl.mean()

    if kl_version == 2:
      #v2 as CVAE
      kl = (0.5 * (self.z_mean(x)**2 + torch.exp(self.z_log_var(x)) - self.z_log_var(x) - 1)).sum()
      self.kl_divergence = kl

    if kl_version == 3:
      #v3 as VAE without mean and mutiply before sum
      kl = torch.sum( -0.5 * (1 + self.z_log_var(x) - self.z_mean(x)**2 - torch.exp(self.z_log_var(x))), axis=1)
      self.kl_divergence = kl

    if kl_version == 4:
      #v4 as CVAE with mean
      kl = (0.5 * (self.z_mean(x)**2 + torch.exp(self.z_log_var(x)) - self.z_log_var(x) - 1)).sum()
      self.kl_divergence = kl.mean()

  def forward(self, features, targets, device, kl_version):
    # if not str(features.dtype) == 'torch.float32':
    #   raise Exception("check dtype provided must be torch.float32")

    onehot_targets = self.to_onehot_fn(targets, self.num_classes, device)
    x = torch.cat((features, onehot_targets), dim=1)

    x = self.hidden_1(x)
    x = F.leaky_relu(x) # derivada no será cero

    z_mean =  self.z_mean(x) # mu o z_mean es la distribucion normal
    z_log_var = self.z_log_var(x)

    self.klDivergence(x, kl_version)

    return self.latentVector(z_mean, z_log_var, device)


In [None]:
class ConditionalDecoder(nn.Module):
  def __init__(self, latent_dims, num_features, num_hidden_layers, num_classes, to_onehot_fn):
    super(ConditionalDecoder, self).__init__()
    # En este caso definimos la capa de manera inversa
    self.linear1 = nn.Linear(latent_dims+num_classes, num_hidden_layers)
    self.linear2 = nn.Linear(num_hidden_layers, num_features+num_classes)

    self.num_classes = num_classes
    self.to_onehot_fn = to_onehot_fn

  def forward(self, z_encoded, targets, device):
    onehot_targets = self.to_onehot_fn(targets, self.num_classes, device)
    z = torch.cat((z_encoded, onehot_targets), dim=1)

    # Activacion de la primera capa a partir del codigo latente obtenido del encoder
    x = self.linear1(z)
    x = F.leaky_relu(x)
    # En el resultaod anterior activamos la segunda capa
    x = self.linear2(x)
    # aplicamos sigmoid para obtener la salida normalizada entre 0 y 1
    z_decoded = torch.sigmoid(x)
    return z_decoded


In [None]:
class ConditionalVariationalAutoencoder(nn.Module):
  # Constructor
  def __init__(self, latent_dims, num_features, num_hidden_layers, num_classes):
    super(ConditionalVariationalAutoencoder, self).__init__()
    self.encoder = ConditionalVariationalEncoder(latent_dims, num_features, num_hidden_layers, num_classes, self.to_onehot)
    self.decoder = ConditionalDecoder(latent_dims, num_features, num_hidden_layers, num_classes, self.to_onehot)

  def to_onehot(self, labels, num_classes, device):
    # binariza las etiquetas, es decir convierte las etiquetas de clase en columnas y por fila, asigna un 1 en la columna que cumpla con la etiqueta
    labels_onehot = torch.zeros(labels.size()[0], num_classes).to(device)
    labels_onehot.scatter_(1, labels.view(-1, 1), 1)
    return labels_onehot

  def forward(self, features, targets, kl_version):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    z = self.encoder(features, targets, device, kl_version)
    return self.decoder(z, targets, device)