# Generate HMM samples


## Setup


### Imports


In [1]:
import glob
import re
import sys
from functools import wraps
from typing import cast

import numpy as np
import pandas as pd
from tqdm.auto import tqdm, trange

### Misc. config


In [2]:
# This makes things look better on high-dpi displays
%matplotlib inline
%config InlineBackend.figure_format='retina'
# Set numpy to display 4 digits to make things more readable
np.set_printoptions(precision=4)

### Sampling/utility functions


In [3]:
def temperature_scaling(p, temperature):
    # Add small epsilon to avoid log(0)
    eps = 1e-8
    log_probs = np.log(p + eps)
    scaled = np.exp(log_probs / temperature)
    return scaled / np.sum(scaled)


def sample(p, temperature=None):
    if temperature is not None:
        p = temperature_scaling(p, temperature)
    return np.random.choice(len(p), p=p)


def build_matrix_sampler(df):
    # This just makes it one step easier to sample from the matrix
    # I don't know how efficient pandas is at this conversion internally, so this is
    # very defensive.
    matrix = df.to_numpy()

    @wraps(sample)
    def wrapper(id_, *args, **kwargs):
        col_idx = df.columns.get_loc(id_)
        p = matrix[:, col_idx]
        index = sample(p, *args, **kwargs)
        return df.index[index]

    return wrapper


def check_lists_match(list1, list2):
    for i, (item1, item2) in enumerate(zip(list1, list2)):
        assert item1 == item2, f"Item mismatch at position {i}: {item1} != {item2}"


## Setup chord-to-note ($S_0$)


In [None]:
# import data/chord-to-note.csv
chord_to_note = pd.read_csv("data/chord-to-note.csv", index_col=0)
# Sort headers (select columns and sort them)
chord_to_note = chord_to_note.sort_index(axis=1, key=lambda x: x.str.split(" ").str[0])
# Sort rows by index
chord_to_note = chord_to_note.sort_index()
# Renormalize columns
chord_to_note = chord_to_note / chord_to_note.sum(axis=0)

# Display the sorted DataFrame
chord_to_note.head()

### Define $S_0$ sampling method


In [5]:
sample_chord_to_note = build_matrix_sampler(chord_to_note)

In [None]:
sample_chord_to_note("Em")

## Setup note-to-note ($S_j$)


In [None]:
# import data/note-to-note.csv
note_to_note = pd.read_csv("data/note-to-note.csv", index_col=0)
# Convert headers and index to integers
note_to_note.columns = note_to_note.columns.astype(int)
note_to_note.index = note_to_note.index.astype(int)
# Sort headers
note_to_note = note_to_note.sort_index(axis=1)
# Sort index
note_to_note = note_to_note.sort_index()
# Renormalize columns
note_to_note = note_to_note / note_to_note.sum(axis=0)
# Display the sorted DataFrame
note_to_note.head()

### Sanity check: relationship between note-to-note header and index


In [8]:
# Get note_to_note header and index
note_to_note_header = set(map(int, note_to_note.columns))
note_to_note_index = set(map(int, note_to_note.index))

# Assert difference is 0
assert len(note_to_note_header - note_to_note_index) == 0
assert len(note_to_note_index - note_to_note_header) == 0

### Define $S_j$ sampling method


In [9]:
sample_note_to_note = build_matrix_sampler(note_to_note)

In [None]:
sample_note_to_note(38)

### Define phrase generation method


In [11]:
def phrase_generator(s_0, temperature=None):
    yield s_0
    prev_note = s_0
    # sample subsequent notes from note_to_note
    while True:
        yield (prev_note := sample_note_to_note(prev_note, temperature))

In [12]:
def sample_phrase(s_0, k, temperature=None):
    generator = phrase_generator(s_0, temperature)
    return [next(generator) for _ in range(k)]

In [None]:
list(sample_phrase(38, 5, temperature=1))

## Setup triad variations ($V_t$)


In [14]:
triad_variations = {}

for file_path in glob.glob("data/triad-variation/*.csv"):
    with open(file_path, "r") as f:
        df = pd.read_csv(f, index_col=0)
        key = df.columns[0]
        df.columns = ["p"]
        assert np.allclose(df["p"].sum(), 1), f"Probabilities for {key} do not sum to 1"
        triad_variations[key] = df

