In [1]:
import miditoolkit
import utils

In [2]:
import torch
import torch.nn as nn

In [3]:
# 1. GPU configuration
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


## Load model

In [4]:

class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        
        self.embedding = nn.Embedding(input_dim, emb_dim)
        
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout = dropout)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src):
        
        #src = [src len, batch size]
        
        embedded = self.dropout(self.embedding(src))
        
        #embedded = [src len, batch size, emb dim]
        
        outputs, (hidden, cell) = self.rnn(embedded)
        
        #outputs = [src len, batch size, hid dim * n directions]
        #hidden = [n layers * n directions, batch size, hid dim]
        #cell = [n layers * n directions, batch size, hid dim]
        
        #outputs are always from the top hidden layer
        
        return hidden, cell
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        
        self.output_dim = output_dim
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        
        self.embedding = nn.Embedding(output_dim, emb_dim)
        
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout = dropout)
        
        self.fc_out = nn.Linear(hid_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input, hidden, cell):
        
        #input = [batch size]
        #hidden = [n layers * n directions, batch size, hid dim]
        #cell = [n layers * n directions, batch size, hid dim]
        
        #n directions in the decoder will both always be 1, therefore:
        #hidden = [n layers, batch size, hid dim]
        #context = [n layers, batch size, hid dim]
        
        input = input.unsqueeze(0)
        
        #input = [1, batch size]
        
        embedded = self.dropout(self.embedding(input))
        
        #embedded = [1, batch size, emb dim]
                
        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        
        #output = [seq len, batch size, hid dim * n directions]
        #hidden = [n layers * n directions, batch size, hid dim]
        #cell = [n layers * n directions, batch size, hid dim]
        
        #seq len and n directions will always be 1 in the decoder, therefore:
        #output = [1, batch size, hid dim]
        #hidden = [n layers, batch size, hid dim]
        #cell = [n layers, batch size, hid dim]
        
        prediction = self.fc_out(output.squeeze(0))
        #prediction = [batch size, output dim]
        
        return prediction, hidden, cell
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
        assert encoder.hid_dim == decoder.hid_dim, \
            "Hidden dimensions of encoder and decoder must be equal!"
        assert encoder.n_layers == decoder.n_layers, \
            "Encoder and decoder must have equal number of layers!"
        
    def forward(self, src):
        
        #src = [src len, batch size]
        #trg = [trg len, batch size]
        #teacher_forcing_ratio is probability to use teacher forcing
        #e.g. if teacher_forcing_ratio is 0.75 we use ground-truth inputs 75% of the time
        
        ########################################
        ### inference
        ########################################
        #pdb.set_trace()
        
        hidden, cell = self.encoder(src)
        
        input = torch.IntTensor([int(0)]).cuda() ## sos
        eos = False
        infer_melody_list = []
        
        cnt = 0
        while not(eos):
            output, hidden, cell = self.decoder(input, hidden, cell)
            top1 = output.argmax(1)
            infer_melody_list.append(top1.item())
            input = top1
            print(top1.item())
            cnt+=1
            if top1.item() == 308:
                eos = True
            if cnt == 5000:
                break
        return infer_melody_list

In [5]:
'''
INPUT_DIM = 60
ENC_EMB_DIM = 16

OUTPUT_DIM = 309
DEC_EMB_DIM = 32

HID_DIM = 512
N_LAYERS = 2
'''
INPUT_DIM = 60
ENC_EMB_DIM = 32

OUTPUT_DIM = 309
DEC_EMB_DIM = 64

HID_DIM = 512
N_LAYERS = 3

ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5

enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT)

model = Seq2Seq(enc, dec, device).to(device)

