In [1]:
import torch
from torch import nn, optim
from jcopdl.callback import Callback, set_config

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

# Dataset & Dataloader

In [2]:
from src.dataset import VCTKTripletDataset, VCTKTripletDataloader
from torch.utils.data import DataLoader

In [3]:
bs = 32

train_set = VCTKTripletDataset("vctk_dataset/wav48/", "vctk_dataset/txt/", n_data=3000, min_dur=1.5)
trainloader = VCTKTripletDataloader(train_set, batch_size=bs)

test_set = VCTKTripletDataset("vctk_dataset/wav48/", "vctk_dataset/txt/", n_data=3000, min_dur=1.5)
testloader = VCTKTripletDataloader(test_set, batch_size=bs)

HBox(children=(FloatProgress(value=0.0, description='Sample Data', max=3000.0, style=ProgressStyle(description…


Excluding 6 triplet containing audio shorter than 1.5s


HBox(children=(FloatProgress(value=0.0, description='Sample Data', max=3000.0, style=ProgressStyle(description…


Excluding 3 triplet containing audio shorter than 1.5s


# Architecture & Config

In [4]:
from src.model import Encoder

In [5]:
config = set_config({
    "ndim": 512,
    "margin": 1,
    "sr": train_set.sr,
    "n_mfcc": train_set.n_mfcc,
    "min_dur": train_set.min_dur
})

# Training Preparation

In [6]:
from jcopdl.optim import RangerLARS

In [7]:
model = Encoder(ndim=config.ndim, triplet=True).to(device)
criterion = nn.TripletMarginLoss(config.margin)
callback = Callback(model, config, outdir="model_vctk", early_stop_patience=15)
optimizer = RangerLARS(model.parameters(), lr=0.001)

# Training

In [8]:
from tqdm.auto import tqdm

In [None]:
while True:
    if callback.ckpt.epoch % 15 == 0:
        train_set = VCTKTripletDataset("vctk_dataset/wav48/", "vctk_dataset/txt/", n_data=3000)
        trainloader = VCTKTripletDataloader(train_set, batch_size=bs)
    
    model.train()
    cost = 0
    for images, labels in tqdm(trainloader, desc="Train"):
        images = images.to(device)
        
        output = model(images)
        loss = criterion(output[0], output[1], output[2])
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
        
        cost += loss.item()*images.shape[0]
    train_cost = cost/len(train_set)
    
    with torch.no_grad():
        model.eval()
        cost = 0
        for images, labels in tqdm(testloader, desc="Test"):
            images = images.to(device)
        
            output = model(images)
            loss = criterion(output[0], output[1], output[2])
            
            cost += loss.item()*images.shape[0]
        test_cost = cost/len(test_set)

    # Logging
    callback.log(train_cost, test_cost)

    # Checkpoint
    callback.save_checkpoint()
        
    # Runtime Plotting
    callback.cost_runtime_plotting()
    
    # Early Stopping
    if callback.early_stopping(model, monitor="test_cost"):
        callback.plot_cost()
        break

HBox(children=(FloatProgress(value=0.0, description='Train', max=94.0, style=ProgressStyle(description_width='…