In [61]:
import torch
from models import MelodyLSTM
from pathlib import Path
from prepare_data import encoding
import matplotlib.pyplot as plt
import numpy as np


def sample_with_temperature(scores, t: float = 1.0):
    prob = scores ** (1.0 / t)
    # I assume np normalizes p to sum to 1.
    prob = prob / sum(prob)  # TODO Maybe make more numerically stable, logsumexp.
    return np.random.choice(range(len(scores)), p=prob).item()

In [3]:
model_file = Path("models/model_time_series_2024-10-15T20-31-27.pth")
model_dict = torch.load(model_file)
config, state_dict = model_dict["config"], model_dict["state_dict"]

model = MelodyLSTM(
    num_unique_tokens=config["num_unique_tokens"],
    embedding_size=config["embedding_size"],
    hidden_size=config["hidden_size"],
)
model.load_state_dict(state_dict)

model.eval()

  model_dict = torch.load(model_file)


MelodyLSTM(
  (embedding): Embedding(132, 8)
  (lstm): LSTM(8, 8)
  (fc): Linear(in_features=8, out_features=132, bias=True)
)

# Illustration: generating without randomness

Given a starting note sequence, below is the continuation from the model when taking the most "likely" next note. It just predicts a `HOLD` token (`129`) because the grid was 16th notes so there are many hold tokens. 

In [6]:
def generate_melody(model, initial_sequence, num_notes, sequence_length):
    melody = list(initial_sequence)
    for i in range(num_notes):
        inputs = melody[-sequence_length:]
        scores = model(inputs)[-1]
        next_item = torch.argmax(scores).item()
        melody.append(next_item)
    return melody


seq1 = ["36", "H", "H", "H", "37", "38", "H", "H"]
scores = torch.exp(model([encoding[e] for e in seq1])[-1]).detach()
generate_melody(
    model=model,
    initial_sequence=[encoding[e] for e in seq1],
    num_notes=10,
    sequence_length=config["sequence_length"],
)

[36,
 129,
 129,
 129,
 37,
 38,
 129,
 129,
 129,
 129,
 129,
 129,
 129,
 129,
 129,
 129,
 129,
 129]

Below are the model scores (predictions) for the next note in the sequence. Instead of selecting the note with the largest score, we will instead sample from this distribution, or a slightly modified version of it. 

In [None]:
plt.figure(figsize=(3, 3))

plt.barh(range(len(scores)), scores)
plt.xlabel("Model scores")
plt.ylabel("MIDI Notes")

# Generate 

In [170]:
def generate_melody2(
    model, initial_sequence, num_notes, sequence_length, temperature=1.0
):
    melody = list(initial_sequence)
    for i in range(num_notes):
        inputs = melody[-sequence_length:]
        scores = np.exp(model(inputs)[-1].detach().numpy())
        next_item = sample_with_temperature(scores, t=temperature)
        melody.append(next_item)
    return melody


np.random.seed(202)
seq1 = ["30", "H", "H", "H", "30", "32", "H", "H"]
mel1 = generate_melody2(
    model=model,
    initial_sequence=[encoding[e] for e in seq1],
    num_notes=100,
    sequence_length=config["sequence_length"],
    temperature=1.0,
)
mel1[:20]

[30,
 129,
 129,
 129,
 30,
 32,
 129,
 129,
 60,
 71,
 129,
 129,
 129,
 129,
 129,
 62,
 129,
 128,
 129,
 129]

# Save to midi



In [146]:
decoding = {v: k for k, v in encoding.items()}

In [None]:
import music21 as m21
from prepare_data import HOLD, REST


def time_series_to_midi(
    sequence: list[str],
    step_duration: float,
    filename: str | Path = None,
    hold_token=HOLD,
    rest_token=REST,
):
    """Convert a time series melody to midi

    Args:
        sequence: list of strings. A melody as notes or rests or hold tokens at fixed time steps.
        filename: Path to save midi file. Defaults to None.

    Returns:
        music21 stream
    """
    stream = m21.stream.Stream()

    step = 1
    for e in sequence:
        if e == hold_token:
            step += 1
        else:
            length = step_duration * step
            if e == rest_token:
                note = m21.note.Rest(quarterLength=length)
            else:
                note = m21.note.Note(pitch=int(e), quarterLength=length)
            stream.append(note)
            step = 1

    if filename is not None:
        stream.write(fmt="midi", fp=filename)
    return stream


stream1 = time_series_to_midi([decoding[e] for e in mel1], step_duration=1)
stream1.show("midi")

# Baseline generators

#TODO Make baseline generators: 

- uniform over tokens
- uniform over neighboring notes
- prob proportional to distance 

distance = pitch and chroma, also condition on a key?