In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import Counter
import datetime
import time
import os
import json

from src.data import NSynthDataset
from src.utils import print_and_log
from src.models import AutoencoderClassifier

Import of 'jit' requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.[0m
  from numba.decorators import jit as optional_jit


In [2]:
if torch.cuda.is_available():  
    device = torch.device("cuda:0")
else:  
    device = torch.device("cpu") 

In [3]:
hparams = {
    'checkpoint' : None,
    'feature_type' : 'mel',
    'instrument_source' : [0, 1, 2],
    'scaling' : 'normalize',
    'resize' : (128, 128),
    'n_epochs' : 25,
    'batch_size' : 64,
    'lr' : 0.001,
    'hidden_dim' : 128,
    'display_iters' : 100,
    'val_iters' : 1000,
    'n_val_samples' : 1000, 
    'n_early_stopping' : 5 # stop if validation doesn't improve after this number of validation cycles
}

In [4]:
train_dataset = NSynthDataset(
    'music-ml-gigioli', 
    'data/nsynth/nsynth-train', 
    instrument_source=hparams['instrument_source'], 
    feature_type=hparams['feature_type'],
    scaling=hparams['scaling'],
    resize=hparams['resize'],
    include_meta=True
)

val_dataset = NSynthDataset(
    'music-ml-gigioli', 
    'data/nsynth/nsynth-valid', 
    instrument_source=hparams['instrument_source'], 
    feature_type=hparams['feature_type'],
    scaling=hparams['scaling'],
    resize=hparams['resize'],
    include_meta=True
)

In [5]:
class_ctr = Counter([x['instrument_family_str'] for x in train_dataset.meta.values()])
class_dict = dict(enumerate(sorted(class_ctr)))
inv_class_dict = dict([(v, k) for k, v in class_dict.items()])
class_weights = np.array([max(class_ctr.values())/class_ctr[class_dict[i]] for i in range(len(class_dict))])

In [6]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=hparams['batch_size'], shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=hparams['batch_size'], shuffle=True)

