In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import librosa.display
import numpy as np
import datetime
import time
import os
import json
from skimage.transform import resize

from src.data import NSynth
from src.utils import print_and_log
from src.models import CVAE

Import of 'jit' requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.
  from numba.decorators import jit as optional_jit


In [2]:
if torch.cuda.is_available():  
    device = torch.device("cuda:0")
else:  
    device = torch.device("cpu") 

In [3]:
hparams = {
    'checkpoint' : None,
    'instrument_source' : [0, 1, 2],
    'sample_rate' : 16000,
    'n_samples' : 64000,
    'feature_type' : 'mel',
    'random_crop' : True,
    'resize' : None,
    'normalize' : True,
    'standardize' : True,
    'standardize_mean' : 0.3356,
    'standardize_std' : 0.2212,
    'spec_augment' : False,
    'remove_synth_lead' : True,
    'n_samples_per_class' : None,
    'depths' : (32, 64, 128, 128, 256, 256, 512), 
    'kl_loss_weight' : 0.001,
    'n_epochs' : 50,
    'batch_size' : 32,
    'lr' : 0.0001,
    'hidden_dim' : 1024,
    'display_iters' : 100,
    'val_iters' : 1000,
    'n_val_samples' : 1000, 
    'n_early_stopping' : 5 # stop if validation doesn't improve after this number of validation cycles
}

In [4]:
train_dataset = NSynth(
    'data/nsynth', 
    'train',
    include_meta=True, 
    instrument_source=hparams['instrument_source'], 
    sample_rate=hparams['sample_rate'], 
    n_samples=hparams['n_samples'], 
    feature_type=hparams['feature_type'], 
    random_crop=hparams['random_crop'], 
    resize=hparams['resize'], 
    normalize=hparams['normalize'], 
    standardize=hparams['standardize'], 
    standardize_mean=hparams['standardize_mean'], 
    standardize_std=hparams['standardize_std'], 
    spec_augment=hparams['spec_augment'],
    remove_synth_lead=hparams['remove_synth_lead'], 
    n_samples_per_class=hparams['n_samples_per_class']
)

val_dataset = NSynth(
    'data/nsynth', 
    'val',
    include_meta=True, 
    instrument_source=hparams['instrument_source'], 
    sample_rate=hparams['sample_rate'], 
    n_samples=hparams['n_samples'], 
    feature_type=hparams['feature_type'], 
    random_crop=hparams['random_crop'], 
    resize=hparams['resize'], 
    normalize=hparams['normalize'], 
    standardize=hparams['standardize'], 
    standardize_mean=hparams['standardize_mean'], 
    standardize_std=hparams['standardize_std'], 
    remove_synth_lead=hparams['remove_synth_lead']
)

In [5]:
class_ctr = dict(train_dataset.meta['character'].value_counts())
class_dict = dict(enumerate(sorted(class_ctr)))
inv_class_dict = dict([(v, k) for k, v in class_dict.items()])
class_weights = np.array([max(class_ctr.values())/class_ctr[class_dict[i]] for i in range(len(class_dict))])

In [6]:
sample_weights = [class_weights[inv_class_dict[x]] for x in train_dataset.meta['character'].tolist()]
sampler = torch.utils.data.sampler.WeightedRandomSampler(sample_weights, len(sample_weights))

In [7]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=hparams['batch_size'], sampler=sampler)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=hparams['batch_size'], shuffle=True)

In [8]:
model = CVAE(
    len(class_dict), 
    h_dim=hparams['hidden_dim'], 
    sigmoid=(not hparams['standardize']),
    depths=hparams['depths']
).to(device)

print('# of parameters : {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))

# of parameters : 11531331


In [9]:
optimizer = torch.optim.Adam(model.parameters(), lr=hparams['lr'])

In [10]:
timestamp = datetime.datetime.fromtimestamp(time.time()).strftime('%Y%m%d-%H%M%S')
results_dir = 'train_results/cvae/{}'.format(timestamp)
os.makedirs(results_dir)

with open(os.path.join(results_dir, 'hparams.json'), 'w') as fp:
    json.dump(hparams, fp)

log_file = os.path.join(results_dir, 'train_log.txt')
log = open(log_file, 'w')
log.close()
print_and_log('{} {}'.format(train_dataset.__class__.__name__, model.__class__.__name__), log_file)

for k, v in hparams.items(): print_and_log('{} : {}'.format(k, v), log_file)

NSynth CVAE
checkpoint : None
instrument_source : [0, 1, 2]
sample_rate : 16000
n_samples : 64000
feature_type : mel
random_crop : True
resize : None
normalize : True
standardize : True
standardize_mean : 0.3356
standardize_std : 0.2212
spec_augment : False
remove_synth_lead : True
n_samples_per_class : None
depths : (32, 64, 128, 128, 256, 256, 512)
kl_loss_weight : 0.001
n_epochs : 50
batch_size : 32
lr : 0.0001
hidden_dim : 1024
display_iters : 100
val_iters : 1000
n_val_samples : 1000
n_early_stopping : 5


In [None]:
if hparams['standardize']:
    recon_loss_fn = nn.MSELoss()
else:
    recon_loss_fn = nn.BCELoss()
    
kl_loss_fn = lambda mu, log_var : torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)
clf_loss_fn = nn.CrossEntropyLoss()

ckpt_weights_path = None
best_loss = 1e10
since_best = 0
done = False