In [None]:
triad_variations["A#M"]

### Define $V_t$ sampling method


In [16]:
def sample_triad_variation(chord, temperature=None):
    variations = triad_variations[chord]
    p = variations["p"]
    return variations.index[sample(p, temperature)]

In [None]:
sample_triad_variation("A#M")

### Sanity check: relationship between chord-note index and note-note index


In [None]:
chord_to_note_index = set(map(int, chord_to_note.index))
print(
    "Notes in chord_to_note but not in note_to_note",
    set(chord_to_note_index) - set(note_to_note_index),
)
print(
    "Notes in note_to_note but not in chord_to_note",
    set(note_to_note_index) - set(chord_to_note_index),
)
print()
print("\033[1mWe've decided this is fine.\033[0m")

## Setup triad-to-triad ($U_t$)


In [19]:
from typing import Match

triad_to_triad = {}

rows_order = None
columns_order = None

triad_to_triad_index_to_id = {}

for file_path in glob.glob("data/triad-to-triad/*.csv"):
    lookback_length = (
        int(cast(Match[str], re.search(r"(\d+)-step", file_path)).group(1)) - 1
    )
    with open(file_path, "r") as f:
        # Setting the first column as the index makes a lot fo the work we do much nicer
        df = pd.read_csv(f, index_col=0)
        # Sort headers after first column (select other columns and sort them)
        df = df.sort_index(axis=1)
        # Sort rows by first column
        df = df.sort_index()
        # Renormalize columns
        df = df / df.sum(axis=0)

        if rows_order is None or columns_order is None:
            rows_order = df.index.tolist()
            columns_order = df.columns.tolist()

            # Get note_to_note header and index
            triad_to_triad_header_set = set(rows_order)
            triad_to_triad_index_set = set(columns_order)

            # Assert difference set is empty
            assert len(triad_to_triad_header_set - triad_to_triad_index_set) == 0, (
                f"triad_to_triad_header - triad_to_triad_index: {triad_to_triad_header_set - triad_to_triad_index_set}"
            )
            assert len(triad_to_triad_index_set - triad_to_triad_header_set) == 0, (
                f"triad_to_triad_index - triad_to_triad_header: {triad_to_triad_index_set - triad_to_triad_header_set}"
            )

            # Create the mappings we need
            triad_to_triad_index_to_id = dict(
                zip(range(len(columns_order)), columns_order)
            )

        # Compare each item to ensure they match (this is defensive)
        check_lists_match(df.index.tolist(), rows_order)
        check_lists_match(df.columns.tolist(), columns_order)

        # Again, including the numpy version is defensive to the point of being hard
        # to defend.
        triad_to_triad[lookback_length] = (df, df.to_numpy())

### Define $U_t$ sampling method


In [20]:
def triad_to_triad_lag_distribution(lag, chord):
    df, matrix = triad_to_triad[lag]
    col_idx = df.columns.get_loc(chord)
    return matrix[:, col_idx]

In [21]:
def sample_triad_to_triad(past_chords, m, weights, temperature=None):
    p = np.zeros(len(triad_to_triad_index_to_id))

    # Combine vectors from each lag
    for j in range(min(m, len(past_chords), 8)):
        # For j, we want the chord at t - j, which is just -j - 1
        chord = past_chords[-(j + 1)]
        # Get the probability distribution of the triad to triad transition
        p += weights[j] * triad_to_triad_lag_distribution(j, chord)

    # Normalize combined probability distribution
    p /= np.sum(p)

    return triad_to_triad_index_to_id[sample(p, temperature)]

In [None]:
triad_to_triad_lag_distribution(0, "A#M")

In [None]:
example_past_chords = ["A#M", "GM", "C#M"]
example_past_chords_length = len(example_past_chords)
example_past_chords_weights = (
    np.ones(example_past_chords_length) / example_past_chords_length
)
sample_triad_to_triad(
    example_past_chords,
    example_past_chords_length,
    example_past_chords_weights,
    temperature=None,
)

## Generation


### Define burn-in initialization


In [24]:
# Configure as desired
MAX_BURN_IN_RETRIES = 10_000
MAX_BURN_IN_LENGTH = 10_000_000

