In [1]:
%load_ext autoreload 
%autoreload 2

In [2]:
import os, time, json, torch 
import torch.nn.functional as F 
from torch.utils.tensorboard import SummaryWriter 
from torch.utils.data import DataLoader 
import torch.optim as optim 
from model import generator, multi_period_discriminator, multi_scale_discriminator, feature_loss, generator_loss, discriminator_loss
import wandb 
import pandas as pd 
import itertools
import dataset 
import sounddevice as sd 
from tqdm import tqdm, trange
import matplotlib.pyplot as plt 

device = 'cuda' if torch.cuda.is_available else 'cpu'
if device == 'cuda': print(torch.cuda.get_device_name())
else: print('cpu')

NVIDIA GeForce RTX 3090


In [3]:
def plot_spectrogram(spectrogram): 
  fig, ax = plt.subplots(figsize=(10, 2)) 
  im = ax.imshow(spectrogram, aspect='auto', origin='lower', interpolation='none')
  plt.colorbar(im, ax=ax) 
  fig.canvas.draw() 
  plt.close() 
  return fig 

In [4]:
class Attr(dict): 
  def __init__(self, *args, **kwargs): 
    super(Attr, self).__init__(*args, **kwargs) 
    self.__dict__ = self 

In [5]:
with open('config.json', 'rb') as file: 
  config = json.load(file)
config = Attr(config)

In [6]:
generator_model = generator(config).to(device) 
mpd = multi_period_discriminator().to(device) 
msd = multi_scale_discriminator().to(device) 

In [7]:
optim_g = optim.AdamW(generator_model.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
optim_d = optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()), config.learning_rate, betas=[config.adam_b1, config.adam_b2])

scheduler_g = optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay)
scheduler_d = optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay) 

In [8]:
data_path_base = '/mnt/sda1/data/lj_speech/LJSpeech-1.1/wavs/'
all_files = [os.path.join(data_path_base, i) for i in os.listdir('/mnt/sda1/data/lj_speech/LJSpeech-1.1/wavs')]
cutoff = int(0.9 * len(all_files))
training_filelist = all_files[:cutoff]
testing_filelist = all_files[cutoff:]

In [9]:
trainset = dataset.mel_dataset(training_filelist, config.segment_size, config.n_fft, config.num_mels, 
                               config.hop_size, config.win_size, config.sampling_rate, 
                               config.fmin, config.fmax, n_cache_reuse=0, 
                               shuffle=True, fmax_loss=config.fmax_for_loss, device=device, 
                               fine_tuning=False, base_mels_path=data_path_base)
testset = dataset.mel_dataset(testing_filelist, config.segment_size, config.n_fft, config.num_mels, 
                               config.hop_size, config.win_size, config.sampling_rate, 
                               config.fmin, config.fmax, split=False, n_cache_reuse=0, 
                               shuffle=False, fmax_loss=config.fmax_for_loss, device=device, 
                               fine_tuning=False, base_mels_path=data_path_base)

In [10]:
train_dl = DataLoader(trainset, num_workers=config.num_workers, shuffle=False, batch_size=config.batch_size, pin_memory=True, drop_last=True)
test_dl = DataLoader(testset, num_workers=1, shuffle=False, batch_size=1, pin_memory=True, drop_last=True)

In [11]:
wandb.init(
    project = 'HiFi GAN', 
    entity='uuzall', 
    sync_tensorboard=True, 
    name=''
)
writer = SummaryWriter(f'runs/')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33muuzall[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668224550085143, max=1.0…

In [12]:
def validation_loop(global_step): 
  generator_model.eval() 
  torch.cuda.empty_cache() 
  val_err_tot = 0 
  with torch.no_grad(): 
    for j, batch in enumerate(test_dl): 
      x, y, _, y_mel = batch 
      y_g_hat = generator_model(x.permute(0, 2, 1).to(device)) 
      y_mel = y_mel.to(device) 
      y_g_hat_mel = dataset.mel_spectrogram(y_g_hat.squeeze(1), config.n_fft, config.num_mels, config.sampling_rate, 
                                            config.hop_size, config.win_size, config.fmin, config.fmax_for_loss) 
      val_err_tot += F.l1_loss(y_mel, y_g_hat_mel.permute(2, 0, 1)).item() 

      if j <= 4: 
        if global_step == 0: 
          writer.add_audio(f'gt/y_{j}', y[0], global_step, config.sampling_rate) 
          writer.add_figure(f'gt/y_spec_{j}', plot_spectrogram(x[0]), global_step) 
        
        writer.add_audio(f'generated/y_hat_{j}', y_g_hat[0], global_step, config.sampling_rate)
        y_hat_spec = dataset.mel_spectrogram(y_g_hat.squeeze(1), config.n_fft, config.num_mels, 
                                             config.sampling_rate, config.hop_size, config.win_size, 
                                             config.fmin, config.fmax) 
        writer.add_figure(f'generated/y_hat_spec_{j}', plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), global_step) 
  val_err_tot /= (j+1)
  writer.add_scalar('testing/mel_spec_error', val_err_tot, global_step)
  generator_model.train() 
  return val_err_tot


