# Functional Baseline Model

Symbolic Melody Harmonization Model

This baseline model takes a tokenized melody (sequence of pitch tokens from 3–130) and generates an aligned harmony
sequence (e.g., bass line or chords), producing one harmony token per melody token.

Model Architecture:
1. Embedding Layer:
   - Converts integer pitch tokens into dense embeddings.
2. Encoder (BiLSTM):
   - Processes the entire melody, outputting a context summary from forward and backward hidden states.
3. Decoder (LSTM):
   - Initialized with encoder state; receives shifted harmony input during training.
   - Autoregressively generates harmony tokens during inference (greedy decoding).
4. Output Layer:
   - Linear projection + softmax to predict harmony token at each step.

Training:
- Loss: CrossEntropyLoss (ignoring <PAD> tokens).
- Teacher forcing: ground truth harmony tokens fed into decoder during training.
- Tokens:
  <PAD> = 0, <SOS> = 1, <EOS> = 2, pitch tokens = 3–130 (MIDI pitch values + 3 offset).

Limitations:
- No rhythm, duration, or polyphony (pitch-only representation).
- Fixed-length, uniform-duration tokens.
- No beat/bar-level structure or musical rules.
- Greedy decoding only.

Future Improvements:
- Integrate time/duration tokens (REMI-style).
- Use Transformer-based encoder-decoder.
- Support chord labels and multi-voice output.
- Apply musical constraints (e.g., avoid parallel 5ths).


In [13]:
import os
from glob import glob
import pretty_midi
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import math
from tqdm import tqdm 
import torch.nn.functional as F

In [14]:
# Constants
MIDI_DIR = "MIDI"
MELODY_DIR = os.path.join(MIDI_DIR, "melody")
CHORDS_DIR = os.path.join(MIDI_DIR, "chords")
PAD_IDX = 0
SOS_IDX = 1
EOS_IDX = 2
VOCAB_START = 3  # actual MIDI pitches start here (3 -> 130)

In [15]:
# Utility: Extract pitches from MIDI file
def extract_pitches(path):
    midi = pretty_midi.PrettyMIDI(path)
    notes = []
    for instr in midi.instruments:
        if instr.is_drum:
            continue
        for note in instr.notes:
            notes.append((note.start, note.pitch))
    notes.sort()
    return [pitch for _, pitch in notes]

# Convert pitch list to tokenized form
def tokenize(pitches, max_len=64):
    tokens = [SOS_IDX] + [p + VOCAB_START for p in pitches[:max_len - 2]] + [EOS_IDX]
    return tokens + [PAD_IDX] * (max_len - len(tokens))

In [16]:
# Dataset class
class MelodyHarmonyDataset(Dataset):
    def __init__(self, melody_paths, chords_paths, max_len=64):
        self.pairs = list(zip(melody_paths, chords_paths))
        self.max_len = max_len

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        mel_path, chord_path = self.pairs[idx]
        melody = tokenize(extract_pitches(mel_path), self.max_len)
        harmony = tokenize(extract_pitches(chord_path), self.max_len)
        return (
            torch.tensor(melody, dtype=torch.long),
            torch.tensor(harmony[:-1], dtype=torch.long),  # input
            torch.tensor(harmony[1:], dtype=torch.long),   # target
        )

In [17]:
# Model definition
class Seq2SeqHarmonizer(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=PAD_IDX)
        self.encoder = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.decoder = nn.LSTM(embed_dim, hidden_dim * 2, batch_first=True)
        self.output = nn.Linear(hidden_dim * 2, vocab_size)

    def forward(self, melody_seq, harmony_input):
        mel_emb = self.embedding(melody_seq)
        harm_emb = self.embedding(harmony_input)
        _, (hidden, cell) = self.encoder(mel_emb)
        h_cat = torch.cat((hidden[0], hidden[1]), dim=1).unsqueeze(0)
        c_cat = torch.cat((cell[0], cell[1]), dim=1).unsqueeze(0)
        out, _ = self.decoder(harm_emb, (h_cat, c_cat))
        return self.output(out)

In [18]:

# Training loop
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for melody, harm_in, harm_out in dataloader:
        melody, harm_in, harm_out = melody.to(device), harm_in.to(device), harm_out.to(device)
        optimizer.zero_grad()
        pred = model(melody, harm_in)
        loss = criterion(pred.view(-1, pred.shape[-1]), harm_out.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

# Inference (greedy)
def generate(model, melody_seq, max_len=64):
    model.eval()
    with torch.no_grad():
        melody_seq = melody_seq.unsqueeze(0)
        mel_emb = model.embedding(melody_seq)
        _, (hidden, cell) = model.encoder(mel_emb)
        h_cat = torch.cat((hidden[0], hidden[1]), dim=1).unsqueeze(0)
        c_cat = torch.cat((cell[0], cell[1]), dim=1).unsqueeze(0)

        input_token = torch.tensor([[SOS_IDX]], device=melody_seq.device)
        result = []

        for _ in range(max_len):
            emb = model.embedding(input_token)
            out, (h_cat, c_cat) = model.decoder(emb, (h_cat, c_cat))
            logits = model.output(out[:, -1, :])
            next_token = torch.argmax(logits, dim=-1)
            if next_token.item() == EOS_IDX:
                break
            result.append(next_token.item())
            input_token = next_token.unsqueeze(0)
        return result


In [19]:
import pretty_midi

def write_combined_midi(melody_tokens, harmony_tokens, out_path, tempo=120):
    """
    Convert token sequences into MIDI file with melody and harmony instruments.
    Assumes tokens are pitch-only and does not encode rhythm/duration.
    """

    pm = pretty_midi.PrettyMIDI()
    melody_instr = pretty_midi.Instrument(program=0, name="Melody")
    harmony_instr = pretty_midi.Instrument(program=0, name="Harmony")

    beat_duration = 60.0 / tempo
    time = 0.0
    note_length = beat_duration  # 1 beat per note

    for token in melody_tokens:
        if 3 <= token <= 130:
            pitch = token - 3
            note = pretty_midi.Note(velocity=100, pitch=pitch, start=time, end=time + note_length)
            melody_instr.notes.append(note)
            time += note_length
        elif token == 2:  # EOS
            break

    time = 0.0
    for token in harmony_tokens:
        if 3 <= token <= 130:
            pitch = token - 3
            note = pretty_midi.Note(velocity=80, pitch=pitch, start=time, end=time + note_length)
            harmony_instr.notes.append(note)
            time += note_length
        elif token == 2:  # EOS
            break

    pm.instruments.append(melody_instr)
    pm.instruments.append(harmony_instr)
    pm.write(out_path)
    print(f"Saved combined MIDI to {out_path}")


In [20]:
def evaluate_baseline(model, dataloader, device):
    model.eval()
    total_loss = 0
    total_tokens = 0

    with torch.no_grad():
        for melody, harm_in, harm_out in tqdm(dataloader, desc="Evaluating Baseline"):
            melody = melody.to(device)
            harm_in = harm_in.to(device)
            harm_out = harm_out.to(device)

            output_logits = model(melody, harm_in)  # [batch, seq_len, vocab]
            logits_flat = output_logits.view(-1, output_logits.shape[-1])
            targets_flat = harm_out.view(-1)

            loss = F.cross_entropy(
                logits_flat,
                targets_flat,
                ignore_index=PAD_IDX,
                reduction='sum'  # sum to count valid tokens correctly
            )

            non_pad_mask = targets_flat != PAD_IDX
            total_loss += loss.item()
            total_tokens += non_pad_mask.sum().item()

    avg_ce_loss = total_loss / total_tokens
    perplexity = math.exp(avg_ce_loss)

    print(f"\n📊 Baseline Evaluation:")
    print(f"Cross-Entropy Loss: {avg_ce_loss:.4f}")
    print(f"Perplexity: {perplexity:.2f}")

    return avg_ce_loss, perplexity

In [21]:

# Main script
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    melody_files = {os.path.basename(p): p for p in glob(os.path.join(MELODY_DIR, "jigs*.mid"))}
    chords_files = {os.path.basename(p): p for p in glob(os.path.join(CHORDS_DIR, "jigs*.mid"))}

    common_filenames = sorted(set(melody_files.keys()) & set(chords_files.keys()))

    if not common_filenames:
        raise RuntimeError("No matching jigs*.mid files found in both melody and chords folders.")

    melody_paths = [melody_files[f] for f in common_filenames]
    chords_paths = [chords_files[f] for f in common_filenames]

    print(f"Using {len(common_filenames)} matched file pairs.")


    dataset = MelodyHarmonyDataset(melody_paths, chords_paths)
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

    vocab_size = 131  # PAD=0, SOS=1, EOS=2, + pitches 0–127 mapped to 3–130
    model = Seq2SeqHarmonizer(vocab_size).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

    for epoch in range(10):
        loss = train(model, dataloader, optimizer, criterion, device)
        print(f"Epoch {epoch+1}: Loss = {loss:.4f}")

    # Test generation from first melody
    melody, _, _ = dataset[0]
    harmony = generate(model, melody.to(device))

    write_combined_midi(melody.tolist(), harmony, "baseline_combined.mid")

    print("\nSample Melody Tokens:\n", melody.tolist())
    print("Generated Harmony Tokens:\n", harmony)

    print("\nEvaluating baseline model...")
    evaluate_baseline(model, dataloader, device)


if __name__ == "__main__":
    main()


Using 338 matched file pairs.
Epoch 1: Loss = 2.5120
Epoch 2: Loss = 1.1335
Epoch 3: Loss = 0.8960
Epoch 4: Loss = 0.7718
Epoch 5: Loss = 0.7096
Epoch 6: Loss = 0.6516
Epoch 7: Loss = 0.6074
Epoch 8: Loss = 0.5920
Epoch 9: Loss = 0.5599
Epoch 10: Loss = 0.5473
Saved combined MIDI to baseline_combined.mid

Sample Melody Tokens:
 [1, 81, 79, 76, 76, 76, 81, 79, 76, 76, 76, 81, 79, 76, 76, 76, 81, 74, 76, 74, 74, 81, 79, 76, 76, 76, 81, 79, 76, 76, 76, 76, 77, 79, 81, 79, 77, 76, 74, 72, 76, 79, 84, 81, 79, 76, 76, 76, 81, 79, 76, 76, 76, 81, 79, 76, 76, 76, 81, 74, 76, 74, 74, 2]
Generated Harmony Tokens:
 [48, 52, 55, 48, 52, 55, 48, 52, 55, 41, 45, 48, 48, 52, 55, 50, 53, 57, 43, 47, 50, 53, 48, 52, 55, 48, 52, 55, 41, 45, 48, 48, 52, 55, 50, 53, 57, 43, 47, 50, 53, 48, 52, 55, 48, 52, 55, 41, 45, 48, 48, 52, 55, 50, 53, 57, 43, 47, 50, 53, 48, 52, 55, 48]

Evaluating baseline model...


Evaluating Baseline: 100%|██████████| 22/22 [00:03<00:00,  6.19it/s]


📊 Baseline Evaluation:
Cross-Entropy Loss: 0.5195
Perplexity: 1.68