In [7]:
model = AutoencoderClassifier(n_classes=len(class_dict), h_dim=hparams['hidden_dim']).to(device)
print('# of parameters : {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))

# of parameters : 8400908


In [8]:
optimizer = torch.optim.Adam(model.parameters(), lr=hparams['lr'])

In [9]:
timestamp = datetime.datetime.fromtimestamp(time.time()).strftime('%Y%m%d-%H%M%S')
results_dir = 'train_results/multitask/{}'.format(timestamp)
os.makedirs(results_dir)

with open(os.path.join(results_dir, 'hparams.json'), 'w') as fp:
    json.dump(hparams, fp)

log_file = os.path.join(results_dir, 'train_log.txt')
log = open(log_file, 'w')
log.close()
print_and_log('{}'.format(model.__class__.__name__), log_file)

AutoencoderClassifier


In [None]:
recon_loss_fn = nn.BCELoss()
class_loss_fn = nn.CrossEntropyLoss(weight=torch.tensor(class_weights, dtype=torch.float32))

ckpt_weights_path = None
best_loss = 1e10
since_best = 0
done = False

if hparams['checkpoint']:
    print_and_log('Resuming training from {}'.format(hparams['checkpoint']), log_file)
    ckpt = torch.load(hparams['checkpoint'])
    epoch = ckpt['epoch']
    itr = ckpt['itr']
    optimizer.load_state_dict(ckpt['optimizer'])
    model.load_state_dict(ckpt['model'])
    best_loss = ckpt['best_loss']
else:
    epoch = 0
    itr = 0

for epoch in range(epoch, hparams['n_epochs']):
    if done:
        break
        
    for batch in train_dataloader:
        if done:
            break
            
        itr += 1
        features = batch[0].unsqueeze(1)
        labels = torch.tensor([inv_class_dict[i] for i in batch[1]['instrument_family_str']], dtype=torch.long)
        logits, recon = model(features)
        
        class_loss = class_loss_fn(logits, labels)
        recon_loss = recon_loss_fn(recon, features)
        loss = class_loss + recon_loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (itr % hparams['display_iters'] == 0) or (itr == 1):
            acc = (logits.argmax(-1) == labels).float().mean()
            print_and_log('[{}, {:5d}] loss : (total - {:.4f}, class - {:.4f}, recon - {:.4f}), acc : {:.4f}'\
                          .format(epoch, itr, loss.item(), class_loss.item(), recon_loss.item(), acc.item()), log_file)
            
        if itr % hparams['val_iters'] == 0:
            val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=hparams['batch_size'], shuffle=True)
            
            model.eval()
            
            ct, i, val_loss, val_class_loss, val_recon_loss, val_acc = 0, 0, 0.0, 0.0, 0.0, 0.0
            for batch in val_dataloader:
                i += 1
                ct += batch[0].size(0)
                features = batch[0].unsqueeze(1)
                labels = torch.tensor([inv_class_dict[i] for i in batch[1]['instrument_family_str']], dtype=torch.long)
                logits, recon = model(features)
                
                class_loss = class_loss_fn(logits, labels)
                recon_loss = recon_loss_fn(recon, features)
                loss = class_loss + recon_loss
                acc = (logits.argmax(-1) == labels).float().mean()
                
                val_loss += (loss.item() - val_loss)/i
                val_class_loss += (class_loss.item() - val_class_loss)/i
                val_recon_loss += (recon_loss.item() - val_recon_loss)/i
                val_acc += (acc.item() - val_acc)/i
                
                if ct >= hparams['n_val_samples']:
                    break
                
            print_and_log('Val - loss : (total - {:.4f}, class - {:.4f}, recon - {:.4f}), acc : {:.4f}'\
                          .format(val_loss, val_class_loss, val_recon_loss, val_acc), log_file)
            
            librosa.display.specshow(features[0][0].numpy(), sr=16000, x_axis='time', y_axis='hz')
            plt.show()
            librosa.display.specshow(recon[0][0].detach().numpy(), sr=16000, x_axis='time', y_axis='hz')
            plt.show()
            
            if val_loss < best_loss:
                since_best = 0
                best_loss = val_loss
                
                # save weights
                if ckpt_weights_path:
                    os.remove(ckpt_weights_path)
                ckpt_weights_path = os.path.join(results_dir, 'model-{}.weights'.format(itr))
                torch.save(model.state_dict(), ckpt_weights_path)
                print_and_log('Weights saved in {}'.format(ckpt_weights_path), log_file)
                
                # save meta information
                ckpt_meta_path = os.path.join(results_dir, 'checkpoint')
                torch.save({
                    'best_loss' : best_loss,
                    'epoch' : epoch,
                    'itr' : itr,
                    'optimizer' : optimizer.state_dict(),
                    'model' : model.state_dict()
                }, ckpt_meta_path)
            else:
                since_best += 1
                if since_best >= hparams['n_early_stopping']:
                    done = True
                    print_and_log('Early stopping... training complete', log_file)
            
            model.train()

[0,     1] loss : (total - 3.7081, class - 2.5680, recon - 1.1401), acc : 0.0938
[0,   100] loss : (total - 1.7914, class - 1.7054, recon - 0.0860), acc : 0.4219
[0,   200] loss : (total - 1.9315, class - 1.9004, recon - 0.0311), acc : 0.3125
[0,   300] loss : (total - 1.5662, class - 1.5454, recon - 0.0208), acc : 0.4062
[0,   400] loss : (total - 1.2716, class - 1.2469, recon - 0.0247), acc : 0.4219
[0,   500] loss : (total - 1.3510, class - 1.3299, recon - 0.0211), acc : 0.4688
[0,   600] loss : (total - 1.1750, class - 1.1607, recon - 0.0144), acc : 0.4844
[0,   700] loss : (total - 0.9242, class - 0.9022, recon - 0.0221), acc : 0.6250
[0,   800] loss : (total - 1.3098, class - 1.2951, recon - 0.0147), acc : 0.4531
[0,   900] loss : (total - 1.6469, class - 1.6286, recon - 0.0184), acc : 0.4844
[0,  1000] loss : (total - 0.8387, class - 0.8159, recon - 0.0228), acc : 0.7031
Val - loss : (total - 4.0125, class - 3.9871, recon - 0.0254), acc : 0.1543
Weights saved in train_results/mu