In [1]:
import os
import sys

sys.path.insert(0, os.path.abspath(".."))

import torch
import torch.nn as nn
import torch.nn.functional as F
from data import TinyVCTK
from torch.utils.data import DataLoader
from torch.optim import Adam
from tqdm import tqdm
from utils import load_configs
from train import train
from models.wavenet import WaveNet

configs = load_configs("../configs.json")

In [2]:
model = WaveNet(layers_per_block=8)

optimizer = Adam(model.parameters())

epochs = 30
batch_size = 16
num_batches_eval = 10

root = configs["dataset_path"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def loss_fn(x_hat, x):
    return F.cross_entropy(x_hat[:, :, :-1], (128 * x[:, 1:]).long() + 128, reduction="mean")

model.to(device)

trainset = TinyVCTK.load_default(root=configs["dataset_path"])
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(trainset, batch_size=batch_size, shuffle=False)  # reuse the training data

preprocessing = None
postprocessing = trainset.transform.mu_law_decode

train(
    model,
    trainloader,
    testloader,
    epochs,
    optimizer,
    loss_fn,
    device,
    data_type="audio",
    preprocessing=preprocessing,
    postprocessing=postprocessing,
    additional_eval=None,
    num_batches_eval=num_batches_eval,
    index_type=None,
    save_dir="..",
    sampling_rate=16000,
    person_id=27,
    n_tracks=3
)

1/30 epochs: 100%|██████████| 724/724 [04:32<00:00,  2.65it/s, train_loss=2.71, test_loss=2.33]
2/30 epochs: 100%|██████████| 724/724 [04:54<00:00,  2.46it/s, train_loss=2.29, test_loss=2.11]
3/30 epochs: 100%|██████████| 724/724 [04:43<00:00,  2.55it/s, train_loss=2.17, test_loss=2.02]
4/30 epochs: 100%|██████████| 724/724 [04:41<00:00,  2.57it/s, train_loss=2.11, test_loss=2.06]
5/30 epochs: 100%|██████████| 724/724 [05:02<00:00,  2.39it/s, train_loss=2.07, test_loss=2]
6/30 epochs: 100%|██████████| 724/724 [04:48<00:00,  2.51it/s, train_loss=2.03, test_loss=1.97]
7/30 epochs: 100%|██████████| 724/724 [04:37<00:00,  2.61it/s, train_loss=2.03, test_loss=1.97]
8/30 epochs: 100%|██████████| 724/724 [04:31<00:00,  2.67it/s, train_loss=1.99, test_loss=1.93]
9/30 epochs: 100%|██████████| 724/724 [04:37<00:00,  2.61it/s, train_loss=1.98, test_loss=1.97]
10/30 epochs: 100%|██████████| 724/724 [04:33<00:00,  2.65it/s, train_loss=1.96, test_loss=1.87]
11/30 epochs: 100%|██████████| 724/724 [04