# 1. load data

In [1]:
import os
import random
import numpy as np
import torch
import torch.utils.data as Data
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable

import pypianoroll as ppr
import pretty_midi
from pypianoroll import Multitrack, Track
from matplotlib import pyplot as plt

In [14]:
from vae_rnn import *

In [3]:
random.seed(0)

In [4]:
train_xs = [ x for x in os.listdir('./../explore_data/data') if '.npy' in x ] 
for i, x in enumerate(train_xs):
    print('[{}]: {}'.format(i, x)) 

[0]: train_x_drum_reduced_World.npy
[1]: train_x_drum_reduced_Country.npy
[2]: train_x_drum_reduced_Punk.npy
[3]: train_x_drum_reduced_Folk.npy
[4]: train_x_drum_reduced_Pop.npy
[5]: train_x_drum_reduced_New-Age.npy
[6]: train_x_drum_reduced_Rock.npy
[7]: train_x_drum_reduced_Metal.npy
[8]: train_x_drum_reduced_Latin.npy
[9]: train_x_drum_reduced_Blues.npy
[10]: train_x_drum_reduced_Electronic.npy
[11]: train_x_drum_reduced_RnB.npy
[12]: train_x_drum_reduced_Rap.npy
[13]: train_x_drum_reduced_Reggae.npy
[14]: train_x_drum_reduced_Jazz.npy


In [5]:
prefix = './../explore_data/data/'
train_x_reduced = np.zeros((0,SEQ_LEN,NUM_FEATURES))

for i, fn in enumerate(train_xs):
    data = np.load(prefix + fn)
    train_x_reduced = np.vstack((train_x_reduced, data))

print(train_x_reduced.shape)

(392836, 96, 9)


In [15]:
LR = 0.015
NUM_EPOCHS = 100
BATCH_SIZE = 256
BETA = 15.0
BEAT = 48

TESTING_RATIO = 0.05
N_DATA = train_x_reduced.shape[0]
N_TRAINING = int(train_x_reduced.shape[0]*TESTING_RATIO)
N_TESTING = N_DATA - N_TRAINING

In [16]:
def pltReducedDrumTrack(track, beat_resolution=12, cmap='Blues'):
    track = np.append(track, np.zeros((track.shape[0], 119)), axis=1)
    # track = np.where(track == 1, 128, 0)
    track = track * 128
    track = Track(pianoroll=track)
    
    fig, axs = track.plot(
        xtick='beat',
        yticklabel='number',
        beat_resolution=beat_resolution,
        cmap=cmap,
    )
    fig.set_size_inches(30,10)
    y = axs.set_ylim(0, 10) # C0 - C2
    y = axs.set_yticks(range(10))
    plt.show()

In [17]:
def parse_data(training_data):
    ratio = TESTING_RATIO
    T = int(training_data.shape[0]*ratio)
    
    train_x = training_data[:-T]
    test_x = training_data[-T:]

    train_x = torch.from_numpy(train_x).type(torch.FloatTensor)
    test_x = torch.from_numpy(test_x).type(torch.FloatTensor)
    
    return train_x, test_x

# 2. training

In [18]:
train_x, test_x = parse_data(train_x_reduced)
train_dataset = Data.TensorDataset(train_x)
test_dataset = Data.TensorDataset(test_x)

train_loader = Data.DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    num_workers=1,
)
test_loader = Data.DataLoader(
    dataset=test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    num_workers=1,
)

In [19]:
encoder = Encoder().to(device)
decoder = Decoder(beat=BEAT).to(device)
vae = VAE(encoder, decoder).to(device)

optimizer = optim.Adam(vae.parameters(), lr=LR)
test_err = 0

In [None]:
for epoch in range(NUM_EPOCHS):
    loss_sum = 0
    bce_sum = 0
    kld_sum = 0
    for batch_i, data in enumerate(train_loader):
        data = Variable(data[0]).type(torch.float32).to(device)
        optimizer.zero_grad()
        data_out = vae(data)
        
        loss, bce, kld = elbo(
            data_out,
            data,
            vae.z_mean,
            vae.z_sigma,
            beta=BETA
        )
        loss.backward()
        optimizer.step()

        loss_sum += loss.data.item()
        bce_sum += bce.data.item()
        kld_sum += kld.data.item()
        
        if batch_i % 5 == 0:
            print('Train Epoch: {} [{:4d}/{} ({:2.0f}%)]      Loss: {:.6f}'.format(
                epoch,
                batch_i * BATCH_SIZE,
                len(train_loader.dataset),
                100. * batch_i / len(train_loader),
                loss.data.item() / BATCH_SIZE))
            print('bce: {:.6f}, kld: {:.6f}'.format(
                bce.data.item() / BATCH_SIZE,
                kld.data.item() / BATCH_SIZE))
    print('====> Epoch: {} Average loss: {:.4f}, bce: {:.4f}, kld: {:.4f}'.format(
        epoch, loss_sum / len(train_loader.dataset),
        bce_sum / len(train_loader.dataset),
        kld_sum / len(train_loader.dataset),
    ))
    
    if epoch % 5 == 0:
        loss_sum_test = 0
        for batch_i, data in enumerate(test_loader):
            with torch.no_grad():
                data = Variable(data[0]).type(torch.float32).to(device)
                data_out = vae(data)

                loss = F.binary_cross_entropy(
                    data_out,
                    data,
                    reduction='sum'
                )
                loss_sum_test += loss.item()

        print('====> Testing Average Loss: {}'.format(
            loss_sum_test / len(test_loader.dataset)))
        test_err = loss_sum_test / len(test_loader.dataset)

In [12]:
from decimal import Decimal
import time

sn_loss = '%.0E' % Decimal(test_err)
sn_lr = '%.0E' % Decimal(LR)
sn_beta = '%.0E' % Decimal(BETA)

t = time.strftime("%Y%m%d_%H%M%S")

model_file_name = '_'.join([
    './models/all/vae',
    'L{}'.format(sn_lr),
    'beta{}'.format(sn_beta),
    'beat{}'.format(BEAT),
    'loss{}'.format(sn_loss),
     ACTIVATION,
    'gru{}'.format(GRU_HIDDEN_SIZE),
    'e{}'.format(NUM_EPOCHS),
    'b{}'.format(BATCH_SIZE),
    'hd{}-{}'.format(LINEAR_HIDDEN_SIZE[0], LINEAR_HIDDEN_SIZE[1]),
    t,
])

print(model_file_name)
torch.save(vae.state_dict(), model_file_name + '.pt')

./models/all/vae_L1E-02_beta2E+01_beat48_loss2E+01_tanh_gru32_e100_b256_hd64-32_20181008_034323


# 3. reconstruction

In [None]:
for batch_i, data in enumerate(test_loader):
    if batch_i > 0:
        break
    with torch.no_grad():
        data = Variable(data[0]).type(torch.float32).to(device)
        data_out = vae(data)
        
        
        for i in range(len(data)):
            if i < 20:
                data_i = data[i].cpu().data.numpy()
                data_o = data_out[i].cpu().data.numpy()
                print(data_o.max(), data_o.min())
                data_o = np.where(data_o > 0.2, 1, 0)
                pltReducedDrumTrack(data_i)
                pltReducedDrumTrack(data_o, cmap='Oranges')