!pip install editdistance

In [21]:
%load_ext autoreload
%autoreload 2

import os, sys
import datetime as dt

import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

import models
import encoders
import decoders
#from training import train, test, checkpoint, main_wandb, main_simple
from dataloaders import load_data, make_loaders, append_SOS


project_dir = '/scratch/users/udemir15/ELEC491/bassline_transcription'
sys.path.insert(0, project_dir)

from utilities import *

from bassline_transcriber.transcription import decode_NN_output
from MIDI_output import create_MIDI_file


SEED = 27

np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
cuda


In [54]:
from utilities import *

!pip install hmmlearn

In [29]:
M = 8

data_params = {'dataset_path': '/scratch/users/udemir15/ELEC491/bassline_transcription/data/datasets/[28, 51]',
               'dataset_name': 'TechHouse_bassline_representations',
               'scale_type': 'min',
               'M': M}

X, titles = load_data(data_params)

K = X.max()+1 # Number of classes, assumes consecutive [0,max] inclusive
sequence_length = X.shape[1]

print('Number of classes: {}\nSequence Length: {}'.format(K, sequence_length))
print('Number of data points: {}'.format(X.shape[0]))

Number of classes: 26
Sequence Length: 64
Number of data points: 4421


In [7]:
X.shape

(4350, 128)

In [4]:
from hmmlearn.hmm import MultinomialHMM
from sklearn.metrics import silhouette_score
##silhouette_score(X,pred)

In [65]:
hmm = MultinomialHMM(25, n_iter=100, verbose=True)

In [66]:
hmm.fit(X.reshape(-1,1), [64]*X.shape[0])

print(hmm.monitor_.converged)

print(hmm.monitor_)

         1     -925755.9588             +nan
         2     -577224.3674     +348531.5914
         3     -576055.5558       +1168.8116
         4     -574771.9979       +1283.5580
         5     -573116.5109       +1655.4870
         6     -570873.8142       +2242.6967
         7     -567803.3213       +3070.4929
         8     -563780.8777       +4022.4436
         9     -559230.7354       +4550.1423
        10     -555183.1357       +4047.5997
        11     -552321.1427       +2861.9930
        12     -550442.5702       +1878.5725
        13     -549049.7636       +1392.8067
        14     -547815.4249       +1234.3387
        15     -546598.7601       +1216.6648
        16     -545360.3313       +1238.4287
        17     -544111.7170       +1248.6144
        18     -542889.4274       +1222.2896
        19     -541730.8003       +1158.6270
        20     -540647.5808       +1083.2196
        21     -539606.7829       +1040.7979
        22     -538528.4237       +1078.3592
        23

