[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sebascarag/AI-SyntheticSound/blob/main/Model_CVAE.ipynb)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
import copy

In [None]:
class ConditionalVariationalEncoder(nn.Module):
  # Constructor
  def __init__(self, latent_dims, num_features, num_hidden_layers, num_classes, to_onehot_fn):
    super(ConditionalVariationalEncoder, self).__init__()
    self.hidden_1 = nn.Linear(num_features+num_classes, num_hidden_layers) # capa oculta
    self.z_mean = nn.Linear(num_hidden_layers, latent_dims) # vector de medias (𝜇)
    self.z_log_var = nn.Linear(num_hidden_layers, latent_dims) #  vector de deviación estándar (𝜎)
    self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    self.num_classes = num_classes
    self.kl_divergence = 0
    self.to_onehot_fn = to_onehot_fn

  # cálcular vector latente a partir de la distribución normal estándar 𝑧~𝑁(𝜇, 𝜎)
  def latentVector(self, z_mu, z_log_var):
    # epsilon from standard normal distribution
    eps = torch.randn(z_mu.size(0), z_mu.size(1)).to(self.device)
    sigma = torch.exp(z_log_var/2.)
    z = z_mu + eps * sigma
    return z

  # cálcular divergencia Kullback-Leibler
  def klDivergence(self, x, kl_version):
    if kl_version == 1:
      #v1 as VAE
      kl = -0.5 * torch.sum(1 + self.z_log_var(x) - self.z_mean(x)**2 - torch.exp(self.z_log_var(x)), axis=1)
      self.kl_divergence = kl.mean()

    if kl_version == 2:
      #v2 as CVAE
      kl = (0.5 * (self.z_mean(x)**2 + torch.exp(self.z_log_var(x)) - self.z_log_var(x) - 1)).sum()
      self.kl_divergence = kl

    if kl_version == 3:
      #v3 as VAE without mean and mutiply before sum
      kl = -0.5 * torch.sum(1 + self.z_log_var(x) - self.z_mean(x)**2 - torch.exp(self.z_log_var(x)), axis=1)
      self.kl_divergence = kl.sum()

    if kl_version == 4:
      #v4 as CVAE with mean
      kl = (0.5 * (self.z_mean(x)**2 + torch.exp(self.z_log_var(x)) - self.z_log_var(x) - 1)).sum()
      self.kl_divergence = kl.mean()

  def forward(self, features, targets, kl_version):
    if not str(features.dtype) == 'torch.float32':
      raise ValueError("check dtype provided must be torch.float32")

    onehot_targets = self.to_onehot_fn(targets, self.num_classes, self.device)
    x = torch.cat((features, onehot_targets), dim=1)

    x = self.hidden_1(x) # analizar características (capa oculta)
    x = F.leaky_relu(x) # función de activación

    z_mean =  self.z_mean(x) # z_mu o z_mean
    z_log_var = self.z_log_var(x)

    self.klDivergence(x, kl_version)

    return self.latentVector(z_mean, z_log_var)


In [None]:
class ConditionalDecoder(nn.Module):
  # Constructor
  def __init__(self, latent_dims, num_features, num_hidden_layers, num_classes, to_onehot_fn):
    super(ConditionalDecoder, self).__init__()
    # se definen capas desde el espacio latente al espacio de entrada
    self.linear1 = nn.Linear(latent_dims+num_classes, num_hidden_layers)
    self.linear2 = nn.Linear(num_hidden_layers, num_features+num_classes)
    self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    self.num_classes = num_classes
    self.to_onehot_fn = to_onehot_fn

  def forward(self, z_encoded, targets):
    onehot_targets = self.to_onehot_fn(targets, self.num_classes, self.device)
    z = torch.cat((z_encoded, onehot_targets), dim=1)

    x = self.linear1(z) # primera capa con el espacio latente obtenido del encoder
    x = F.leaky_relu(x) # función de activación
    x = self.linear2(x) # con el resultado anterior activamos la segunda capa
    z_decoded = torch.sigmoid(x) # sigmoid para obtener la salida normalizada entre 0 y 1

    return z_decoded


