In [21]:
from music21 import chord, note, stream
import yaml
from lightning.data import MusicDataWrapper
from lightning.rnn import LitAttentionRNN
from lightning.seq2seq import LitSeq2Seq

def show_and_write(notes, durs, fn):
    streamm = stream.Stream()
    for notee, dur in zip(notes, durs):
        if notee == "START" or not dur:
            continue

        if '.' in notee:
            element = chord.Chord(notee.split("."))
        else:
            element = note.Note(notee)

        streamm.append(element)

    streamm.write("midi", fn)

def get_music_from_tokens(tokenizer_ds, token_notes, token_durs):
    for tok_note, tok_dur in zip(token_notes, token_durs):
        notes.append(tokenizer_ds.tokens_to_notes[tok_note.item()])
        durs.append(tokenizer_ds.tokens_to_durations[tok_dur.item()])
    return notes, durs

# Attention RNN

In [6]:
with open("config_rnn.yaml", "r") as f:
    config_rnn = yaml.load(f, Loader=yaml.FullLoader)

checkpoint_path = "checkpoints\model-epoch=09-val_loss=2.29.ckpt"
dm_rnn = MusicDataWrapper(config_rnn)
lit_rnn = LitAttentionRNN(config_rnn, dm_rnn.num_notes_classes, dm_rnn.num_duration_classes)
lit_rnn_trained = LitAttentionRNN.load_from_checkpoint(checkpoint_path, config=config_rnn, input_note_size=dm_rnn.num_notes_classes, input_dur_size=dm_rnn.num_duration_classes)

In [11]:
notes, durs = lit_rnn.generate(dm_rnn.dataset)
show_and_write(notes, durs, "untrained_rnn.midi")

In [7]:
notes, durs = lit_rnn_trained.generate(dm_rnn.dataset)
show_and_write(notes, durs, "trained_rnn.midi")

# Encoder-Decoder model

In [3]:
with open("config_seq2seq.yaml", "r") as f:
    config_seq2seq = yaml.load(f, Loader=yaml.FullLoader)

checkpoint_path = "checkpoints_seq2seq\model-epoch=06-val_loss=100.53.ckpt"
dm_seq2seq = MusicDataWrapper(config_seq2seq)
lit_seq2seq = LitSeq2Seq(config_seq2seq, dm_seq2seq.num_notes_classes, dm_seq2seq.num_duration_classes)
lit_seq2seq_trained = LitSeq2Seq.load_from_checkpoint(checkpoint_path, config=config_seq2seq, input_note_size=dm_seq2seq.num_notes_classes, input_dur_size=dm_seq2seq.num_duration_classes)

In [4]:
notes, durs = lit_seq2seq.generate(dm_seq2seq.dataset, start_seq=(["START"], [0]))
show_and_write(notes, durs, "untrained_seq2seq.midi")

In [10]:
notes, durs = lit_seq2seq_trained.generate(dm_seq2seq.dataset, start_seq=(["A4", "A4"], [1, 1]))
show_and_write(notes, durs, "trained_seq2seq.midi")

In [29]:
tokenizer = dm_seq2seq.dataset
position = 10000
show_and_write(*get_music_from_tokens(tokenizer, *tokenizer[5000]), "example.midi")

In [30]:
position = 10000
notes, durs = lit_seq2seq_trained.generate(dm_seq2seq.dataset, start_seq=get_music_from_tokens(tokenizer, *tokenizer[position]))
show_and_write(notes, durs, "example_recreate.midi")