In [None]:
import torch
from torch.utils.data import Dataset, TensorDataset, DataLoader

from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
import torchaudio
Resample = torchaudio.transforms.Resample(44100, 48000, resampling_method='kaiser_window')

if torch.cuda.is_available():
    device = 'cuda:0'
    my_cuda = 1
else: 
    device = 'cpu'
    my_cuda = 0
    
Resample = Resample.to(device)

from pathlib import Path
import random
import numpy as np
from scipy import interpolate as sp_interpolate
import json

import librosa
import librosa.display

import soundfile as sf
# import sounddevice as sd
import configparser
import random
import json
import matplotlib.pyplot as plt
import IPython.display as display

In [None]:
sampling_rate = 44100
sr = sampling_rate

hop_length = 128

segment_length = 1024
n_units = 2048
latent_dim = 256

batch_size = 256

audio_fold = Path(r'../vits2_pytorch/data/filelists/lolo_audio_sid_text_test_filelist.txt')
audio = audio_fold
lts_audio_files = [f for f in audio_fold.glob('*.wav')]

In [None]:

# Following should give you more than 0. Otherwise, the dataset is not in the right place. Please make sure that the following folder is there: rawaudiovae/content/2022-zkm-workshop

len(lts_audio_files)

In [None]:
# Models 

class raw_VAE(nn.Module):
  def __init__(self, segment_length, n_units, latent_dim):
    super(raw_VAE, self).__init__()

    self.segment_length = segment_length
    self.n_units = n_units
    self.latent_dim = latent_dim
    
    self.fc1 = nn.Linear(segment_length, n_units)
    self.fc21 = nn.Linear(n_units, latent_dim)
    self.fc22 = nn.Linear(n_units, latent_dim)
    self.fc3 = nn.Linear(latent_dim, n_units)
    self.fc4 = nn.Linear(n_units, segment_length)

  def encode(self, x):
      h1 = F.relu(self.fc1(x))
      return self.fc21(h1), self.fc22(h1)

  def reparameterize(self, mu, logvar):
      std = torch.exp(0.5*logvar)
      eps = torch.randn_like(std)
      return mu + eps*std

  def decode(self, z):
      h3 = F.relu(self.fc3(z))
      return F.tanh(self.fc4(h3))

  def forward(self, x):
      mu, logvar = self.encode(x.view(-1, self.segment_length))
      z = self.reparameterize(mu, logvar)
      return self.decode(z), mu, logvar

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar, kl_beta, segment_length):
  recon_loss = F.mse_loss(recon_x, x.view(-1, segment_length))

  # see Appendix B from VAE paper:
  # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
  # https://arxiv.org/abs/1312.6114
  # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
  KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

  return recon_loss + ( kl_beta * KLD)