In [25]:
def initialize_burn_in(u_0, n, temperature=None):
    t = None
    chords = None
    success = True

    for retry in range(MAX_BURN_IN_RETRIES):
        chords = [u_0]
        t = 0
        last_chord = u_0
        with tqdm(total=n, desc="Burn in") as pbar:
            while t < n or (last_chord != u_0):
                t += 1
                triad_u_t = None
                while triad_u_t is None or triad_u_t == "End":
                    p = triad_to_triad_lag_distribution(0, last_chord)
                    triad_u_t = triad_to_triad_index_to_id[sample(p, temperature)]
                chords.append(triad_u_t)
                last_chord = triad_u_t
                pbar.update(1)
                if t > MAX_BURN_IN_LENGTH:
                    print(
                        f"Burn in attempt {retry} failed after {t} steps",
                        file=sys.stderr,
                    )
                    success = False
                    break
        if success:
            break

    if not success:
        print(
            f"Burn in failed after maximum {MAX_BURN_IN_RETRIES} retries",
            file=sys.stderr,
        )
        raise Exception("Burn in failed")
    return chords

In [None]:
cast(list[str], initialize_burn_in("CM", 500))[-5:]

### Define generation sampler


In [27]:
DEFAULT_MAX_LENGTH = 100

In [None]:
def generate_music(
    u_0, k, m, n, initial_chords=None, weights=None, max_length=DEFAULT_MAX_LENGTH
):
    weights = np.array(weights) if weights is not None else np.ones(m) / m
    assert weights.shape == (m,), f"Weights shape {weights.shape} does not match m {m}"
    assert np.allclose((weights_sum := weights.sum()), 1), (
        f"Weights do not sum to 1: {weights_sum}"
    )
    output = []
    triad_u_t = u_0
    # I chose to allow initial_chords instead of hardcoding an initialization function
    # because that lets us easily try out different initialization methods without
    # making it the responsibility of this method
    chords = [u_0] if initial_chords is None else initial_chords
    n = max(n, m)
    # Notice here that we set t to 0, which ignores any initial chords. This lets us
    # get an output of consistent length (at least with the current burn-in
    # implementation), but we could also do:
    #
    # t = len(initial_chords) if initial_chords is not None else 0
    #
    # Again, depends how much we care about getting a consistent length of output from
    # the model.
    t = 0

    assert chords[-1] == u_0

    with tqdm(total=min(max_length, n), desc="Generation") as pbar:
        while triad_u_t != "End" and t < min(max_length, n):
            # Sample the first variation
            triad_variation_x_t = sample_triad_variation(triad_u_t, temperature=None)
            # Sample the first note
            first_note_s_0 = sample_chord_to_note(triad_variation_x_t, temperature=None)
            # Sample phrase
            phrase_z_t = sample_phrase(first_note_s_0, k, temperature=None)
            # Add to results
            output.append((triad_u_t, phrase_z_t))

            # Sample next triad
            triad_u_t = sample_triad_to_triad(chords, m, weights, temperature=None)
            if triad_u_t == "End":
                break
            # Update past chords
            chords.append(triad_u_t)

            t += 1
            pbar.update(1)

    return output

### Generate music


In [None]:
# Based on counts in https://en.wikipedia.org/wiki/List_of_chord_progressions
FREQUENCY_WEIGHTS = np.array([35, 34, 32, 17, 4, 2, 1, 1], dtype=np.float64)
FREQUENCY_WEIGHTS /= FREQUENCY_WEIGHTS.sum()

In [None]:
burned_in_chords = initialize_burn_in("A#M", 5)
music = generate_music(
    "A#M", 8, 8, 10, initial_chords=burned_in_chords, weights=FREQUENCY_WEIGHTS
)
music[-5:]

### Postprocess for output


In [None]:
# from music21 import note, tie, harmony
from music21.note import Note, Rest
from music21.tie import Tie
from music21.harmony import ChordSymbol
from music21.expressions import TextExpression
from music21.metadata import Metadata

# These imports are this way because ruff complains about the more direct imports
from music21.meter.base import TimeSignature
from music21.stream.base import Measure, Score, Part

In [32]:
def chunk_every(seq, size):
    return (seq[pos : pos + size] for pos in range(0, len(seq), size))

