In [1]:
import os
import sys

sys.path.insert(0, os.path.abspath(".."))

import torch
import torch.nn.functional as F
from data import TinyVCTK
from torch.utils.data import DataLoader
from torch.optim import Adam, lr_scheduler
from utils import load_configs, NucleusSampler
from train import train
from models.wavenet import WaveNet

configs = load_configs("../configs.json")

In [2]:
sampler = NucleusSampler(threshold=1.0)  # degenerate to standard temperature(=1) sampling

model = WaveNet(
    input_dim=256,
    hidden_dim=32,
    skip_dim=256,
    kernel_size=2,
    layers_per_block=10,
    num_blocks=5,
    quantization=True,
    padding_mode="learnable",
    sampler=sampler
)

optimizer = Adam(model.parameters())
scheduler = lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.1)
epochs = 30
batch_size = 8
num_batches_eval = 10
start_length = 8000

root = configs["dataset_path"]

def loss_fn(xhat, x):
    return F.cross_entropy(xhat[:, :, :-1], x[:, 1:].long(), ignore_index=-1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)

# considering the computational cost and our purpose of generating human-like voice
# we will use real audios as seeds and train our model on the testset (sample size is smaller)
trainset = TinyVCTK.load_default(root=configs["dataset_path"], train=False)
testset = TinyVCTK.load_default(root=configs["dataset_path"], train=False)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)

preprocessing = None
postprocessing = trainset.transform.mu_law_decode

train(
    model,
    trainloader,
    testloader,
    epochs,
    optimizer,
    loss_fn,
    device,
    scheduler=scheduler,
    data_type="audio",
    preprocessing=preprocessing,
    postprocessing=postprocessing,
    additional_eval=None,
    num_batches_eval=num_batches_eval,
    index_type=None,
    save_dir="..",
    sampling_rate=16000,
    start_length=start_length,
    generate_length=24000,
    n_tracks=3
)

1/30 epochs: 100%|██████████| 205/205 [09:59<00:00,  2.92s/it, train_loss=3.45, test_loss=2.78] 
2/30 epochs: 100%|██████████| 205/205 [09:54<00:00,  2.90s/it, train_loss=2.74, test_loss=2.59] 
3/30 epochs: 100%|██████████| 205/205 [09:53<00:00,  2.89s/it, train_loss=2.61, test_loss=2.52] 
4/30 epochs: 100%|██████████| 205/205 [09:56<00:00,  2.91s/it, train_loss=2.54, test_loss=2.47] 
5/30 epochs: 100%|██████████| 205/205 [09:53<00:00,  2.90s/it, train_loss=2.51, test_loss=2.43] 
6/30 epochs: 100%|██████████| 205/205 [09:56<00:00,  2.91s/it, train_loss=2.47, test_loss=2.41] 
7/30 epochs: 100%|██████████| 205/205 [09:56<00:00,  2.91s/it, train_loss=2.45, test_loss=2.39] 
8/30 epochs: 100%|██████████| 205/205 [09:54<00:00,  2.90s/it, train_loss=2.43, test_loss=2.37] 
9/30 epochs: 100%|██████████| 205/205 [09:57<00:00,  2.91s/it, train_loss=2.41, test_loss=2.35] 
10/30 epochs: 100%|██████████| 205/205 [09:59<00:00,  2.93s/it, train_loss=2.4, test_loss=2.34] 
11/30 epochs: 100%|██████████|