https://github.com/jukedeck/nottingham-dataset/tree/master/MIDI/chords

Chord -> melody training

In [53]:
# Probably more imports than are really necessary...
import os
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn
import torch.nn.functional as F
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB
from tqdm import tqdm
import librosa
import numpy as np
import miditoolkit
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score, average_precision_score, accuracy_score
import random

In [54]:
def get_midi_paths(root_dir):
    midi_paths = []
    for root, _, files in os.walk(root_dir):
        for file in files:
            if file.lower().endswith('.mid') or file.lower().endswith('.midi'):
                midi_paths.append(os.path.join(root, file))
    return sorted(midi_paths)

midi_root = 'nottingham-data/MIDI'
midi_files = get_midi_paths(midi_root)

In [55]:
def extract_chords_and_melody(midi_path):
    midi_obj = miditoolkit.midi.parser.MidiFile(midi_path)
    notes = [note for note in midi_obj.instruments[0].notes if not note.is_drum]
    
    # Sort notes
    notes.sort(key=lambda x: x.start)

    # Extract melody (assuming track 0 is melody)
    melody = []
    for note in notes:
        pitch = note.pitch
        start = note.start / midi_obj.ticks_per_beat
        duration = (note.end - note.start) / midi_obj.ticks_per_beat
        melody.append((pitch, start, duration))

    # Use built-in chord extraction (or a heuristic)
    chords = []  # Replace with a real extraction
    for chord in midi_obj.markers:
        chords.append((chord.text, chord.time / midi_obj.ticks_per_beat))

    return chords, melody


### Tokenize

In [56]:
def chord_to_token(chord_name):
    return f"[{chord_name}]"
def melody_to_tokens(melody, step=0.25):
    tokens = []
    for pitch, start, dur in melody:
        pitch_str = f"{pitch}_{round(dur, 2)}"
        tokens.append(pitch_str)
    return tokens


### Dataset Class

In [57]:
class ChordMelodyDataset(Dataset):
    def __init__(self, midi_dir):
        self.data = []
        for fname in os.listdir(midi_dir):
            if fname.endswith('.mid'):
                chords, melody = extract_chords_and_melody(os.path.join(midi_dir, fname))
                chord_tokens = [chord_to_token(c[0]) for c in chords]
                melody_tokens = melody_to_tokens(melody)
                if len(chord_tokens) > 0 and len(melody_tokens) > 0:
                    self.data.append((chord_tokens, melody_tokens))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        chords, melody = self.data[idx]
        return chords, melody


### Model Build