In [33]:
def postprocess_output(music, metadata=None):
    # TODO: There is a bug in here dealing with tying things across measures
    chords, phrases = zip(*music)
    chords = list(chunk_every(chords, 2))
    phrases = list(chunk_every(phrases, 2))
    score = Score()
    melody_part = Part()
    harmony_part = Part()
    melody_part.append(TimeSignature("4/4"))
    harmony_part.append(TimeSignature("4/4"))

    score.metadata = Metadata()
    score.metadata.title = "Generated score"
    if metadata is not None:
        score.metadata.composer = metadata

    # This variable will hold the last note or rest from the previous measure, so we can
    # carry it over if a measure starts with -2.
    last_element = None

    for j, (chord_names, phrase_group) in enumerate(zip(chords, phrases)):
        melody_measure = Measure()
        harmony_measure = Measure()
        quarter_sum = 0

        for chord_name, phrase in zip(chord_names, phrase_group):
            # Add chord symbol (displayed over the entire measure)
            cs = ChordSymbol(chord_name)
            cs.quarterLength = 2
            harmony_measure.append(cs)
            # Process each symbol in the phrase
            for i, symbol in enumerate(phrase):
                if symbol >= 0:
                    # Create a new note with duration one quarter note.
                    n = Note()
                    n.pitch.midi = symbol
                    n.quarterLength = 0.25
                    melody_measure.append(n)
                    quarter_sum += 0.25
                    # Update last element
                    last_element = n
                elif symbol == -1:
                    # Create a rest with duration one quarter note.
                    r = Rest(quarterLength=0.25)
                    melody_measure.append(r)
                    quarter_sum += 0.25
                    # Update last element
                    last_element = r
                elif symbol == -2:
                    # Determine if we're at the beginning of the measure (chord symbol)
                    # m[0] is the chord symbol, so if len(m)==1 then no note/rest has
                    # been added
                    # yet.
                    if len(melody_measure) == 0:
                        # Measure starts with -2: carry over the last element from
                        # previous measure
                        if last_element is None:
                            # At the very start: no previous element. Create a rest.
                            r = Rest(quarterLength=0.25)
                            # print(f"Rest in measure {j} (last_element is None)")
                            melody_measure.append(r)
                            quarter_sum += 0.25
                            last_element = r
                        else:
                            # Create a tied note or rest in the new measure based on
                            # last_element.
                            if isinstance(last_element, Note):
                                new_elem = Note(last_element.pitch)
                            else:
                                new_elem = Rest()
                            new_elem.quarterLength = 0.25
                            # For ties, update the previous element's tie:
                            # If it doesn't have one, mark it as starting a tie.
                            if (
                                not hasattr(last_element, "tie")
                                or last_element.tie is None
                            ):
                                last_element.tie = Tie("start")
                            else:
                                # If already tied, set it to 'continue'
                                last_element.tie = Tie("continue")
                            # Mark the new element as the ending tie.
                            new_elem.tie = Tie("stop")
                            melody_measure.append(new_elem)
                            last_element = new_elem
                    else:
                        # Otherwise, extend the previous element in the current measure by one
                        # quarter note.
                        # This effectively lengthens its duration.
                        # We assume here that m[-1] is a note or a rest.
                        if last_element is not None and hasattr(
                            last_element, "quarterLength"
                        ):
                            last_element.quarterLength += 0.25
                            quarter_sum += 0.25
                            last_element = melody_measure[-1]
                        else:
                            # Fallback: if for some reason there is no valid element, add a rest.
                            r = Rest(quarterLength=0.25)
                            print(f"Rest in measure {j} (no valid element)")
                            melody_measure.append(r)
                            quarter_sum += 0.25
                            last_element = r

        melody_part.append(melody_measure)
        harmony_part.append(harmony_measure)

    score.append(harmony_part)
    score.append(melody_part)
    return score

In [34]:
example_generated_music = [
    ("FM", [74, 77, -2, 67, -2, 75, -2, -2]),
    ("CM", [-2, -2, 74, -2, -2, -2, 66, -2]),
    ("A#m", [79, 78, -2, 75, -2, -2, 58, -2]),
    ("FM", [78, -2, -2, -2, -2, -2, -2, 68]),
    ("FM", [69, -2, -2, -2, -2, 66, -2, -2]),
    ("D#m", [81, -2, -2, 75, -2, 64, -2, 73]),
    ("G#m", [68, -2, -2, -2, 76, -2, 84, -2]),
    ("D#m", [61, -2, 82, -2, -2, -2, -2, 72]),
    ("EM", [77, -2, -2, -2, 82, -2, -1, -2]),
    ("CM", [76, -2, -2, -2, -2, -2, -2, -2]),
]