if hparams['checkpoint']:
    print_and_log('Resuming training from {}'.format(hparams['checkpoint']), log_file)
    ckpt = torch.load(hparams['checkpoint'])
    epoch = ckpt['epoch']
    itr = ckpt['itr']
    optimizer.load_state_dict(ckpt['optimizer'])
    model.load_state_dict(ckpt['model'])
    best_loss = ckpt['best_loss']
else:
    epoch = 0
    itr = 0

for epoch in range(epoch, hparams['n_epochs']):
    if done:
        break
    
    for batch in train_dataloader:
        if done:
            break
            
        itr += 1
        features = batch[0].to(device)
        labels = torch.tensor(
            [inv_class_dict[i] for i in batch[1]['character']], dtype=torch.long, device=device
        )
        outputs, logits, mu, log_var = model(features, sample=True)
        
        recon_loss = recon_loss_fn(outputs, features)
        kl_loss = kl_loss_fn(mu, log_var)
        clf_loss = clf_loss_fn(logits, labels)
        loss = clf_loss #recon_loss + hparams['kl_loss_weight']*kl_loss + clf_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (itr % hparams['display_iters'] == 0) or (itr == 1):
            print_and_log('[{}, {:5d}] loss : (total : {:.4f}, recon : {:.4f}, kl : {:.4f}, clf : {:.4f})'\
                          .format(epoch, itr, loss.item(), recon_loss.item(), kl_loss.item(), clf_loss.item()), log_file)
            
    # validation
    model.eval()
    with torch.no_grad():
        ct, i, val_loss, val_recon_loss, val_kl_loss, val_clf_loss = 0, 0, 0.0, 0.0, 0.0, 0.0
        for batch in val_dataloader:
            i += 1
            ct += batch[0].size(0)
            features = batch[0].to(device)
            labels = torch.tensor(
                [inv_class_dict[i] for i in batch[1]['character']], dtype=torch.long, device=device
            )
            outputs, logits, mu, log_var = model(features, sample=False)

            recon_loss = recon_loss_fn(outputs, features)
            kl_loss = kl_loss_fn(mu, log_var)
            clf_loss = clf_loss_fn(logits, labels)
            loss = recon_loss + hparams['kl_loss_weight']*kl_loss + clf_loss
            
            val_recon_loss += (recon_loss.item() - val_recon_loss)/i
            val_kl_loss += (kl_loss.item() - val_kl_loss)/i
            val_clf_loss += (clf_loss.item() - val_clf_loss)/i
            val_loss += (loss.item() - val_loss)/i

            if ct >= hparams['n_val_samples']:
                break

    print_and_log('Val - loss : (total : {:.4f}, recon : {:.4f}, kl : {:.4f}, clf : {:.4f})'\
                  .format(val_loss, val_recon_loss, val_kl_loss, val_clf_loss), log_file)
    
    print('Class : {}'.format(batch[1]['character'][0]))
    librosa.display.specshow(batch[0][0].cpu().numpy().squeeze(), sr=16000, x_axis='time', y_axis='hz')
    plt.show()
    librosa.display.specshow(outputs[0][0].detach().cpu().numpy(), sr=16000, x_axis='time', y_axis='hz')
    plt.show()

    if val_loss < best_loss:
        since_best = 0
        best_loss = val_loss

        # save weights
        ckpt_weights_path = os.path.join(results_dir, 'model-{}.weights'.format(itr))
        torch.save(model.state_dict(), ckpt_weights_path)
        print_and_log('Weights saved in {}'.format(ckpt_weights_path), log_file)

        # save meta information
        ckpt_meta_path = os.path.join(results_dir, 'checkpoint')
        torch.save({
            'best_loss' : best_loss,
            'epoch' : epoch,
            'itr' : itr,
            'optimizer' : optimizer.state_dict(),
            'model' : model.state_dict()
        }, ckpt_meta_path)
    else:
        since_best += 1
        if since_best >= hparams['n_early_stopping']:
            done = True
            print_and_log('Early stopping... training complete', log_file)

    model.train()

[0,     1] loss : (total : 0.6713, recon : 4.0495, kl : 261.0703, clf : 0.6713)
[0,   100] loss : (total : 0.2847, recon : 4.0493, kl : 263.2803, clf : 0.2847)
[0,   200] loss : (total : 0.1884, recon : 4.2776, kl : 265.9010, clf : 0.1884)
[0,   300] loss : (total : 0.2031, recon : 4.0753, kl : 264.3385, clf : 0.2031)
[0,   400] loss : (total : 0.2553, recon : 4.0679, kl : 261.3989, clf : 0.2553)
[0,   500] loss : (total : 0.4030, recon : 4.2980, kl : 265.4588, clf : 0.4030)
[0,   600] loss : (total : 0.2317, recon : 4.1084, kl : 265.2944, clf : 0.2317)
[0,   700] loss : (total : 0.4816, recon : 4.0959, kl : 264.2976, clf : 0.4816)
[0,   800] loss : (total : 0.2563, recon : 4.2224, kl : 264.4899, clf : 0.2563)
[0,   900] loss : (total : 0.1577, recon : 4.2060, kl : 262.1267, clf : 0.1577)
[0,  1000] loss : (total : 0.1820, recon : 4.2223, kl : 264.8611, clf : 0.1820)
[0,  1100] loss : (total : 0.4554, recon : 3.9038, kl : 266.7531, clf : 0.4554)
[0,  1200] loss : (total : 0.1013, recon