True
ConvergenceMonitor(
    history=[-925755.9587753956, -577224.3674038842, -576055.5558279086, -574771.9978633412, -573116.5109050425, -570873.8142064767, -567803.3213031514, -563780.8776677409, -559230.7353753791, -555183.1356528824, -552321.1426871789, -550442.5702367916, -549049.7635746846, -547815.424891451, -546598.7600616371, -545360.3313426548, -544111.7169911034, -542889.4273623001, -541730.8003302331, -540647.5807712049, -539606.7828883288, -538528.4236706565, -537301.9818854539, -535823.0359678883, -534053.0969608193, -532077.9439363109, -530090.3110788506, -528269.778014226, -526681.9475412602, -525297.5557430293, -524065.8496330399, -522956.51443790866, -521961.55640456046, -521081.99266018503, -520317.0897436258, -519660.52943938755, -519101.1026222488, -518624.900895018, -518217.3639250089, -517864.63875888044, -517554.24915838725, -517275.28003189445, -517018.32038554037, -516775.38947405847, -516540.046940825, -516307.81581619085, -516076.8019885993, -515847.97810975

       100     -501952.6584        +738.3503


In [56]:
hmm.monitor_.converged

True

pred = hmm.predict(X)

ConvergenceMonitor(
    history=[-900547.9834174844, -577580.7960102715, -576574.2808935986, -575581.3656861598, -574378.9234676844, -572663.6231316498, -570063.5901780436, -566322.3389531954, -561772.5935705422, -557467.7335203668],
    iter=10,
    n_iter=10,
    tol=0.01,
    verbose=False,
)

In [62]:
code = hmm.sample(64)[0].flatten()
print('HMM OUTPUT:\n{}'.format(code))
midi_number_array = code_to_MIDI(code)
print('\nCODE:\n{}\n'.format(midi_number_array))

print_beat_matrix(midi_number_array, M, SIL=0, SUS=100)
print_transposed_beat_matrix(midi_number_array, M, SIL=0, SUS=100)

HMM OUTPUT:
[ 0 25  6 11  7 16 25 25 10 25  0 25 25 25 25  6  0 12 25 25  0 25  8  0
 25  9 25 25  0 16 25 25 13  0 25  9 25  0 25  8  0  2 25 25 25 25 25 20
 25  0  0 25 21 11 25 21 25 25 25  0  0 25 25 21]

CODE:
[  0 100  33  38  34  43 100 100  37 100   0 100 100 100 100  33   0  39
 100 100   0 100  35   0 100  36 100 100   0  43 100 100  40   0 100  36
 100   0 100  35   0  29 100 100 100 100 100  47 100   0   0 100  48  38
 100  48 100 100 100   0   0 100 100  48]

SIL: 0, SUS: 100

        Bar 0            Bar 1        
Beat 0: [  0 100  33  38]   [  0  39 100 100]
Beat 1: [ 34  43 100 100]   [  0 100  35   0]
Beat 2: [ 37 100   0 100]   [100  36 100 100]
Beat 3: [100 100 100  33]   [  0  43 100 100]

        Bar 2            Bar 3        
Beat 0: [ 40   0 100  36]   [100   0   0 100]
Beat 1: [100   0 100  35]   [ 48  38 100  48]
Beat 2: [  0  29 100 100]   [100 100 100   0]
Beat 3: [100 100 100  47]   [  0 100 100  48]
SIL: 0, SUS: 100

       Beat 0           Beat 1       
Ba

In [59]:
midi_array = decode_NN_output(code_to_MIDI(bassline), 8, M, sustain_code=SUSTAIN_CODE)

BPM = 125 

create_MIDI_file(midi_array, BPM, 'lol5', midi_dir[str(M)])

In [23]:
hmm.score(X.reshape(-1,1), [64]*X.shape[0])

-573585.5920598782

In [25]:
directories = get_directories(project_dir+'/data/directories.json')

midi_dir = directories['midi']['generated']

SUSTAIN = 25 
SILENCE = 0

#SOS = None
MAX_NOTE = 51 
MIN_NOTE = 28

SILENCE_CODE = 0
SUSTAIN_CODE = 100

def code_to_MIDI(code):
    # takes code in the form silence, note,..., sustain
    X = code.copy()
    
    X[X!=0] += MIN_NOTE-1
    
    X[X==X.max()] = 100 # max will be the sustain
        
    return X

In [25]:
print_beat_matrix(bassline, M, SIL=0, SUS=25)

SIL: 0, SUS: 25

        Bar 0            Bar 1        
Beat 0: [ 0  9  6 25]   [ 0 15 25 11]
Beat 1: [13  0  8  0]   [25 23 25 13]
Beat 2: [25  0 17 25]   [25 15  0 25]
Beat 3: [ 9 25  9 12]   [ 0 14 25 25]

        Bar 2            Bar 3        
Beat 0: [ 0  5 25 25]   [ 0 25  0 25]
Beat 1: [25 25 12 25]   [ 0  0 25  0]
Beat 2: [25  4 25 25]   [25  0  9 12]
Beat 3: [21  0 25  9]   [25 21 25 25]


In [20]:
def print_transposed_beat_matrix(representation, M, SIL=1, SUS=26, N_bars=4):
    representation = representation.reshape((N_bars,4, 4*(8//M)))       
    ppb = 32//M # points per beat, 32 comes from the pYIN frame size
    tab = 2*ppb + (ppb-1)+ 2 # pretty print
    print('SIL: {}, SUS: {}'.format(SIL, SUS))
    for i in range(N_bars//2):
        print('\n{:>7}{:<{}}  {:<{}}'.format(' ','Beat {}'.format(2*i), tab+2, 'Beat {}'.format(2*i+1), tab))
        for j in range(4):
            print('Bar {}: {}   {}'.format(j, representation[j,2*i,:], representation[j,2*i+1,:]))
            
def print_beat_matrix(representation, M, SIL=1, SUS=26, N_bars=4):    
    representation = representation.reshape((N_bars,4, 4*(8//M)))       
    ppb = 32//M # points per beat, 32 comes from the pYIN frame size
    tab = 2*ppb + (ppb-1)+ 2 # pretty print
    print('SIL: {}, SUS: {}'.format(SIL, SUS))
    for i in range(N_bars//2):
        print('\n{:>8}{:<{}}  {:<{}}'.format(' ','Bar {}'.format(2*i), tab+2, 'Bar {}'.format(2*i+1), tab))
        for j in range(4):
            print('Beat {}: {}   {}'.format(j, representation[2*i,j,:], representation[2*i+1,j,:]))