In [35]:
example_generated_score = postprocess_output(example_generated_music, "Example")

In [None]:
example_generated_score.write("musicxml", fp="output.musicxml")

## Try things


In [37]:
!mkdir -p outputs

In [None]:
# No burn-in
no_burn_in_music = generate_music(
    "AM", 8, 8, 100, initial_chords=None, weights=FREQUENCY_WEIGHTS
)
no_burn_in_score = postprocess_output(no_burn_in_music, "No burn-in")
no_burn_in_score.write("musicxml", fp="outputs/no-burn-in.musicxml")

In [None]:
# Long burn-in
long_burn_in_duration = 100_000
key = "AM"
burned_in_chords = initialize_burn_in(key, long_burn_in_duration)
long_burn_in_music = generate_music(
    key, 8, 8, 100, initial_chords=burned_in_chords, weights=FREQUENCY_WEIGHTS
)
long_burn_in_score = postprocess_output(
    long_burn_in_music, f"Long burn-in: {long_burn_in_duration}, key: {key}"
)
long_burn_in_score.write("musicxml", fp=f"outputs/long-burn-in-{key}.musicxml")


In [None]:
burned_in_chords = initialize_burn_in("CM", 100)

In [None]:
# No burn-in, uniform weights
no_burn_in_no_weights_music = generate_music(
    "CM", 8, 8, 100, initial_chords=None, weights=None
)
no_burn_in_no_weights_score = postprocess_output(
    no_burn_in_no_weights_music, "No burn-in, uniform weights"
)
no_burn_in_no_weights_score.write(
    "musicxml", fp="outputs/no-burn-in-no-weights.musicxml"
)

In [None]:
# Long burn-in, uniform weights
burned_in_chords = initialize_burn_in("CM", 100)
long_burn_in_no_weights_music = generate_music(
    "CM", 8, 8, 100, initial_chords=burned_in_chords, weights=None
)
long_burn_in_no_weights_score = postprocess_output(
    long_burn_in_no_weights_music, "burn-in 100, uniform weights"
)
long_burn_in_no_weights_score.write(
    "musicxml", fp="outputs/long-burn-in-no-weights.musicxml"
)

In [None]:
# Custom weights favoring more recent chords
# custom_weights = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8])
# recent_preferring_weights = np.array([0.1, 0.1, 0.1, 0.1, 0.1, 8, 9, 10])
recent_preferring_weights = np.array([0, 0, 0, 0, 0, 0, 0, 1], dtype=np.float64)
recent_preferring_weights /= recent_preferring_weights.sum()
long_burn_in_duration = 100_000
key = "CM"
burned_in_chords = initialize_burn_in(key, long_burn_in_duration)
recent_preferring_music = generate_music(
    key, 8, 8, 100, weights=recent_preferring_weights, initial_chords=burned_in_chords
)
recent_preferring_score = postprocess_output(
    recent_preferring_music,
    f"Long burn in: {long_burn_in_duration}, recent preferring weights: {recent_preferring_weights.tolist()}, key: {key}",
)
recent_preferring_score.write("musicxml", fp="outputs/recent-preferring.musicxml")


In [None]:
burned_in_chords = initialize_burn_in("CM", 100_000)
less_recent_preferring_weights = np.array([10, 9, 8, 7, 6, 5, 4, 3], dtype=np.float64)
less_recent_preferring_weights /= less_recent_preferring_weights.sum()
less_recent_preferring_music = generate_music(
    "CM",
    8,
    8,
    100,
    initial_chords=burned_in_chords,
    weights=less_recent_preferring_weights,
)
less_recent_preferring_score = postprocess_output(
    less_recent_preferring_music,
    f"Long burn in: {long_burn_in_duration}, less recent preferring weights: {less_recent_preferring_weights.round(2).tolist()}",
)
less_recent_preferring_score.write(
    "musicxml", fp="outputs/less-recent-preferring.musicxml"
)