In [58]:
class ChordToMelody(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.encoder = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.decoder = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.output = nn.Linear(hidden_dim, vocab_size)

    def forward(self, chord_seq, melody_seq):
        embedded_chords = self.embedding(chord_seq)
        _, (hidden, cell) = self.encoder(embedded_chords)

        embedded_melody = self.embedding(melody_seq)
        output, _ = self.decoder(embedded_melody, (hidden, cell))
        logits = self.output(output)
        return logits


### Get Our Tokens

In [59]:


### --- STEP 1: Pair MIDI files from melody/ and chords/ ---
def get_melody_chord_pairs(midi_root):
    melody_dir = os.path.join(midi_root, 'melody')
    chords_dir = os.path.join(midi_root, 'chords')

    melody_files = sorted([f for f in os.listdir(melody_dir) if f.endswith('.mid')])
    chord_files = sorted([f for f in os.listdir(chords_dir) if f.endswith('.mid')])

    pairs = []
    for fname in melody_files:
        melody_path = os.path.join(melody_dir, fname)
        chord_path = os.path.join(chords_dir, fname)
        if os.path.exists(chord_path):
            pairs.append((chord_path, melody_path))
    return pairs

### --- STEP 2: Load and parse MIDI files ---
def load_pair(chord_path, melody_path):
    chords_midi = miditoolkit.MidiFile(chord_path)
    melody_midi = miditoolkit.MidiFile(melody_path)
    return chords_midi, melody_midi

### --- STEP 3: Extract melody tokens ---
def extract_melody_tokens(midi, quantize_step=0.25):
    notes = midi.instruments[0].notes
    ticks_per_beat = midi.ticks_per_beat
    tokens = []
    for note in notes:
        start_beat = note.start / ticks_per_beat
        duration_beat = round((note.end - note.start) / ticks_per_beat / quantize_step) * quantize_step
        pitch = note.pitch
        tokens.append(f"{pitch}_{duration_beat:.2f}")
    return tokens

### --- STEP 4: Extract chord tokens from markers ---
def extract_chord_tokens(midi):
    markers = midi.markers
    ticks_per_beat = midi.ticks_per_beat
    chord_tokens = []
    for marker in markers:
        beat = marker.time / ticks_per_beat
        chord_name = marker.text.strip()
        chord_tokens.append((beat, f"[{chord_name}]"))
    return chord_tokens

### --- STEP 5: Align chords to melody time grid ---
def align_chords(chord_tokens, melody_tokens, step=0.25):
    melody_len = len(melody_tokens)
    result = []
    current_chord = '[N]'  # default: no chord
    chord_idx = 0
    chord_times = [beat for beat, _ in chord_tokens]

    for i in range(melody_len):
        time = i * step
        while chord_idx + 1 < len(chord_tokens) and chord_times[chord_idx + 1] <= time:
            chord_idx += 1
        if chord_tokens:
            current_chord = chord_tokens[chord_idx][1]
        result.append(current_chord)
    return result

### --- STEP 6: Extract token pair for one file ---
def extract_token_pair(chord_path, melody_path, step=0.25):
    chords_midi, melody_midi = load_pair(chord_path, melody_path)
    melody_tokens = extract_melody_tokens(melody_midi, quantize_step=step)
    chord_tokens_raw = extract_chord_tokens(chords_midi)
    chord_tokens = align_chords(chord_tokens_raw, melody_tokens, step=step)
    return chord_tokens, melody_tokens

### --- STEP 7: Load all pairs into memory ---
def extract_all_token_pairs(midi_root, step=0.25):
    pairs = get_melody_chord_pairs(midi_root)
    all_data = []
    for chord_path, melody_path in pairs:
        try:
            chords, melody = extract_token_pair(chord_path, melody_path, step=step)
            if len(chords) == len(melody) and len(chords) > 0:
                all_data.append((chords, melody))
        except Exception as e:
            print(f"Failed on {chord_path} / {melody_path}: {e}")
    return all_data

### --- STEP 8: Example usage ---
midi_root = "nottingham-dataset/MIDI"
data = extract_all_token_pairs(midi_root, step=0.25)

print(f"✅ Extracted {len(data)} chord–melody token sequences.")
print("🎵 Sample chord tokens:", data[0][0][:10])
print("🎶 Sample melody tokens:", data[0][1][:10])


✅ Extracted 1021 chord–melody token sequences.
🎵 Sample chord tokens: ['[N]', '[N]', '[N]', '[N]', '[N]', '[N]', '[N]', '[N]', '[N]', '[N]']
🎶 Sample melody tokens: ['76_1.00', '74_2.00', '71_1.00', '69_1.50', '71_0.50', '72_1.00', '71_2.00', '67_1.00', '69_2.00', '76_1.00']


### Map

In [60]:
# Build vocabulary from all chord and melody tokens
token_set = set()
for chords, melody in data:
    token_set.update(chords)
    token_set.update(melody)

# Create mapping: token <-> index
token2idx = {token: i for i, token in enumerate(sorted(token_set))}
idx2token = {i: token for token, i in token2idx.items()}

print("🧠 Vocab size:", len(token2idx))
print("🔁 Sample tokens:", list(token2idx.items())[:10])


🧠 Vocab size: 228
🔁 Sample tokens: [('55_0.50', 0), ('55_1.00', 1), ('55_2.00', 2), ('55_3.00', 3), ('56_0.50', 4), ('57_0.25', 5), ('57_0.50', 6), ('57_1.00', 7), ('57_1.50', 8), ('57_2.00', 9)]


### Skeleton

In [61]:
import torch
from torch.nn.utils.rnn import pad_sequence
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def encode_sequence(tokens, token2idx):
    return torch.tensor([token2idx[token] for token in tokens], dtype=torch.long)
encoded_data = [(encode_sequence(chords, token2idx),
                 encode_sequence(melody, token2idx))
                for chords, melody in data]
def collate_fn(batch):
    chord_batch, melody_batch = zip(*batch)

    # Pad sequences to same length
    chord_tensor = pad_sequence(chord_batch, batch_first=True, padding_value=token2idx['[N]'])
    melody_tensor = pad_sequence(melody_batch, batch_first=True, padding_value=token2idx['[N]'])

    return chord_tensor.to(device), melody_tensor.to(device)

loader = DataLoader(encoded_data, batch_size=16, shuffle=True, collate_fn=collate_fn)

model = ChordToMelody(vocab_size=len(token2idx)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss(ignore_index=token2idx['[N]'])  # ignore padding

for epoch in range(30):
    model.train()
    total_loss = 0
    for chords, melodies in loader:
        # Input: chords and melody tokens shifted right
        melody_input = melodies[:, :-1]
        melody_target = melodies[:, 1:]

        # Forward pass
        logits = model(chords, melody_input)  # [batch, seq_len, vocab]
        loss = loss_fn(logits.reshape(-1, logits.size(-1)), melody_target.reshape(-1))

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(loader)
    print(f"📘 Epoch {epoch+1}: Loss = {avg_loss:.4f}")



📘 Epoch 1: Loss = 3.6397
📘 Epoch 2: Loss = 2.9518
📘 Epoch 3: Loss = 2.7652
📘 Epoch 4: Loss = 2.6637
📘 Epoch 5: Loss = 2.5852
📘 Epoch 6: Loss = 2.5337
📘 Epoch 7: Loss = 2.4844
📘 Epoch 8: Loss = 2.4413
📘 Epoch 9: Loss = 2.4015
📘 Epoch 10: Loss = 2.3643
📘 Epoch 11: Loss = 2.3309
📘 Epoch 12: Loss = 2.3034
📘 Epoch 13: Loss = 2.2684
📘 Epoch 14: Loss = 2.2345
📘 Epoch 15: Loss = 2.1987
📘 Epoch 16: Loss = 2.1770
📘 Epoch 17: Loss = 2.1418
📘 Epoch 18: Loss = 2.1064
📘 Epoch 19: Loss = 2.0773
📘 Epoch 20: Loss = 2.0423
📘 Epoch 21: Loss = 2.0191
📘 Epoch 22: Loss = 1.9905
📘 Epoch 23: Loss = 1.9557
📘 Epoch 24: Loss = 1.9270
📘 Epoch 25: Loss = 1.8982
📘 Epoch 26: Loss = 1.8681
📘 Epoch 27: Loss = 1.8389
📘 Epoch 28: Loss = 1.8070
📘 Epoch 29: Loss = 1.7845
📘 Epoch 30: Loss = 1.7561


### Generation

In [62]:
def generate_melody(model, chord_tokens, token2idx, idx2token, max_len=100, temperature=1.0):
    model.eval()

    chord_ids = torch.tensor([[token2idx.get(token, token2idx['[N]']) for token in chord_tokens]], dtype=torch.long).to(device)
    melody_input = torch.tensor([[token2idx['[N]']]], dtype=torch.long).to(device)  # start with silence

    generated = []

    with torch.no_grad():
        for _ in range(max_len):
            logits = model(chord_ids, melody_input)
            next_token_logits = logits[:, -1, :] / temperature  # take last timestep
            probs = torch.softmax(next_token_logits, dim=-1)
            next_token_id = torch.multinomial(probs, num_samples=1)
            next_token = idx2token[next_token_id.item()]
            generated.append(next_token)

            if next_token == '[N]' and len(generated) > 10:
                break  # early stopping on silence

            melody_input = torch.cat([melody_input, next_token_id], dim=1)

    return generated
test_chords = ['C:maj', 'G:maj', 'Am', 'F', 'C:maj', 'G:maj', 'Am', 'F']
generated_melody = generate_melody(model, test_chords, token2idx, idx2token)

print("🎵 Chords:", test_chords)
print("🎶 Generated Melody:", generated_melody)


🎵 Chords: ['C:maj', 'G:maj', 'Am', 'F', 'C:maj', 'G:maj', 'Am', 'F']
🎶 Generated Melody: ['78_0.25', '76_0.25', '74_1.00', '76_0.50', '78_1.00', '78_0.50', '79_1.00', '76_0.50', '78_0.50', '76_0.50', '74_0.50', '73_1.00', '71_1.00', '69_1.00', '73_0.50', '71_4.00', '73_1.00', '81_1.00', '79_1.00', '76_0.50', '78_0.50', '76_0.50', '78_0.50', '76_0.50', '74_0.50', '73_0.50', '69_1.00', '81_1.00', '78_1.00', '76_0.50', '78_0.50', '76_0.50', '73_0.50', '74_0.50', '76_0.50', '78_0.50', '76_0.50', '74_0.50', '73_0.50', '74_0.50', '76_0.50', '78_0.50', '76_0.50', '78_0.50', '76_0.50', '73_0.50', '76_4.00', '76_1.00', '74_1.00', '76_0.50', '78_0.50', '81_0.50', '78_0.50', '76_0.25', '74_0.25', '73_0.50', '76_0.50', '74_1.00', '78_1.00', '78_0.50', '76_0.50', '78_1.00', '76_0.50', '78_0.50', '79_0.50', '81_0.50', '78_0.50', '76_0.50', '73_0.50', '76_0.50', '78_1.00', '74_1.00', '78_0.50', '79_0.50', '81_0.50', '78_0.50', '76_0.50', '78_0.50', '76_0.50', '73_0.50', '71_0.50', '69_0.50', '71_0.50

### Convert Back to Midi

In [None]:
import miditoolkit

def melody_tokens_to_midi(tokens, output_path="generated_melody.mid", start_time=0.0, velocity=100, step=0.25):
    midi = miditoolkit.midi.parser.MidiFile()
    track = miditoolkit.midi.containers.Instrument(program=0, is_drum=False, name="melody")

    current_time = start_time
    for token in tokens:
        if token == '[N]':
            current_time += step
            continue

        try:
            pitch_str, dur_str = token.split('_')
            pitch = int(pitch_str)
            duration = float(dur_str)

            note = miditoolkit.Note(
                start=int(current_time * 480),
                end=int((current_time + duration) * 480),
                pitch=pitch,
                velocity=velocity
            )
            track.notes.append(note)
            current_time += duration
        except Exception as e:
            print(f"⚠️ Skipping token '{token}': {e}")
            continue

    midi.instruments.append(track)
    midi.dump(output_path)
    print(f"✅ Saved MIDI to {output_path}")

    
melody_tokens_to_midi(generated_melody, "my_generated_song2.mid")



✅ Saved MIDI to my_generated_song.mid


### Play Our Song!

In [67]:
from IPython.display import Audio
import pretty_midi

def play_midi(path):
    pm = pretty_midi.PrettyMIDI(path)
    audio = pm.synthesize()
    return Audio(audio, rate=22050)

play_midi("symbolic_conditioned.mid")
#play_midi("my_generated_song2.mid")

In [65]:

def list_subdirectories(path):
    return [name for name in os.listdir(path)
            if os.path.isdir(os.path.join(path, name))]

midi_root = "nottingham-dataset/MIDI"
subdirs = list_subdirectories(midi_root)
print("📁 Subdirectories:", subdirs)
midi_root = "nottingham-dataset/MIDI"
print("✅ Exists:", os.path.exists(midi_root))
print("📁 Contents:", os.listdir(midi_root))

print("Current working directory:", os.getcwd())

📁 Subdirectories: ['chords', 'melody']
✅ Exists: True
📁 Contents: ['ashover1.mid', 'ashover10.mid', 'ashover11.mid', 'ashover12.mid', 'ashover13.mid', 'ashover14.mid', 'ashover15.mid', 'ashover16.mid', 'ashover17.mid', 'ashover18.mid', 'ashover19.mid', 'ashover2.mid', 'ashover20.mid', 'ashover21.mid', 'ashover22.mid', 'ashover23.mid', 'ashover24.mid', 'ashover25.mid', 'ashover26.mid', 'ashover27.mid', 'ashover28.mid', 'ashover29.mid', 'ashover3.mid', 'ashover30.mid', 'ashover31.mid', 'ashover32.mid', 'ashover33.mid', 'ashover34.mid', 'ashover35.mid', 'ashover36.mid', 'ashover37.mid', 'ashover38.mid', 'ashover39.mid', 'ashover4.mid', 'ashover40.mid', 'ashover41.mid', 'ashover42.mid', 'ashover43.mid', 'ashover44.mid', 'ashover45.mid', 'ashover46.mid', 'ashover5.mid', 'ashover6.mid', 'ashover7.mid', 'ashover8.mid', 'ashover9.mid', 'chords', 'hpps1.mid', 'hpps10.mid', 'hpps11.mid', 'hpps12.mid', 'hpps13.mid', 'hpps14.mid', 'hpps15.mid', 'hpps16.mid', 'hpps17.mid', 'hpps18.mid', 'hpps19.mid