In [128]:
#!g1.1
import os

import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import typing as tp

In [129]:
#!g1.1
device = torch.device('cuda')

In [195]:
#!g1.1
from music_project.common.melspec import MelSpectrogram, MelSpectrogramConfig
from music_project.common.dataset import MusicDataset
from music_project.common.collator import split_collate_fn
from music_project.common.utils import init_scheduler
from music_project.config import TaskConfig
from music_project.models.pretrained_vocoder import Vocoder
from music_project.vae_trainer.vae_trainer import SpecVaeTrainer
from IPython import display

In [282]:
#!g1.1
device = torch.device('cuda')

In [None]:
#!g1.1
!git clone https://github.com/NVIDIA/waveglow.git

In [None]:
#!g1.1
import sys
sys.path.append('waveglow/')
from music_project.models.pretrained_vocoder import Vocoder
vocoder = Vocoder().to(device).eval()

In [284]:
#!g1.1
CONFIG = TaskConfig()
N_EPOCHS = CONFIG.n_epochs
BATCH_SIZE = CONFIG.dataloaders_params["batch_size"]
MEL_LENGTH = CONFIG.mel_length

In [285]:
#!g1.1
data_params = CONFIG.dataloaders_params
dataset = MusicDataset(data_path="../nsynth-train/audio")
total_len = len(dataset)

train_len = int(data_params["train_size"] * total_len)
val_len = total_len - train_len
train_part, val_part = torch.utils.data.random_split(
    dataset, [train_len, val_len],
    generator=torch.Generator().manual_seed(42)
)

train_loader = torch.utils.data.DataLoader(
    train_part, batch_size=BATCH_SIZE, collate_fn=split_collate_fn,
    shuffle=True, num_workers=data_params["num_workers"], pin_memory=True
)
val_loader = torch.utils.data.DataLoader(
    val_part, batch_size=BATCH_SIZE, collate_fn=split_collate_fn,
    shuffle=True, num_workers=data_params["num_workers"], pin_memory=True
)


In [264]:
#!g1.1
wav, mel = next(iter(val_loader))
wav.shape, mel.shape

(torch.Size([128, 8192]), torch.Size([128, 80, 33]))

In [350]:
#!g1.1
model = MelVAE(input_size=(80, MEL_LENGTH))
# model.load_state_dict(torch.load('MelVAE_run1_epoch_5.pth'))
model = model.to(device)

In [351]:
#!g1.1
optimizer = torch.optim.Adam(model.parameters(), lr=0.001,  amsgrad=True)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=len(train_loader), gamma=0.1)

In [353]:
#!g1.1
trainer = SpecVaeTrainer(model, optimizer, device, mel_length=MEL_LENGTH, vocoder=vocoder,
                 train_data_loader=train_loader, valid_data_loader=val_loader, lr_scheduler=lr_scheduler)

In [None]:
#!g1.1
N_EPOCHS = 30

RUN_NAME = "MelVAE_no-rec-loss"
wandb.init(project='vae_music', name=RUN_NAME)

In [None]:
#!g1.1
for i in range(N_EPOCHS):
    trainer._train_epoch(i)
    torch.save(trainer.model.state_dict(), RUN_NAME + '_lastepoch.pth')
    if i % 5 == 0:
        torch.save(trainer.model.state_dict(), RUN_NAME + '_epoch_' + str(i) + '.pth')

In [258]:
#!g1.1
y = mel
reconstructed_wav = vocoder.inference(y.to(device)).cpu()
display.display(display.Audio(reconstructed_wav[0], rate=22050))

In [259]:
#!g1.1
reconstructed_wav = vocoder.inference(out.to(device)).cpu()
display.display(display.Audio(reconstructed_wav[0], rate=22050))