In [None]:
import pickle
import fluidsynth
import pretty_midi

import numpy as np
import IPython.display
import matplotlib.pyplot as plt

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import StochasticWeightAveraging

from torch import nn
from torch.utils.data import DataLoader

from utils.data import *
from utils.loss import *
from utils.model import *
from utils.metric import *
from utils.common_utils import *

%load_ext autoreload
%autoreload 2

In [None]:
# init seed
random_seed = 0
pl.seed_everything(random_seed)

In [None]:
# initialize model with GPU
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
print(device)

### Load Data

In [None]:
path = './data/midi_detected_strict_kick_pianoroll.pkl'
with open(path, 'rb') as f:
    data = pickle.load(f)

print('The number of data : %d' % len(data))

In [None]:
data = np.stack(data)
data = np.transpose(data, (0, 2, 1))

print('data shape :', data.shape)

In [None]:
# shuffle and split
num_data = len(data)
num_train = int(num_data * 0.8)

train_data = data[:num_train]
val_data = data[num_train:]

print('The number of train: %d' % len(train_data))
print('The number of validation: %d' % len(val_data))

In [None]:
# dataloader
batch_size = 2048
train_params = {'batch_size': batch_size,
                'shuffle': True,
                'pin_memory': True,
                'num_workers': 4}

val_params = train_params.copy()
val_params['shuffle'] = False

train_set = DataLoader(DatasetSampler(train_data), **train_params)
val_set = DataLoader(DatasetSampler(val_data), **val_params)

### Get Model

In [None]:
# model
ch = 128
num_pitch = 57
latent_dim = 16
num_embed = 256

model = VQVAE(ch, num_pitch, latent_dim, num_embed, thres=1)
swa_callback = StochasticWeightAveraging(swa_epoch_start=0.7, swa_lrs=5e-5, annealing_epochs=20)
checkpoint_callback = ModelCheckpoint(monitor='val_loss',
                                      filename='VQVAE-{epoch:02d}-{val_loss:.4f}')

In [None]:
# training
trainer = pl.Trainer(gpus=[5],
                     num_nodes=1,
                     max_epochs=500,
                     deterministic=True,
                     default_root_dir='./model',
                     callbacks=[swa_callback, checkpoint_callback])

trainer.fit(model, train_set, val_set)

In [None]:
print('best model path :', checkpoint_callback.best_model_path)
print('final results :', trainer.logged_metrics)

### Get Code

In [None]:
# model
ch = 128
num_pitch = 57
latent_dim = 16
num_embed = 256

ckpt_path = checkpoint_callback.best_model_path
model = VQVAE(ch, num_pitch, latent_dim, num_embed)
model = model.load_from_checkpoint(ckpt_path, ch=ch, num_pitch=num_pitch, latent_dim=latent_dim, num_embed=num_embed)
model = model.to(device)

In [None]:
# get latent z
code_list = []

model.eval()
with torch.no_grad():
    for batch_idx, x in enumerate(train_set):
        x = x.to(device)
        z = model.encoder(x)
        quant_z, quant_idx, _ = model.quantize(z)
        code_list.append(quant_idx.data.cpu().numpy())

code_list = np.vstack(code_list)
print('code_list shape :', code_list.shape)

In [None]:
# save latent codes of training set
path = './data/code_list_num_dict_256.pkl'
with open(path, 'wb') as f:
    pickle.dump(code_list, f, protocol=pickle.HIGHEST_PROTOCOL)