# MusciVAE

- input: 130차원(velocity) -> T, 9
- drum class: 9 canonical classes
- 16개의 note intervals -> 16 events
- 2-bar data -> T=32, 16-bar data -> T=256
- U = 16 -> 1개의 bar 안에 16개

- 2-bar -> T는 32 = 16*2 (bar 안에 16개의 note * 2 bar)
- input: T, 130, 9
- bar별로 16개의 note들
- rnn model: slide 1bar
- output: U=16

In [None]:
# dataset structure
"""
    Create a pandas to store the piano rolls in.
    It'll look a bit like this:
    
    |   Index     | t | C0 |...| C8 |
    |Song_name_3:0| 0 | 40 |...| 0   |
    |             | 1 | 40 |...| 0   |
    |             |...|... |...| ... |
    |             | N | 40 |...| 0   |
    |Song_name_3:1| 0 | 40 |...| 0   |
    |             |...|... |...| ... |
    |             | N | 40 |...| 0   |
    |Song_name_4:0| 0 | 40 |...| 0   |
    |             |...|... |...| ... |
    |             | N | 40 |...| 0   |

"""

In [2]:
import pandas as pd
import pretty_midi
import os
from getpass import getuser
import numpy as np
from yullab.utils import class_map
from yullab.prep import generate_midi_df, BarTransform
from yullab.datasets import GrooveDataset
from torch.utils.data import DataLoader

In [3]:
username = getuser()
base_path = f'/Users/{username}/data/midi/groove'
info_csv_filename = os.path.join(base_path, 'info.csv')
midi_csv_filename = os.path.join(base_path, 'midi.csv')

info_df = pd.read_csv(info_csv_filename)
info_df

Unnamed: 0,drummer,session,id,style,bpm,beat_type,time_signature,midi_filename,audio_filename,duration,split
0,drummer1,drummer1/eval_session,drummer1/eval_session/1,funk/groove1,138,beat,4-4,drummer1/eval_session/1_funk-groove1_138_beat_...,drummer1/eval_session/1_funk-groove1_138_beat_...,27.872308,test
1,drummer1,drummer1/eval_session,drummer1/eval_session/10,soul/groove10,102,beat,4-4,drummer1/eval_session/10_soul-groove10_102_bea...,drummer1/eval_session/10_soul-groove10_102_bea...,37.691158,test
2,drummer1,drummer1/eval_session,drummer1/eval_session/2,funk/groove2,105,beat,4-4,drummer1/eval_session/2_funk-groove2_105_beat_...,drummer1/eval_session/2_funk-groove2_105_beat_...,36.351218,test
3,drummer1,drummer1/eval_session,drummer1/eval_session/3,soul/groove3,86,beat,4-4,drummer1/eval_session/3_soul-groove3_86_beat_4...,drummer1/eval_session/3_soul-groove3_86_beat_4...,44.716543,test
4,drummer1,drummer1/eval_session,drummer1/eval_session/4,soul/groove4,80,beat,4-4,drummer1/eval_session/4_soul-groove4_80_beat_4...,drummer1/eval_session/4_soul-groove4_80_beat_4...,47.987500,test
...,...,...,...,...,...,...,...,...,...,...,...
1145,drummer2,drummer2/session2,drummer2/session2/11,rock,130,beat,4-4,drummer2/session2/11_rock_130_beat_4-4.mid,,1.909613,train
1146,drummer2,drummer2/session2,drummer2/session2/12,rock,130,beat,4-4,drummer2/session2/12_rock_130_beat_4-4.mid,,1.808652,train
1147,drummer2,drummer2/session2,drummer2/session2/13,rock,130,beat,4-4,drummer2/session2/13_rock_130_beat_4-4.mid,,1.864421,train
1148,drummer2,drummer2/session2,drummer2/session2/14,rock,130,beat,4-4,drummer2/session2/14_rock_130_beat_4-4.mid,,1.875960,train


In [4]:
midi_df = pd.read_csv(midi_csv_filename)
midi_df