In [13]:
generator_model.train() 
mpd.train() 
msd.train() 
global_step = 0 
n_epochs = 100
best_loss = 1000

for epoch in range(n_epochs): 
  for i, batch in (loop := tqdm(enumerate(train_dl), total=len(train_dl))): 
    x, y, _, y_mel = batch 
    x, y, y_mel = x.permute(0, 2, 1).to(device), y.unsqueeze(1).to(device), y_mel.to(device)

    y_g_hat = generator_model(x) 
    y_g_hat_mel = dataset.mel_spectrogram(y_g_hat.squeeze(1), config.n_fft, config.num_mels, config.sampling_rate, config.hop_size, config.win_size, config.fmin, config.fmax_for_loss)
    optim_d.zero_grad() 

    y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach()) 
    loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g) 

    y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach()) 
    loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g) 

    loss_disc_all = loss_disc_s + loss_disc_f 

    loss_disc_all.backward() 
    optim_d.step() 
    
    optim_g.zero_grad() 

    loss_mel = F.l1_loss(y_mel, y_g_hat_mel.permute(2, 0, 1)) * 45 

    y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat) 
    y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat) 
    loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) 
    loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) 
    loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) 
    loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) 
    loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel 

    loss_gen_all.backward() 
    optim_g.step() 

    writer.add_scalar('training/gen_loss_total', loss_gen_all.item(), global_step)
    writer.add_scalar('training/mel_spec_error', loss_mel.item()/45, global_step) 
    writer.add_scalar('training/disc_loss_total', loss_disc_all.item(), global_step)

    if global_step % 500 == 0: 
      test_loss = validation_loop(global_step) 

      if best_loss > test_loss: 
        best_loss = test_loss 
        torch.save(generator_model.state_dict(), 'models/gen_model')
        torch.save(mpd.state_dict(), 'models/mpd')
        torch.save(msd.state_dict(), 'models/msd')
    
    global_step += 1 

    loop.set_description(f'Epoch {epoch+1}/{n_epochs}')
    loop.set_postfix(loss_disc_all=loss_disc_all.item(), loss_gen_all=loss_gen_all.item(), best_loss=best_loss, current_test_loss=test_loss)

  writer.add_scalar('learning/d_lr', scheduler_d.get_last_lr()[0], global_step)
  writer.add_scalar('learning/g_lr', scheduler_g.get_last_lr()[0], global_step)
  scheduler_d.step() 
  scheduler_g.step() 

