In [None]:
import pickle
import fluidsynth
import pretty_midi

import numpy as np
import IPython.display
import matplotlib.pyplot as plt

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import StochasticWeightAveraging

from torch import nn
from torch.utils.data import DataLoader

from utils.data import *
from utils.loss import *
from utils.model import *
from utils.metric import *
from utils.common_utils import *

%load_ext autoreload
%autoreload 2

In [None]:
# init seed
random_seed = 0
pl.seed_everything(random_seed)

In [None]:
# initialize model with GPU
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
print(device)

### Load Data

In [None]:
path = './data/code_list_num_dict_512_kick.pkl'
with open(path, 'rb') as f:
    data = np.stack(pickle.load(f))
                      
print('data shape :', data.shape)

In [None]:
# shuffle and split
num_data = len(data)
num_train = int(num_data * 0.95)

train_data = data[:num_train]
val_data = data[num_train:]

print('The number of train: %d' % len(train_data))
print('The number of validation: %d' % len(val_data))

In [None]:
# dataloader
batch_size = 2048
train_params = {'batch_size': batch_size,
                'shuffle': True,
                'pin_memory': True,
                'num_workers': 4}

val_params = train_params.copy()
val_params['shuffle'] = False

train_set = DataLoader(DatasetSamplerInt(train_data), **train_params)
val_set = DataLoader(DatasetSamplerInt(val_data), **val_params)

### Get Model

In [None]:
# model
embed_size = 128
hidden_size = 512
vocab_size = 512

model = LSTM(embed_size, hidden_size, vocab_size, num_layers=4)
swa_callback = StochasticWeightAveraging(swa_epoch_start=0.7, swa_lrs=5e-5, annealing_epochs=10)
checkpoint_callback = ModelCheckpoint(monitor='val_loss',
                                      mode='max',
                                      filename='LSTM-{epoch:02d}-{val_loss:.4f}')

In [None]:
# training
trainer = pl.Trainer(gpus=[1],
                     num_nodes=1,
                     max_epochs=100,
                     deterministic=True,
                     default_root_dir='./model',
                     callbacks=[swa_callback, checkpoint_callback])

trainer.fit(model, train_set, val_set)

In [None]:
print('best model path :', checkpoint_callback.best_model_path)
print('final results :', trainer.logged_metrics)

### Sampling

In [None]:
# get x1 prob
num_classes = 512
prob_x1 = np.sum(np.eye(num_classes)[data[:, 0]], axis=0)
prob_x1 = prob_x1 / np.sum(prob_x1)

In [None]:
# plot parameters
CHAR_FONT_SIZE = 15
NUM_FONT_SIZE = 12
WIDTH = 17
HEIGHT = 5
LABEL_PAD = 13

# plotting
plt.figure(figsize=(WIDTH, HEIGHT))
plt.bar(np.arange(num_classes), prob_x1)
plt.xticks(fontsize=NUM_FONT_SIZE)
plt.yticks(fontsize=NUM_FONT_SIZE)
plt.xlabel('Classes', fontsize=CHAR_FONT_SIZE, labelpad=LABEL_PAD)
plt.ylabel('Count', fontsize=CHAR_FONT_SIZE, labelpad=LABEL_PAD)
plt.show()

In [None]:
# load LSTM
AR_model = LSTM(embed_size, hidden_size, vocab_size, num_layers=4)

ckpt_path = './model/kick/num_dict_512/LSTM-epoch=98-val_loss=0.7666.ckpt'
AR_model = AR_model.load_from_checkpoint(ckpt_path,
                                         embed_size=embed_size,
                                         hidden_size=hidden_size,
                                         vocab_size=vocab_size,
                                         num_layers=4)

In [None]:
# load VQ-VAE
ch = 128
num_pitch = 57
latent_dim = 16
num_embed = 512

ckpt_path = './model/kick/num_dict_512/VQVAE-epoch=369-val_loss=0.0066.ckpt'
VQVAE_model = VQVAE(ch, num_pitch, latent_dim, num_embed)
VQVAE_model = VQVAE_model.load_from_checkpoint(ckpt_path,
                                               ch=ch,
                                               num_pitch=num_pitch,
                                               num_embed=num_embed,
                                               latent_dim=latent_dim)

### Rejection Sampling

In [None]:
# load detector
num_features = 28
AE = AutoEncoder(num_features)

center = [3.4064, -2.3389, -2.8335, -1.2972, -2.0128, -1.1937, 1.1904]
center = torch.as_tensor(center)

ckpt_path = './model/SVDD-epoch=562-val_loss=0.06.ckpt'
SVDD_model = SVDD.load_from_checkpoint(ckpt_path, encoder=AE.encoder, center=center.to(device))

In [None]:
# data generation under loop score
gen_data =[]
start_time = time()

while True:
    temp_code = sampling_code(batch_size, num_data//10, AR_model, 
                              num_classes, prob_x1, device, 
                              len_code=32, top_k=30, top_p=0, temp=0.7)

    temp_data = sampling_from_code(batch_size, temp_code, VQVAE_model, device)
    
    for data in temp_data:
        if loop_score(data[np.newaxis], SVDD_model, center) < 0.001:
            gen_data.append(data)
    
    if len(gen_data) >= num_data:
        break

gen_data = np.stack(gen_data[:num_data])
print('gen_data shape : %s (%0.3f sec)' % (gen_data.shape, time()-start_time))

### Evaluation

In [None]:
path = './data/midi_detected_strict_kick_pianoroll.pkl'
with open(path, 'rb') as f:
    original_data = np.stack(pickle.load(f))

print('The number of data : %d' % len(original_data))

In [None]:
# shuffle and split
num_data_origin = len(original_data)
num_train_origin = int(num_data_origin * 0.8)

original_data = original_data[:num_train_origin]

print('The number of data: %d' % len(original_data))

In [None]:
# play demo
pm = play_pianoroll(gen_data[1], fs=9)
IPython.display.display(IPython.display.Audio(pm.fluidsynth(fs=16000), rate=16000))

In [None]:
print('loop_score : %f' % loop_score(gen_data, SVDD_model, center))
print('unique_pitch : %f' % unique_pitch(gen_data))
print('note_density : %f' % note_density(gen_data))

In [None]:
# precision & recall & diversity & coverage
metrics = compute_prd(original_data, gen_data, k=5, repeat=10)

print('precision : %0.3f (%0.3f)' % (np.mean(metrics['precision']), np.std(metrics['precision'])))
print('recall : %0.3f (%0.3f)' % (np.mean(metrics['recall']), np.std(metrics['recall'])))
print('density : %0.3f (%0.3f)' % (np.mean(metrics['density']), np.std(metrics['density'])))
print('coverage : %0.3f (%0.3f)' % (np.mean(metrics['coverage']), np.std(metrics['coverage'])))