In [21]:
saved_model_path =  './model_ckpt/0628_12_epoch_6000_step.pt'
model.load_state_dict(torch.load(saved_model_path))
model.eval()

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(60, 32)
    (rnn): LSTM(32, 512, num_layers=3, dropout=0.5)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Decoder(
    (embedding): Embedding(309, 64)
    (rnn): LSTM(64, 512, num_layers=3, dropout=0.5)
    (fc_out): Linear(in_features=512, out_features=309, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
)

## Inference

In [22]:
import pickle5 as pickle

chord_dict_path = "./data/chord2idx_dict.pkl"
chord2idx = pickle.load(open(chord_dict_path, 'rb'))

input_data_path =  "./data/all_chord_4_bars.pkl"
all_chord_event_list = pickle.load(open(input_data_path, 'rb'))
type(all_chord_event_list[0]["Chord"])

list

In [73]:
## -- generate test chord sequence
#test_chord_sequence = ['E:dom', 'B:dom','C#:min','A:dom','E:dom', 'B:dom','C#:min','A:dom']
#test_chord_sequence = ['C:maj','A:min','D:min','G:maj']
test_chord_sequence = ['C#:min','E:maj','B:maj','A:maj']
test_chord_idx_sequence = [chord2idx[x] for x in test_chord_sequence]
all_chord_event_list[0]["Chord"]
print(chord2idx)
test_chord_idx_sequence


{'C#:maj': 0, 'B:maj': 1, 'C#:min': 2, 'A:maj': 3, 'F:maj': 4, 'C:maj': 5, 'D:maj': 6, 'G:maj': 7, 'A#:maj': 8, 'G:min': 9, 'E:min': 10, 'A:min': 11, 'F#:maj': 12, 'D#:maj': 13, 'D:aug': 14, 'D#:aug': 15, 'C:min': 16, 'G#:maj': 17, 'F#:min': 18, 'B:min': 19, 'G#:min': 20, 'D#:min': 21, 'F:dom': 22, 'F:min': 23, 'D:min': 24, 'B:dim': 25, 'E:dom': 26, 'E:maj': 27, 'D#:dim': 28, 'A#:min': 29, 'D:dom': 30, 'A#:dom': 31, 'A:aug': 32, 'C:dim': 33, 'A:dom': 34, 'G:dom': 35, 'G#:dim': 36, 'E:dim': 37, 'C#:dim': 38, 'D#:dom': 39, 'F:aug': 40, 'G:dim': 41, 'C#:aug': 42, 'C:aug': 43, 'F#:dim': 44, 'A#:dim': 45, 'G#:dom': 46, 'G:aug': 47, 'B:aug': 48, 'D:dim': 49, 'F:dim': 50, 'B:dom': 51, 'C:dom': 52, 'A:dim': 53, 'F#:aug': 54, 'E:aug': 55, 'G#:aug': 56, 'F#:dom': 57, 'C#:dom': 58, 'A#:aug': 59}


[2, 27, 1, 3]

In [74]:
## -- inference loop
random_seed = 200
torch.manual_seed(random_seed)
import pdb
input = torch.tensor(test_chord_idx_sequence).unsqueeze(-1).cuda() # [seq_len, 1]
pdb.set_trace()
infer_melody_list = model(input)
len(infer_melody_list)

--Return--
None
> [0;32m/tmp/ipykernel_60796/958066533.py[0m(6)[0;36m<module>[0;34m()[0m
[0;32m      4 [0;31m[0;32mimport[0m [0mpdb[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      5 [0;31m[0minput[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mtensor[0m[0;34m([0m[0mtest_chord_idx_sequence[0m[0;34m)[0m[0;34m.[0m[0munsqueeze[0m[0;34m([0m[0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mcuda[0m[0;34m([0m[0;34m)[0m [0;31m# [seq_len, 1][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 6 [0;31m[0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      7 [0;31m[0minfer_melody_list[0m [0;34m=[0m [0mmodel[0m[0;34m([0m[0minput[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      8 [0;31m[0mlen[0m[0;34m([0m[0minfer_melody_list[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


1
213
1
3
4
1
42
184
44
1
42
163
66
1
42
152
66
1
28
140
66
1
28
140
94
5
3
103
5
28
25
94
5
28
140
94
5
28
140
94
72
42
163
94
72
42
152
94
72
28
140
53
72
28
25
53
72
28
157
53
6
3
63
6
42
163
94
6
42
152
94
6
42
140
94
6
28
25
94
6
28
156
94
6
42
140
94
77
42
25
53
77
3
156
53
8
3
63
8
42
189
30
8
42
163
30
8
42
152
30
8
28
140
30
8
28
25
30
8
28
156
30
8
28
25
53
60
28
25
53
60
28
34
53
0
1
211
1
3
70
1
42
183
94
1
42
99
94
1
42
140
94
1
28
156
94
1
28
156
94
1
42
25
94
5
3
63
5
19
184
94
5
19
196
94
5
19
163
94
5
19
152
94
5
19
25
94
5
42
25
94
5
28
156
94
5
42
25
94
5
42
25
94
6
3
4
6
42
99
94
6
42
189
94
6
42
100
94
6
42
152
94
6
42
25
94
6
28
25
94
6
28
156
94
6
3
4
8
42
25
94
8
42
25
94
8
42
25
94
8
42
25
94
8
42
25
94
8
42
25
94
8
28
156
94
0
1
211
1
3
70
1
42
183
61
1
42
99
94
1
42
140
94
1
28
25
94
1
28
156
94
1
28
156
94
5
3
70
5
19
184
94
5
19
196
94
5
19
163
94
5
19
152
94
5
19
25
94
5
42
25
94
5
28
156
94
5
42
25
94
6
3
4
6
42
99
94
6
42
189
94
6
42
100
94
6
42
152
94
6

524

In [75]:
## -- melody token to MIDI
event_dict_path =  "./REMI-tempo-chord-checkpoint/dictionary.pkl"
event2word, word2event = pickle.load(open(event_dict_path, 'rb'))

infer_melody_list.insert(0,0)
infer_melody_list

[0,
 1,
 213,
 1,
 3,
 4,
 1,
 42,
 184,
 44,
 1,
 42,
 163,
 66,
 1,
 42,
 152,
 66,
 1,
 28,
 140,
 66,
 1,
 28,
 140,
 94,
 5,
 3,
 103,
 5,
 28,
 25,
 94,
 5,
 28,
 140,
 94,
 5,
 28,
 140,
 94,
 72,
 42,
 163,
 94,
 72,
 42,
 152,
 94,
 72,
 28,
 140,
 53,
 72,
 28,
 25,
 53,
 72,
 28,
 157,
 53,
 6,
 3,
 63,
 6,
 42,
 163,
 94,
 6,
 42,
 152,
 94,
 6,
 42,
 140,
 94,
 6,
 28,
 25,
 94,
 6,
 28,
 156,
 94,
 6,
 42,
 140,
 94,
 77,
 42,
 25,
 53,
 77,
 3,
 156,
 53,
 8,
 3,
 63,
 8,
 42,
 189,
 30,
 8,
 42,
 163,
 30,
 8,
 42,
 152,
 30,
 8,
 28,
 140,
 30,
 8,
 28,
 25,
 30,
 8,
 28,
 156,
 30,
 8,
 28,
 25,
 53,
 60,
 28,
 25,
 53,
 60,
 28,
 34,
 53,
 0,
 1,
 211,
 1,
 3,
 70,
 1,
 42,
 183,
 94,
 1,
 42,
 99,
 94,
 1,
 42,
 140,
 94,
 1,
 28,
 156,
 94,
 1,
 28,
 156,
 94,
 1,
 42,
 25,
 94,
 5,
 3,
 63,
 5,
 19,
 184,
 94,
 5,
 19,
 196,
 94,
 5,
 19,
 163,
 94,
 5,
 19,
 152,
 94,
 5,
 19,
 25,
 94,
 5,
 42,
 25,
 94,
 5,
 28,
 156,
 94,
 5,
 42,
 25,
 94,
 5,
 42,
 25,
 94,


In [76]:
## -- check in REMI token
infer_token_list = [word2event[x] for x in infer_melody_list[:-1]]
infer_token_list

['Bar_None',
 'Position_1/16',
 'Chord_E:maj',
 'Position_1/16',
 'Tempo Class_mid',
 'Tempo Value_30',
 'Position_1/16',
 'Note Velocity_17',
 'Note On_44',
 'Note Duration_32',
 'Position_1/16',
 'Note Velocity_17',
 'Note On_56',
 'Note Duration_11',
 'Position_1/16',
 'Note Velocity_17',
 'Note On_61',
 'Note Duration_11',
 'Position_1/16',
 'Note Velocity_18',
 'Note On_68',
 'Note Duration_11',
 'Position_1/16',
 'Note Velocity_18',
 'Note On_68',
 'Note Duration_7',
 'Position_5/16',
 'Tempo Class_mid',
 'Tempo Value_37',
 'Position_5/16',
 'Note Velocity_18',
 'Note On_71',
 'Note Duration_7',
 'Position_5/16',
 'Note Velocity_18',
 'Note On_68',
 'Note Duration_7',
 'Position_5/16',
 'Note Velocity_18',
 'Note On_68',
 'Note Duration_7',
 'Position_7/16',
 'Note Velocity_17',
 'Note On_56',
 'Note Duration_7',
 'Position_7/16',
 'Note Velocity_17',
 'Note On_61',
 'Note Duration_7',
 'Position_7/16',
 'Note Velocity_18',
 'Note On_68',
 'Note Duration_3',
 'Position_7/16',
 'N

In [77]:
try:
    infer_melody_item = [word2event[x] for x in infer_melody_list]
except KeyError:
    infer_melody_item = [word2event[x] for x in infer_melody_list[:-1]]
    infer_melody_list = infer_melody_list[:-1]
len(infer_melody_item)

524

In [78]:
import utils
MIDI_PATH = "./CEBA.midi"
utils.write_midi(words = infer_melody_list,
                word2event = word2event,
                output_path = MIDI_PATH,
                prompt_path = None)