In [None]:
import pickle
import random
import fluidsynth
import pretty_midi
import numpy as np

from time import time
from scipy import stats
from scipy.special import boxcox, inv_boxcox

import IPython.display
import matplotlib.pyplot as plt

import torch
import pytorch_lightning as pl

from torch.utils.data import DataLoader
from pytorch_lightning.callbacks import ModelCheckpoint

from utils.data import *
from utils.model import *
from utils.common_utils import *

%load_ext autoreload
%autoreload 2

plt.style.use('seaborn')

In [None]:
# init seed
random_seed = 0
pl.seed_everything(random_seed)

In [None]:
# initialize available devices
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
print('using', device)

### Load Data

In [None]:
path = './data/wav_cov.pkl'
with open(path, 'rb') as f:
    data = pickle.load(f)

print('the number of data :', data.shape)

In [None]:
# shuffle and split
random.shuffle(data)

num_data = data.shape[0]
num_train = int(num_data * 0.8)

train_data = data[:num_train]
val_data = data[num_train:]

print('the number of train :', train_data.shape)
print('the number of validation :', val_data.shape)

In [None]:
# dataloader
batch_size = 64
train_params = {'batch_size': batch_size, 'shuffle': True, 'pin_memory': True, 'num_workers': 4}
val_params = {'batch_size': batch_size, 'shuffle': False, 'pin_memory': True, 'num_workers': 4}

train_set = DataLoader(DatasetSampler(train_data), **train_params)
val_set = DataLoader(DatasetSampler(val_data), **val_params)

### Get Autoencoder Model

In [None]:
# model
num_features = data.shape[1]
AE = AutoEncoder(num_features)
checkpoint_callback = ModelCheckpoint(monitor='val_loss',
                                      filename='loop-detection-{epoch:02d}-{val_loss:.2f}')

In [None]:
# training
trainer = pl.Trainer(gpus=1,
                     num_nodes=1,
                     max_epochs=1000,
                     deterministic=True,
                     default_root_dir='./model',
                     callbacks=[checkpoint_callback])

trainer.fit(AE, train_set, val_set)

print('best model path :', checkpoint_callback.best_model_path)
print('final results :', trainer.logged_metrics)

In [None]:
# # load model
# ckpt_path = './model/loop-detection-epoch=927-val_loss=0.02.ckpt'
# AE = AE.load_from_checkpoint(ckpt_path, num_features=num_features)

### Get Center from Autoencoder

In [None]:
train_z = []

AE.eval()
with torch.no_grad():
    for batch_idx, x_train in enumerate(train_set):
        z = AE.encoder(x_train)
        train_z.append(z)
        
center = torch.vstack(train_z).mean(0)
print('center shape :', center.shape)

In [None]:
# center = [3.4064, -2.3389, -2.8335, -1.2972, -2.0128, -1.1937, 1.1904]
# center = torch.as_tensor(center)

### Get Deep SVDD

In [None]:
# model
AE.train()
SVDD_model = SVDD(AE.encoder, center.to(device))
checkpoint_callback = ModelCheckpoint(monitor='val_loss',
                                      filename='SVDD-{epoch:02d}-{val_loss:.2f}')

In [None]:
# training
trainer = pl.Trainer(gpus=1,
                     num_nodes=1,
                     max_epochs=1000,
                     deterministic=True,
                     default_root_dir='./model',
                     callbacks=[checkpoint_callback])

trainer.fit(SVDD_model, train_set, val_set)

print('best model path :', checkpoint_callback.best_model_path)
print('final results :', trainer.logged_metrics)

In [None]:
# # load SVDD
# ckpt_path = './model/SVDD-epoch=562-val_loss=0.06.ckpt'
# SVDD_model = SVDD_model.load_from_checkpoint(ckpt_path, encoder=AE.encoder, center=center.to(device))

In [None]:
def get_dist_from_SVDD(data_set, model, center):
    z_set = []
    
    model.eval()
    with torch.no_grad():
        for batch_idx, x in enumerate(data_set):
            z = model(x)
            z_set.append(z)
            
    z_set = torch.vstack(z_set)
    
    # compute distance
    dist = z_set - center.unsqueeze(0)
    dist = dist.square().mean(1)
    dist = dist.cpu().detach().numpy()
    
    return dist