Unnamed: 0,filename,timestep,0,1,2,3,4,5,6,7,8
0,1_funk_80_beat_4-4,0,0.000000,0.055118,0.000000,0.0,0.0,0.0,0.0,0.0,0.000000
1,1_funk_80_beat_4-4,1,0.000000,0.078740,0.000000,0.0,0.0,0.0,0.0,0.0,0.000000
2,1_funk_80_beat_4-4,2,0.000000,0.440945,0.000000,0.0,0.0,0.0,0.0,0.0,0.000000
3,1_funk_80_beat_4-4,3,0.000000,0.000000,0.511811,0.0,0.0,0.0,0.0,0.0,0.000000
4,1_funk_80_beat_4-4,4,0.000000,0.377953,0.000000,0.0,0.0,0.0,0.0,0.0,0.000000
...,...,...,...,...,...,...,...,...,...,...,...
357613,15_rock_130_beat_4-4,33,0.000000,0.275591,0.000000,0.0,0.0,0.0,0.0,0.0,0.000000
357614,15_rock_130_beat_4-4,34,0.488189,0.000000,0.000000,0.0,0.0,0.0,0.0,0.0,0.000000
357615,15_rock_130_beat_4-4,35,0.000000,0.259843,0.000000,0.0,0.0,0.0,0.0,0.0,0.000000
357616,15_rock_130_beat_4-4,36,0.000000,0.000000,0.000000,0.0,0.0,0.0,0.0,0.0,0.614173


In [82]:
sample_filenames = info_df[info_df.split.eq('train')]['midi_filename']
drum_df = generate_midi_df(base_path, sample_filenames)
# drum_df.to_csv(os.path.join(base_path, 'midi.csv'))

In [249]:
transform = BarTransform(2)
dataset = GrooveDataset(midi_csv_filename, transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=False)
sample = next(iter(dataloader))

In [250]:
sample.shape #batch, time, feature

torch.Size([16, 32, 9])

In [265]:
from yullab.musicvae import LSTMEncoder
bars = 2
seq_len = bars * 16 # 한마디 안에 4개가 맞아?
feature_size = sample.shape[-1]

encoder_num_layers = 4
batch_size = 16
encoder_hidden_size = 128
latent_dimension = 512

conductor_num_layers = 4
conductor_hidden_size = 256
u_size = bars
lstm_conductor_input_size = 1

decoder_num_layers = 2
decoder_hidden_size = 64
decoder_lstm_hidden_size = 32

lstm_encoder = LSTMEncoder(feature_size, num_layers=encoder_num_layers, hidden_size=encoder_hidden_size)
z_dist = lstm_encoder(sample.float())
z_dist.rsample().shape

torch.Size([16, 512])

In [257]:
from yullab.musicvae import Conductor

conductor = Conductor(latent_dimension, num_layers=conductor_num_layers, u_size=u_size, hidden_size=conductor_hidden_size)
embeddings = conductor(z_dist.rsample())
len(embeddings), embeddings[0].shape

(2, torch.Size([16, 256]))

In [262]:
hidden_size = decoder_hidden_size
num_layers = decoder_num_layers
lstm_hidden_size = decoder_lstm_hidden_size
counter_size = 0

outputs = []
previous = torch.zeros((batch_size, feature_size))

fc = nn.Linear(in_features=conductor_hidden_size, out_features=hidden_size*num_layers*2)
self.lstm_l2_decoder_cell_1 = nn.LSTMCell(input_size=hidden_size+input_size, hidden_size=lstm_l2_decoder_hidden_size)
self.lstm_l2_decoder_cell_2 = nn.LSTMCell(input_size=lstm_l2_decoder_hidden_size, hidden_size=lstm_l2_decoder_hidden_size)


embedding = embeddings[0]

t = torch.tanh(fc(embedding))
hidden_states, cell_states = t[None, :, :hidden_size*num_layers], t[None, :, hidden_size*num_layers:]
hidden_states = hidden_states.reshape([num_layers, -1, hidden_size])
cell_states = cell_states.reshape([num_layers, -1, hidden_size])

inputs = torch.cat((embedding, previous), dim=1)


In [264]:
# todo
cell에 넣기 input은 embdding과 이전 ouput 합쳐진거
2개의 셀을 만들고 각각 h1, h2를 넣어서 내보내기

torch.Size([16, 265])

In [259]:
seq_len//u_size

16