Epoch 1/100: 100%|██████████| 736/736 [08:12<00:00,  1.49it/s, best_loss=0.641, current_test_loss=0.641, loss_disc_all=3.25, loss_gen_all=35.3]
Epoch 2/100: 100%|██████████| 736/736 [07:16<00:00,  1.69it/s, best_loss=0.545, current_test_loss=0.545, loss_disc_all=3.26, loss_gen_all=29.9]  
Epoch 3/100: 100%|██████████| 736/736 [07:51<00:00,  1.56it/s, best_loss=0.494, current_test_loss=0.508, loss_disc_all=3.65, loss_gen_all=29.2] 
Epoch 4/100: 100%|██████████| 736/736 [07:17<00:00,  1.68it/s, best_loss=0.478, current_test_loss=0.478, loss_disc_all=3.48, loss_gen_all=27.5]  
Epoch 5/100: 100%|██████████| 736/736 [07:53<00:00,  1.55it/s, best_loss=0.453, current_test_loss=0.453, loss_disc_all=3.49, loss_gen_all=27.1] 
Epoch 6/100: 100%|██████████| 736/736 [07:16<00:00,  1.69it/s, best_loss=0.444, current_test_loss=0.444, loss_disc_all=3.51, loss_gen_all=28]    
Epoch 7/100: 100%|██████████| 736/736 [07:58<00:00,  1.54it/s, best_loss=0.444, current_test_loss=0.454, loss_disc_all=3.34, los

In [14]:
n_epochs = 200

for epoch in range(100, n_epochs): 
  for i, batch in (loop := tqdm(enumerate(train_dl), total=len(train_dl))): 
    x, y, _, y_mel = batch 
    x, y, y_mel = x.permute(0, 2, 1).to(device), y.unsqueeze(1).to(device), y_mel.to(device)

    y_g_hat = generator_model(x) 
    y_g_hat_mel = dataset.mel_spectrogram(y_g_hat.squeeze(1), config.n_fft, config.num_mels, config.sampling_rate, config.hop_size, config.win_size, config.fmin, config.fmax_for_loss)
    optim_d.zero_grad() 

    y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach()) 
    loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g) 

    y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach()) 
    loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g) 

    loss_disc_all = loss_disc_s + loss_disc_f 

    loss_disc_all.backward() 
    optim_d.step() 
    
    optim_g.zero_grad() 

    loss_mel = F.l1_loss(y_mel, y_g_hat_mel.permute(2, 0, 1)) * 45 

    y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat) 
    y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat) 
    loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) 
    loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) 
    loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) 
    loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) 
    loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel 

    loss_gen_all.backward() 
    optim_g.step() 

    writer.add_scalar('training/gen_loss_total', loss_gen_all.item(), global_step)
    writer.add_scalar('training/mel_spec_error', loss_mel.item()/45, global_step) 
    writer.add_scalar('training/disc_loss_total', loss_disc_all.item(), global_step)

    if global_step % 500 == 0: 
      test_loss = validation_loop(global_step) 

      if best_loss > test_loss: 
        best_loss = test_loss 
        torch.save(generator_model.state_dict(), 'models/gen_model')
        torch.save(mpd.state_dict(), 'models/mpd')
        torch.save(msd.state_dict(), 'models/msd')
    
    global_step += 1 

    loop.set_description(f'Epoch {epoch+1}/{n_epochs}')
    loop.set_postfix(loss_disc_all=loss_disc_all.item(), loss_gen_all=loss_gen_all.item(), best_loss=best_loss, current_test_loss=test_loss)

  writer.add_scalar('learning/d_lr', scheduler_d.get_last_lr()[0], global_step)
  writer.add_scalar('learning/g_lr', scheduler_g.get_last_lr()[0], global_step)
  scheduler_d.step() 
  scheduler_g.step() 

Epoch 101/200: 100%|██████████| 736/736 [06:49<00:00,  1.80it/s, best_loss=0.258, current_test_loss=0.261, loss_disc_all=2.65, loss_gen_all=25.5]
Epoch 102/200: 100%|██████████| 736/736 [07:24<00:00,  1.66it/s, best_loss=0.258, current_test_loss=0.258, loss_disc_all=2.75, loss_gen_all=24.3]  
Epoch 103/200: 100%|██████████| 736/736 [06:51<00:00,  1.79it/s, best_loss=0.255, current_test_loss=0.255, loss_disc_all=2.88, loss_gen_all=24.9]
Epoch 104/200: 100%|██████████| 736/736 [07:23<00:00,  1.66it/s, best_loss=0.255, current_test_loss=0.256, loss_disc_all=2.66, loss_gen_all=25]    
Epoch 105/200: 100%|██████████| 736/736 [06:53<00:00,  1.78it/s, best_loss=0.255, current_test_loss=0.26, loss_disc_all=2.83, loss_gen_all=23.4] 
Epoch 106/200: 100%|██████████| 736/736 [07:24<00:00,  1.66it/s, best_loss=0.255, current_test_loss=0.256, loss_disc_all=2.66, loss_gen_all=25]    
Epoch 107/200: 100%|██████████| 736/736 [06:52<00:00,  1.78it/s, best_loss=0.253, current_test_loss=0.253, loss_disc_a