In [None]:
train_dist = get_dist_from_SVDD(train_set, SVDD_model, center)
val_dist = get_dist_from_SVDD(val_set, SVDD_model, center)

print('train dist :', train_dist.shape)
print('val dist :', val_dist.shape)

### Loop Detection

In [None]:
path = './data/midi_full_strict_pianoroll.pkl'
with open(path, 'rb') as f:
    midi = pickle.load(f)

print('the number of data :', len(midi))

In [None]:
midi_dist = []
start_time = time()

SVDD_model.eval()
with torch.no_grad():
    for i in range(len(midi)):
        pianoroll = midi[i].todense()
        
        inputs = torch.from_numpy(get_xor_corr(pianoroll))
        z = SVDD_model(inputs.float())

        # compute distance
        dist = z - center
        dist = dist.square().mean()
        midi_dist.append(dist.item())
            
        if i % 1e6 == 0:
            print('I am on %d (%0.3f sec)' % (i, time()-start_time))
            start_time = time()

In [None]:
# get threshold
boxcox_train_dist = boxcox(train_dist, 0)
box_cox_thres = np.mean(boxcox_train_dist) + np.std(boxcox_train_dist)
thres = inv_boxcox(box_cox_thres, 0)

In [None]:
# plot parameters
CHAR_FONT_SIZE = 17
NUM_FONT_SIZE = 15
WIDTH = HEIGHT = 8
LABEL_PAD = 13

# plotting
plt.figure(figsize=(WIDTH, HEIGHT))
plt.hist(train_dist, facecolor='tab:blue', bins=50, alpha=0.8)
plt.hist(val_dist, facecolor='tab:orange', bins=50, alpha=0.4)
plt.xticks(fontsize=NUM_FONT_SIZE)
plt.yticks(fontsize=NUM_FONT_SIZE)
plt.yscale('log')
plt.xlabel('Loop Score', fontsize=CHAR_FONT_SIZE, labelpad=LABEL_PAD)
plt.ylabel('Frequency', fontsize=CHAR_FONT_SIZE, labelpad=LABEL_PAD)
plt.legend(['wav_train', 'wav_val'], fontsize=CHAR_FONT_SIZE)
plt.savefig('./images/distance.png', dpi=1000, bbox_inches='tight', pad_inches=0)
plt.show()

In [None]:
# plotting
plt.figure(figsize=(WIDTH, HEIGHT))
plt.hist(midi_dist, bins=50, alpha=0.8)
plt.xticks(fontsize=NUM_FONT_SIZE)
plt.yticks(fontsize=NUM_FONT_SIZE)
plt.yscale('log')
plt.xlabel('Loop Score', fontsize=CHAR_FONT_SIZE, labelpad=LABEL_PAD)
plt.ylabel('Frequency', fontsize=CHAR_FONT_SIZE, labelpad=LABEL_PAD)
plt.legend(['midi'], fontsize=CHAR_FONT_SIZE)
# plt.axvline(x=thres, color='tab:red', linestyle='--', linewidth=3.5)
plt.savefig('./images/midi_distance.png', dpi=1000, bbox_inches='tight', pad_inches=0)
plt.show()

### Get Loop from Loop Detection

In [None]:
loop_data = []
start_time = time()

SVDD_model.eval()
with torch.no_grad():
    for i in range(len(midi)):
        pianoroll = midi[i].todense()
        
        inputs = torch.from_numpy(get_xor_corr(pianoroll))
        z = SVDD_model(inputs.float())

        # compute distance
        dist = z - center
        dist = dist.square().mean().item()

        if dist < thres:
            loop_data.append(pianoroll)
            
        if i % 1e6 == 0:
            print('I am on %d (%0.3f sec)' % (i, time()-start_time))
            start_time = time()

### Test

In [None]:
# play demo
# random.shuffle(loop_list)
for loop in loop_data[0:]:
    pm = play_pianoroll(loop, fs=8)
    IPython.display.display(IPython.display.Audio(pm.fluidsynth(fs=16000), rate=16000))
    break

