In [None]:
from midi import Midi
from QLTSM.LSTMusic import LSTMusic
import torch

from pathlib import Path

from pennylane import numpy as np

import matplotlib.pyplot as plt

In [None]:
seq_length = 25
n_epochs = 10
cutoff = 20
n_qubits = 4

model_name = f"lstm-seq{seq_length}-cut{cutoff}-epcs{n_epochs}-qu{n_qubits}"
model_str = f"saved_models/{model_name}.pt"

In [None]:
print("Initialized Midi")
midi = Midi(seq_length)

In [None]:
print("Initialized LSTM")
lstm = LSTMusic(hidden_dim=midi.n_vocab, n_qubits=n_qubits)

if Path(model_str).is_file():
    print("Loading model")
    lstm.load_state_dict(torch.load(model_str))
    lstm.eval()
    # lstm = torch.load(model_str)
else:
    print("Training LSTM")
    train_history = lstm.train(
        True, midi.network_input, midi.network_output, n_epochs=n_epochs, cutoff=cutoff
    )
    torch.save(lstm.state_dict(), model_str)

In [None]:
print("Generating notes")
notes = lstm.generate_notes(
    midi.network_input, midi.int_to_note, midi.n_vocab, n_notes=20
)

In [None]:
print("Saving as MIDI file.")
midi.create_midi_from_model(notes, f"generated_songs/{model_name}_generated.mid")