In [None]:
class ConditionalVariationalAutoencoder(nn.Module):
  # Constructor
  def __init__(self, latent_dims, num_features, num_hidden_layers, num_classes, random_seed=None):
    super(ConditionalVariationalAutoencoder, self).__init__()
    self.encoder = ConditionalVariationalEncoder(latent_dims, num_features, num_hidden_layers, num_classes, self.to_onehot)
    self.decoder = ConditionalDecoder(latent_dims, num_features, num_hidden_layers, num_classes, self.to_onehot)
    self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    self.num_classes = num_classes
    if random_seed is not None:
      torch.manual_seed(random_seed)
      torch.cuda.manual_seed(random_seed)
    print('Device:', self.device)

  def to_onehot(self, labels, num_classes, device):
    # binariza las etiquetas, es decir convierte las etiquetas de clase en columnas,
    # por fila asigna un 1 en la columna que cumpla con la etiqueta
    labels_onehot = torch.zeros(labels.size()[0], num_classes).to(device)
    labels_onehot.scatter_(1, labels.view(-1, 1), 1)
    return labels_onehot

  def forward(self, features, targets, kl_version):
    z = self.encoder(features, targets, kl_version)
    return self.decoder(z, targets)

  # entrenamiento de la red VAE.
  # con el parámetro optuna_trial se realiza tareas de optimización de hiperparámetros con Optuna
  def train_fit(self, data, learning_rate=1e-3, num_epochs=20, kl_version=2, flatten=False, optuna_trial=None, max_error_epoch_loss=1000000):
    start_time = time.time()
    optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)

    if optuna_trial is not None: # para tareas de optimización de hiperparámetros con Optuna
      import optuna

    best_model = copy.deepcopy(self) # copiar modelo base
    best_loss = None
    epoch_losses = np.empty((0,3))

    for epoch in range(num_epochs):

      start_time_elapsed = time.time()
      batch_losses = np.empty((0,3))

      for batch_idx, (features, labels) in enumerate(data):

        if flatten:
          features = torch.flatten(features, start_dim=1).to(self.device)
        else:
          features = features.to(self.device)

        targets = labels.type(torch.int64).to(self.device)

        optimizer.zero_grad() # importante: antes de usar modelo

        decoded = self(features, targets, kl_version) # enttrenar CVAE'
        x_con = torch.cat((features, self.to_onehot(targets, self.num_classes, self.device)), dim=1)
        kl_divergence = self.encoder.kl_divergence
        pixelwise = torch.nn.functional.binary_cross_entropy(decoded, x_con, reduction='sum')

        loss = kl_divergence + pixelwise # cost = reconstruction loss + Kullback-Leibler divergence

        # actualizar parámetros del modelo
        loss.backward()
        optimizer.step()

        # save cost/loss
        batch_losses = np.append(batch_losses, [[loss.item(), kl_divergence.item(), pixelwise.item()]], axis=0)

        # logging progress
        if not batch_idx % num_epochs:
          print('Epoch: %03d/%03d | Batch %03d/%03d | kl: %.4f + pw: %.4f = loss: %.4f | KL v%02d'
                %(epoch+1, num_epochs, batch_idx, len(data.dataset)//data.batch_size, kl_divergence, pixelwise, loss, kl_version))

      # cálcular promedio perdida por época
      epoch_loss = np.mean(batch_losses[:,0])
      epoch_losses = np.append(epoch_losses, [[epoch_loss, np.mean(batch_losses[:,1]), np.mean(batch_losses[:,2])]], axis=0)

      if best_loss is None or epoch_loss < best_loss:
        best_loss = epoch_loss
        best_model.load_state_dict(self.state_dict())  # copiar pesos y sesgos

      print('Time elapsed: %.2f min' % ((time.time() - start_time_elapsed)/60))

      if optuna_trial is not None:
        optuna_trial.report(epoch_loss, epoch)
        if optuna_trial.should_prune() or epoch_loss >= max_error_epoch_loss or epoch_loss is None:
          print('Prune on epoch: {:0>3} | loss:{:.4f}'.format(epoch, epoch_loss))
          print(f'Params: {optuna_trial.params}')
          raise optuna.TrialPruned()

    print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))
    return best_model, best_loss, epoch_losses