In [None]:
SVDD_model.eval()
with torch.no_grad():
    inputs = torch.from_numpy(get_xor_corr(loop))
    z = SVDD_model(inputs.float())

    # compute distance
    dist = z - center
    dist = dist.square().mean().item()
    
    print(dist)

In [None]:
# save pickle
save_path = './data/midi_detected_strict_pianoroll.pkl'
with open(save_path, 'wb') as f:
    print('File saved!')
    pickle.dump(loop_data, f, protocol=pickle.HIGHEST_PROTOCOL)
    
print('the number of data :', loop_data.shape)

### Loop Detector Evaluation

In [None]:
# play demo
while True:
    rand_idx = np.random.randint(len(midi))
    loop = midi[rand_idx]
    
    if loop[0, 0] == 1 and loop[0, 6] == 1:
        break
        
pm = play_pianoroll(loop, fs=8)
IPython.display.display(IPython.display.Audio(pm.fluidsynth(fs=16000), rate=16000))
print(rand_idx)

In [None]:
loop_idx = [5383096, 3978016, 618411, 5186352, 5261371, 5528929, 2975564, 2652515, 1356022, 890982,
            3017643, 2272088, 5086286, 725082, 350768, 2283479, 5011746, 3861869, 4733529, 7228,
            4402904, 511009, 3499391, 39603, 4082163, 1896876, 3159032, 4534427, 2248818, 4270337,
            5648465, 4283338, 516120, 2020602, 1943102, 2959300, 1957977, 5270308, 993345, 2741232,
            3634553, 1172096, 4643747, 2389123, 1983880, 2211931, 3795814, 2613799, 5065649, 4723833,
            4018700, 578129, 578129, 578129, 1167437, 5043236, 686352, 3127576, 5410844, 4623258,
            3977773, 5133501, 3917701, 4792378, 596408, 275077, 5264031, 5574783, 4011830, 2127268,
            773880, 5658692, 4213914, 5511406, 1526841, 5292300, 1035078, 836215, 2335737, 282704,
            1590806, 2724512, 148811, 1990322, 4261188, 51099, 4190463, 1708669, 920486, 3178536,
            1309008, 4529539, 3151077, 3170320, 83849, 5052372, 4776581, 4355603, 1066846, 2585470]

print('the number of loop_idx :', len(loop_idx))

In [None]:
rand_idx = np.random.randint(len(midi), size=400).tolist()
eval_idx = loop_idx + rand_idx

print('the number of eval_idx :', len(eval_idx))

In [None]:
eval_dist = []
start_time = time()

SVDD_model.eval()
with torch.no_grad():
    for idx in eval_idx:
        pianoroll = midi[idx].todense()
        
        inputs = torch.from_numpy(get_xor_corr(pianoroll))
        z = SVDD_model(inputs.float())

        # compute distance
        dist = z - center
        dist = dist.square().mean()
        eval_dist.append(dist.item())

In [None]:
# statistical test
for i in range(5):
    print(stats.ttest_ind(eval_dist[:100], eval_dist[100*i:100*(i+1)]))

In [None]:
# plot parameters
CHAR_FONT_SIZE = 17
NUM_FONT_SIZE = 14
WIDTH = 10; HEIGHT = 5
LABEL_PAD = 13

# plotting
plt.figure(figsize=(WIDTH, HEIGHT))
plt.boxplot([eval_dist[:100], eval_dist[100:200], eval_dist[200:300], eval_dist[300:400], eval_dist[400:500]], 
            vert=False)
plt.yticks([1, 2, 3, 4, 5], ['Loop Set', 'Random Set 1', 'Random Set 2', 'Random Set 3', 'Random Set 4'], fontsize=NUM_FONT_SIZE)
plt.xticks(fontsize=NUM_FONT_SIZE)
plt.xlabel('Loop Score', fontsize=CHAR_FONT_SIZE, labelpad=LABEL_PAD)
plt.ylabel('Groups', fontsize=CHAR_FONT_SIZE, labelpad=LABEL_PAD)
plt.savefig('./images/loop_score.png', dpi=1000, bbox_inches='tight', pad_inches=0)
plt.show()