In [14]:
import numpy as np
import pickle as pkl

from music21 import converter, instrument, note, chord, stream, duration, midi

from torch.utils.data import DataLoader, random_split
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from utils import PianoDataset
from model import LSTMModel

In [4]:
with open('processed_data.pkl', 'rb') as f:
    (normalized_sequences, encoded_labels, vocab, d_min, d_max) = pkl.load(f)

In [6]:
# Model parameters
vocab_size = len(vocab)
embedding_dim = 64
hidden_dim = 256
hidden_dim2 = 128
model = LSTMModel(vocab_size, embedding_dim, hidden_dim, hidden_dim2)
model.load_state_dict(torch.load("model_parameters.pth", map_location=torch.device('cpu')))

  model.load_state_dict(torch.load("model_parameters.pth", map_location=torch.device('cpu')))


<All keys matched successfully>

In [7]:
#584136
idx = np.random.randint(len(normalized_sequences))
start_seq = normalized_sequences[idx]
print(idx)
print(start_seq)

214990
[(1221, 0.00206079340546116), (709, 0.0005151983513652744), (0, 0.0030911901081916524), (1743, 0.002060793405461101), (958, 0.0005151983513653328), (10, 0.006697578567748641), (51, 0.00206079340546116), (0, 0.0030911901081916524), (687, 0.0025759917568264354), (0, 0.0030911901081916524), (527, 0.003091190108191594), (291, 0.0036063884595569284), (148, 0.0051519835136526965), (106, 0.011334363730036061), (1, 0.0036063884595569275), (23, 0.0036063884595569275), (106, 0.00206079340546116), (8, 0.0036063884595569284), (260, 0.003091190108191594), (20, 0.0036063884595569284), (8, 0.003091190108191594), (133, 0.0020607934054611013), (133, 0.0061823802163833656), (0, 0.0030911901081916524), (236, 0.00206079340546116), (5, 0.0036063884595569865), (14, 0.0025759917568264354), (4, 0.0005151983513652744), (30, 0.00206079340546116), (297, 0.01030396702730557), (4, 0.0051519835136526965), (27, 0.0036063884595569284)]


In [11]:
def generate_notes(model, start_seq, num_notes):
    x_notes = torch.tensor([[pair[0] for pair in start_seq]], dtype=torch.long)  # Note indices
    x_durations = torch.tensor([[pair[1] for pair in start_seq]], dtype=torch.float)

    #gen_seq = [x for x in start_seq]
    gen_seq = []

    model.eval()
    for _ in range(num_notes):
        # Predicts the next note
        note_pred, duration_pred = model(x_notes, x_durations)

        probs = note_pred[0].detach().numpy()
        sorted_idxs = list(range(len(probs)))
        sorted_idxs.sort(key = lambda x: probs[x], reverse=True)
        new_note = np.random.choice(sorted_idxs[:3])
        #new_note = torch.argmax(note_pred, dim=1)
        new_dur = duration_pred.squeeze()

        gen_seq.append((new_note, float(new_dur)))

        x_notes = torch.cat((torch.tensor([x_notes[0].detach().numpy()[1:]], dtype=torch.long), torch.tensor([[new_note]], dtype=torch.long)), 1)
        x_durations = torch.cat((torch.tensor([x_durations[0].detach().numpy()[1:]], dtype=torch.float), torch.tensor([[new_dur]], dtype=torch.float)), 1)

    return gen_seq

In [16]:
def denormalize_duration(d_normalized, d_min, d_max):
    return d_normalized * (d_max - d_min) + d_min
inv_vocab = {v: i for i,v in vocab.items()}

In [18]:
# creates a new midi file from a flat sequence
def create_midi(sequence, out_path):
    notes = []

    prev_note = None
    rest_start = 0
    rest_duration = 0
    for i, (midi_code,dur) in enumerate(sequence):
        offset = i / 2
        if midi_code == 'REST':
            if prev_note == 'REST':
                rest_duration += dur
            else:
                prev_note = 'REST'
                rest_start = offset
                rest_duration = 0.5
        else:
            if prev_note == 'REST':
                new_note = note.Rest(rest_duration)
                new_note.offset = rest_start
                notes.append(new_note)
            
            if '.' in midi_code:
                components = []
                for comp in midi_code.split('.'):
                    cur_note = note.Note(int(comp))
                    cur_note.storedInstrument = instrument.Piano()
                    components.append(cur_note)
                new_chord = chord.Chord(components)
                new_chord.offset = offset
                new_chord.duration = duration.Duration(dur)
                notes.append(new_chord)
                prev_note = 'CHORD'
            # pattern is a note
            else:
                new_note = note.Note(int(midi_code))
                new_note.offset = offset
                new_note.storedInstrument = instrument.Piano()
                new_note.duration = duration.Duration(dur)
                notes.append(new_note)
                prev_note = 'NOTE'
    
    midi_stream = stream.Stream(notes)
    midi_stream.write('midi', fp=out_path)

In [12]:
gen_seq = generate_notes(model, start_seq, 500)
unemb_seq = [(inv_vocab[x], denormalize_duration(y, d_min, d_max)) for x,y in gen_seq]
create_midi(unemb_seq, 'test.midi')