<a href="https://colab.research.google.com/github/tom-howes/DanceMusicGPT/blob/main/dancemusicgpt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
!pip install miditok>=3.0.0 symusic torch tqdm matplotlib

In [5]:
!mkdir dance-midi
!mv *.mid dance-midi

mkdir: cannot create directory ‘dance-midi’: File exists
mv: cannot stat '*.mid': No such file or directory


In [9]:
from pathlib import Path
from symusic import Score

midi_files = list(Path('dance-midi').glob('*.mid')) + list(Path('dance-midi').glob('*.midi'))
print(f"Found {len(midi_files)} MIDI files\n")
## Optional Print Statements for track analysis ##
# Analyze a few
# total_bpm = 0
# for midi_path in midi_files:
#     try:
#         score = Score(str(midi_path))
#         print(f"{midi_path.name}")
#         print(f"  Duration: {score.end() / score.ticks_per_quarter / 4:.1f} bars (assuming 4/4) BPM: ", score.tempos[0].qpm)
#         print(f"  Tracks: {len(score.tracks)}")
#         for track in score.tracks:
#             print(f"    - {track.name or 'Unnamed'}: {len(track.notes)} notes, program {track.program}")
#         print()
#         if score.tempos[0].qpm < 50 or score.tempos[0].qpm > 150:
#           total_bpm += 120
#         else:
#           total_bpm += score.tempos[0].qpm
#     except Exception as e:
#         print(f"{midi_path.name}: Error - {e}\n")


# print(total_bpm / 38)

Found 38 MIDI files



In [16]:
from miditok import REMI, TokenizerConfig

# Configure tokenizer for dance music
config = TokenizerConfig(
    num_velocities=16,          # Quantize velocity into 16 bins
    use_chords=False,                    # Enable chord detection
    # chord_tokens_with_root_note=True,    # Include root note (e.g. "Chord_C:maj")
    use_programs=True,                  # Enable multi-instrument
    use_time_signatures=True,
    use_tempos=True,                    # Allows model to predict changes in tempo
    num_tempos=32,                      # Number of tempo bins
    tempo_range=(100, 140),             # Dance music tempo range
    one_token_stream_for_programs=True,  # Prepend instrument to Pitch, NoteOn, NoteOff (test with initially - might set to false to generate individual instruments if I run into issues)
    beat_res={(0, 4): 8, (4, 12): 4},
)

tokenizer = REMI(config)

  super().__init__(tokenizer_config, params)


In [17]:
all_tokens = []
total_tokens = 0
### Optional print for token count ###
# print(f"{'File':<60} {'Tokens':<10}")
# print("-" * 70)

for midi_path in midi_files:
  try:
    tokens = tokenizer(midi_path)
    all_tokens.append(tokens)
    total_tokens += len(tokens.ids)
    # print(f"{midi_path.name:<60} {len(tokens.ids):<10}")
  except Exception as e:
    print(f"{midi_path.name:<60} Error: {e}")

In [3]:
!pip install transformers accelerate -q

In [5]:
from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments
from torch.utils.data import Dataset
import torch

In [18]:
# Dataset for tokens
class MidiDataset(Dataset):
  def __init__(self, all_tokens, seq_length=512, stride=512):
    self.seq_length = seq_length

    self.data = []
    for token in all_tokens:
      self.data.extend(tokens)

    self.data = torch.tensor(self.data, dtype=torch.long)
    # Use stride to avoid non-overlapping chunks for faster training
    self.indices = list(range(0, len(self.data) - seq_length - 1, stride))

  def __len__(self):
    return len(self.indices)

  def __getitem__(self, idx):
    start = self.indices[idx]
    chunk = self.data[start:start + self.seq_length + 1]
    return {
        'input_ids': chunk[:-1],
        'labels' : chunk[1:]
    }
# Create dataset
dataset = MidiDataset(all_tokens, seq_length=512, stride=256)
print(f"Dataset samples: {len(dataset)}")
print(f"Steps per epoch: {len(dataset) // 8}")
print(f"Total steps (10 epochs): {len(dataset) // 8 * 10}")

Dataset samples: 941
Steps per epoch: 117
Total steps (10 epochs): 1170


In [19]:
config = GPT2Config(
    vocab_size=len(tokenizer),
    n_positions=512,
    n_embd=256,
    n_layer=6,
    n_head=8,
    bos_token_id=tokenizer['BOS_None'],
    eos_token_id=tokenizer['EOS_None']
)

model = GPT2LMHeadModel(config)
print(f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}")

Model Parameters: 4,998,656


### Attempt 1: 10 epochs, batch_size = 8

In [None]:
# # Load tensorboard extension
# %load_ext tensorboard

# training_args = TrainingArguments(
#     output_dir="./midi-gpt2",
#     overwrite_output_dir=True,
#     num_train_epochs=10,
#     per_device_train_batch_size=8,
#     learning_rate=5e-4,
#     warmup_steps=100,
#     logging_steps=50,
#     logging_dir="./logs",       # TensorBoard logs
#     save_steps=500,
#     save_total_limit=2,
#     report_to="tensorboard",    # Use tensorboard
# )

# trainer = Trainer(
#     model=model,
#     args=training_args,
#     train_dataset=dataset,
# )

# trainer.train()

In [None]:
# # Generate a sequence
# model.eval()

# # Start with a bar token
# start_tokens = torch.tensor([[tokenizer['Bar_None']]], device=device)

# with torch.no_grad():
#     output = model.generate(
#         start_tokens,
#         max_length=500,
#         temperature=1.0,
#         top_p=0.95,
#         do_sample=True,
#     )

# # Decode
# generated = [tokenizer[tok_id.item()] for tok_id in output[0]]
# print("Generated tokens:")
# for i, tok in enumerate(generated[:100]):
#     print(f"  {i}: {tok}")

In [None]:
# # Decode the generated tokens back to MIDI
# from miditok import TokSequence

# # Get the generated token IDs (not strings)
# generated_ids = output[0].tolist()

# # Create a TokSequence and decode to MIDI
# generated_midi = tokenizer.decode(generated_ids)
# generated_midi.dump_midi("generated_sample.mid")

# print("Saved to generated_sample.mid")

In [None]:
# from IPython.display import FileLink
# FileLink("generated_sample.mid")

In [20]:
# Train for more epochs
training_args = TrainingArguments(
    output_dir="./midi-gpt2",
    overwrite_output_dir=True,
    num_train_epochs=50,
    per_device_train_batch_size=16,
    learning_rate=3e-4,
    warmup_steps=100,
    logging_steps=100,
    save_steps=500,
    save_total_limit=2,
    report_to="none",
    fp16=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
)

trainer.train()

`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Step,Training Loss
100,3.4187
200,1.3544
300,0.8583
400,0.5068
500,0.289
600,0.1602
700,0.0965
800,0.0605
900,0.0423
1000,0.0342


TrainOutput(global_step=2950, training_loss=0.24092079044398615, metrics={'train_runtime': 258.3668, 'train_samples_per_second': 182.105, 'train_steps_per_second': 11.418, 'total_flos': 684974093107200.0, 'train_loss': 0.24092079044398615, 'epoch': 50.0})

In [24]:
torch.cuda.empty_cache()
if torch.cuda.is_available():
    torch.cuda.reset_peak_memory_stats()

AcceleratorError: CUDA error: device-side assert triggered
Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [22]:
model.eval()

device = 'cpu'
model.to(device)
# Seed with real music
seed_length = 50
seed_tokens = torch.tensor([all_tokens[0][:seed_length]], device=device)

with torch.no_grad():
    output = model.generate(
        seed_tokens,
        max_length=2000,        # Longer output
        temperature=0.9,
        top_p=0.95,
        do_sample=True,
        pad_token_id=tokenizer['PAD_None'],
    )

generated_ids = output[0].tolist()
generated_midi = tokenizer.decode(generated_ids)
generated_midi.dump_midi("generated_v2.mid")

# Check
score = Score("generated_v2.mid")
print(f"Tracks: {len(score.tracks)}")
print(f"Total notes: {sum(len(t.notes) for t in score.tracks)}")
print(f"Duration: {score.end() / score.ticks_per_quarter / 2:.1f} bars")

from IPython.display import FileLink
FileLink("generated_v2.mid")

AcceleratorError: CUDA error: device-side assert triggered
Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [20]:
# Get the correct vocab size and special tokens
vocab_size = len(tokenizer)
pad_token_id = tokenizer['PAD_None']
eos_token_id = tokenizer['EOS_None']
bos_token_id = tokenizer['BOS_None']

print(f"Vocab size: {vocab_size}")
print(f"PAD: {pad_token_id}, EOS: {eos_token_id}, BOS: {bos_token_id}")

# Check seed tokens are valid
seed_data = tokens.ids[:50]
print(f"Seed token range: {min(seed_data)} to {max(seed_data)}")

# Move model to CPU
model = GPT2LMHeadModel.from_pretrained("./midi-gpt2/checkpoint-2950")
model = model.to('cpu')
model.eval()

# Create seed on CPU (no device argument)
seed_tokens = torch.tensor([all_tokens[0][:50]])

with torch.no_grad():
    output = model.generate(
        seed_tokens,
        max_length=500,
        temperature=0.9,
        top_p=0.95,
        do_sample=True,
    )

generated_ids = output[0].tolist()
print(f"Generated {len(generated_ids)} tokens")

generated_midi = tokenizer.decode(generated_ids)
generated_midi.dump_midi("generated.mid")

from symusic import Score
score = Score("generated.mid")
print(f"Notes: {sum(len(t.notes) for t in score.tracks)}")

from google.colab import files
files.download("generated.mid")

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Vocab size: 502
PAD: 0, EOS: 2, BOS: 1
Seed token range: 4 to 501
Generated 500 tokens
Notes: 8


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [21]:
# Start with just a Bar token
start_tokens = torch.tensor([[tokenizer['Bar_None']]])

with torch.no_grad():
    output = model.generate(
        start_tokens,
        max_length=500,
        temperature=0.95,
        top_p=0.95,
        do_sample=True,
        pad_token_id=tokenizer['PAD_None'],
        eos_token_id=tokenizer['EOS_None'],
    )

generated_ids = output[0].tolist()
print(f"Generated {len(generated_ids)} tokens")

generated_midi = tokenizer.decode(generated_ids)
generated_midi.dump_midi("generated_no_seed.mid")

from google.colab import files
files.download("generated_no_seed.mid")

Generated 500 tokens


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [22]:
def generate_long(model, tokenizer, num_chunks=4, chunk_length=450, overlap=50):
    """Generate longer sequences by chaining chunks together"""

    # Start with a bar token
    generated = [tokenizer['Bar_None']]

    for i in range(num_chunks):
        # Use last `overlap` tokens as context
        context = generated[-overlap:] if len(generated) > overlap else generated
        input_tokens = torch.tensor([context])

        with torch.no_grad():
            output = model.generate(
                input_tokens,
                max_length=len(context) + chunk_length,
                temperature=0.9,
                top_p=0.95,
                do_sample=True,
                pad_token_id=tokenizer['PAD_None'],
                eos_token_id=tokenizer['EOS_None'],
            )

        # Add new tokens (skip the context we provided)
        new_tokens = output[0].tolist()[len(context):]
        generated.extend(new_tokens)
        print(f"Chunk {i+1}: {len(new_tokens)} new tokens, total: {len(generated)}")

    return generated

# Generate ~2000 tokens (roughly 16-32 bars)
generated_ids = generate_long(model, tokenizer, num_chunks=4, chunk_length=450, overlap=50)

generated_midi = tokenizer.decode(generated_ids)
generated_midi.dump_midi("generated_long.mid")

from symusic import Score
score = Score("generated_long.mid")
print(f"Total notes: {sum(len(t.notes) for t in score.tracks)}")
print(f"Duration: ~{score.end() / score.ticks_per_quarter / 4:.0f} bars")

from google.colab import files
files.download("generated_long.mid")

Chunk 1: 450 new tokens, total: 451
Chunk 2: 450 new tokens, total: 901
Chunk 3: 450 new tokens, total: 1351
Chunk 4: 450 new tokens, total: 1801
Total notes: 0
Duration: ~0 bars


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [24]:
start_tokens = torch.tensor([[tokenizer['Bar_None']]])

with torch.no_grad():
    output = model.generate(
        start_tokens,
        max_length=500,
        temperature=0.7,        # Lower = more conservative
        top_k=50,               # Limit choices
        top_p=0.9,
        do_sample=True,
        repetition_penalty=1.2, # Penalize repetition
        pad_token_id=tokenizer['PAD_None'],
        eos_token_id=tokenizer['EOS_None'],
    )

generated_ids = output[0].tolist()

# Check output
print("First 50 tokens:")
for i, tok_id in enumerate(generated_ids[:50]):
    print(f"  {i}: {tokenizer[tok_id]}")

generated_midi = tokenizer.decode(generated_ids)
generated_midi.dump_midi("generated_v2.mid")

from symusic import Score
score = Score("generated_v2.mid")
print(f"\nTotal notes: {sum(len(t.notes) for t in score.tracks)}")

First 50 tokens:
  0: Bar_None
  1: Position_0
  2: Pitch_29
  3: Duration_0.5.8
  4: Pitch_80
  5: Duration_0.3.8
  6: Pitch_68
  7: Duration_0.3.8
  8: Pitch_65
  9: Duration_0.3.8
  10: Pitch_72
  11: Duration_0.3.8
  12: PitchDrum_35
  13: Duration_0.1.8
  14: Pitch_72
  15: Duration_0.3.8
  16: Pitch_73
  17: Duration_0.3.8
  18: Pitch_70
  19: Duration_0.3.8
  20: Pitch_65
  21: Duration_0.3.8
  22: Pitch_61
  23: Duration_0.3.8
  24: Pitch_65
  25: Duration_0.3.8
  26: Pitch_61
  27: Duration_0.3.8
  28: Pitch_65
  29: Duration_4.0.4
  30: Pitch_58
  31: Duration_4.0.4
  32: Pitch_54
  33: Duration_4.0.4
  34: PitchDrum_35
  35: Duration_0.1.8
  36: Program_76
  37: Velocity_111
  38: Program_-1
  39: Velocity_127
  40: Program_99
  41: Velocity_111
  42: Program_99
  43: Velocity_111
  44: Program_0
  45: Velocity_7
  46: Program_0
  47: Velocity_7
  48: Program_0
  49: Velocity_7

Total notes: 0


In [25]:
# Use first 20 tokens from real data as seed
seed = all_tokens[0][:20]
print("Seed tokens:")
for i, tok_id in enumerate(seed):
    print(f"  {i}: {tokenizer[tok_id]}")

seed_tokens = torch.tensor([seed])

with torch.no_grad():
    output = model.generate(
        seed_tokens,
        max_length=500,
        temperature=0.8,
        top_p=0.9,
        do_sample=True,
        repetition_penalty=1.2,
        pad_token_id=tokenizer['PAD_None'],
        eos_token_id=tokenizer['EOS_None'],
    )

generated_ids = output[0].tolist()
generated_midi = tokenizer.decode(generated_ids)
generated_midi.dump_midi("generated_seeded.mid")

from symusic import Score
score = Score("generated_seeded.mid")
print(f"Total notes: {sum(len(t.notes) for t in score.tracks)}")

from google.colab import files
files.download("generated_seeded.mid")

Seed tokens:
  0: Bar_None
  1: TimeSig_4/4
  2: Position_0
  3: Tempo_140.0
  4: Bar_None
  5: TimeSig_4/4
  6: Bar_None
  7: TimeSig_4/4
  8: Position_16
  9: Program_-1
  10: PitchDrum_36
  11: Velocity_79
  12: Duration_1.0.8
  13: Position_24
  14: Program_-1
  15: PitchDrum_38
  16: Velocity_103
  17: Duration_0.2.8
  18: Program_-1
  19: PitchDrum_36
Total notes: 2


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

### Longer Seed more instruments

In [26]:
# Use first 100 tokens from a full song as seed
seed = all_tokens[0][:100]

seed_tokens = torch.tensor([seed])

with torch.no_grad():
    output = model.generate(
        seed_tokens,
        max_length=500,
        temperature=0.85,
        top_p=0.9,
        top_k=50,
        do_sample=True,
        repetition_penalty=1.2,
        pad_token_id=tokenizer['PAD_None'],
        eos_token_id=tokenizer['EOS_None'],
    )

generated_ids = output[0].tolist()
generated_midi = tokenizer.decode(generated_ids)
generated_midi.dump_midi("generated_v3.mid")

from symusic import Score
score = Score("generated_v3.mid")
print(f"Total notes: {sum(len(t.notes) for t in score.tracks)}")
print(f"Duration: ~{score.end() / score.ticks_per_quarter / 4:.0f} bars")

# Check what instruments
for track in score.tracks:
    if len(track.notes) > 0:
        print(f"  Program {track.program}: {len(track.notes)} notes")

from google.colab import files
files.download("generated_v3.mid")

Total notes: 18
Duration: ~6 bars
  Program 0: 18 notes


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

### Exploring different songs as seeds

In [29]:
# Try seeds from different songs
for i, song_tokens in enumerate(all_tokens[:5]):
    seed = song_tokens[:100]
    seed_tokens = torch.tensor([seed])

    with torch.no_grad():
        output = model.generate(
            seed_tokens,
            max_length=500,
            temperature=0.85,
            top_p=0.9,
            do_sample=True,
            repetition_penalty=1.2,
            pad_token_id=tokenizer['PAD_None'],
            eos_token_id=tokenizer['EOS_None'],
        )

    generated_ids = output[0].tolist()
    score = tokenizer.decode(generated_ids)  # This returns a Score directly

    total_notes = sum(len(t.notes) for t in score.tracks)
    print(f"Song {i} seed: {total_notes} notes")

Song 0 seed: 18 notes
Song 1 seed: 20 notes
Song 2 seed: 16 notes
Song 3 seed: 21 notes
Song 4 seed: 18 notes


### Issue with drums only - Exploring using middle of track as seed

In [30]:
# Seed from middle of songs (skip intro)
for i, song_tokens in enumerate(all_tokens[:5]):
    # Skip first 500 tokens (intro), take 100 from the middle
    start_pos = min(500, len(song_tokens) // 2)
    seed = song_tokens[start_pos:start_pos + 100]
    seed_tokens = torch.tensor([seed])

    with torch.no_grad():
        output = model.generate(
            seed_tokens,
            max_length=500,
            temperature=0.85,
            top_p=0.9,
            do_sample=True,
            repetition_penalty=1.2,
            pad_token_id=tokenizer['PAD_None'],
            eos_token_id=tokenizer['EOS_None'],
        )

    generated_ids = output[0].tolist()
    score = tokenizer.decode(generated_ids)

    total_notes = sum(len(t.notes) for t in score.tracks)

    # Count by instrument type
    drum_notes = sum(len(t.notes) for t in score.tracks if t.program == -1 or t.is_drum)
    other_notes = total_notes - drum_notes

    print(f"Song {i}: {total_notes} notes ({drum_notes} drums, {other_notes} other)")

Song 0: 22 notes (6 drums, 16 other)
Song 1: 21 notes (7 drums, 14 other)
Song 2: 21 notes (2 drums, 19 other)
Song 3: 22 notes (13 drums, 9 other)
Song 4: 22 notes (3 drums, 19 other)


In [31]:
# Use middle of song 2 as seed (good drum/other ratio)
song_tokens = all_tokens[2]
start_pos = len(song_tokens) // 2
seed = song_tokens[start_pos:start_pos + 100]

seed_tokens = torch.tensor([seed])

# Generate longer with chunking
def generate_long(model, tokenizer, seed, num_chunks=6, chunk_length=400, overlap=50):
    generated = list(seed)

    for i in range(num_chunks):
        context = generated[-overlap:]
        input_tokens = torch.tensor([context])

        with torch.no_grad():
            output = model.generate(
                input_tokens,
                max_length=len(context) + chunk_length,
                temperature=0.85,
                top_p=0.9,
                do_sample=True,
                repetition_penalty=1.2,
                pad_token_id=tokenizer['PAD_None'],
                eos_token_id=tokenizer['EOS_None'],
            )

        new_tokens = output[0].tolist()[len(context):]
        generated.extend(new_tokens)

        # Progress check
        score = tokenizer.decode(generated)
        notes = sum(len(t.notes) for t in score.tracks)
        print(f"Chunk {i+1}: {notes} total notes")

    return generated

generated_ids = generate_long(model, tokenizer, seed, num_chunks=6)

# Save and download
score = tokenizer.decode(generated_ids)
score.dump_midi("generated_full.mid")

total_notes = sum(len(t.notes) for t in score.tracks)
drum_notes = sum(len(t.notes) for t in score.tracks if t.is_drum)
print(f"\nFinal: {total_notes} notes ({drum_notes} drums, {total_notes - drum_notes} other)")
print(f"Duration: ~{score.end() / score.ticks_per_quarter / 4:.0f} bars")

from google.colab import files
files.download("generated_full.mid")

Chunk 1: 23 total notes
Chunk 2: 23 total notes
Chunk 3: 23 total notes
Chunk 4: 23 total notes
Chunk 5: 23 total notes
Chunk 6: 23 total notes

Final: 23 notes (2 drums, 21 other)
Duration: ~2 bars


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

### Lower temperature, top 20 choices, higher repetition penalty

In [32]:
song_tokens = all_tokens[2]
start_pos = len(song_tokens) // 2
seed = song_tokens[start_pos:start_pos + 100]

seed_tokens = torch.tensor([seed])

with torch.no_grad():
    output = model.generate(
        seed_tokens,
        max_length=500,
        temperature=0.5,          # Much lower - more predictable
        top_k=20,                 # Only top 20 choices
        top_p=0.85,
        do_sample=True,
        repetition_penalty=1.3,   # Stronger repetition penalty
        pad_token_id=tokenizer['PAD_None'],
        eos_token_id=tokenizer['EOS_None'],
    )

generated_ids = output[0].tolist()
score = tokenizer.decode(generated_ids)
score.dump_midi("generated_conservative.mid")

print(f"Notes: {sum(len(t.notes) for t in score.tracks)}")

from google.colab import files
files.download("generated_conservative.mid")

Notes: 23


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

### Greedy gen

In [33]:
with torch.no_grad():
    output = model.generate(
        seed_tokens,
        max_length=500,
        do_sample=False,          # Greedy - always pick most likely
        pad_token_id=tokenizer['PAD_None'],
        eos_token_id=tokenizer['EOS_None'],
    )

generated_ids = output[0].tolist()
score = tokenizer.decode(generated_ids)
score.dump_midi("generated_greedy.mid")

from google.colab import files
files.download("generated_greedy.mid")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

## GPT-2 V2 - Trying regularization, larger model and more epochs

### Setup

In [1]:
!pip install miditok>=3.0.0 symusic transformers accelerate -q

import os
import torch
from pathlib import Path
from torch.utils.data import Dataset
from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments
from miditok import REMI, TokenizerConfig

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

Using device: cuda


### Tokenizer

In [2]:
config = TokenizerConfig(
    num_velocities=16,
    use_chords=False,
    use_programs=True,
    one_token_stream_for_programs=True,
    use_time_signatures=True,
    use_tempos=True,
    nb_tempos=32,
    tempo_range=(100, 140),
    beat_res={(0, 4): 8, (4, 12): 4},
)

tokenizer = REMI(config)
print(f"Vocab size: {len(tokenizer)}")

Vocab size: 502


  config = TokenizerConfig(
  super().__init__(tokenizer_config, params)


### Load and tokenize MIDIs

In [6]:
midi_files = list(Path('dance-midi').glob('*.mid')) + list(Path('dance-midi').glob('*.midi'))

all_tokens = []
for midi_path in midi_files:
    try:
        tokens = tokenizer(midi_path)
        all_tokens.append(tokens.ids)
    except Exception as e:
        print(f"Error with {midi_path.name}: {e}")

total_tokens = sum(len(t) for t in all_tokens)
print(f"Tokenized {len(all_tokens)} files")
print(f"Total tokens: {total_tokens:,}")

Tokenized 38 files
Total tokens: 1,041,915


### Dataset with more overlap

In [7]:
class MidiDataset(Dataset):
    def __init__(self, all_tokens, seq_length=512, stride=256):  # 50% overlap
        self.seq_length = seq_length
        self.data = []
        for tokens in all_tokens:
            self.data.extend(tokens)
        self.data = torch.tensor(self.data, dtype=torch.long)
        self.indices = list(range(0, len(self.data) - seq_length - 1, stride))

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        start = self.indices[idx]
        chunk = self.data[start:start + self.seq_length + 1]
        return {'input_ids': chunk[:-1], 'labels': chunk[1:]}

dataset = MidiDataset(all_tokens, seq_length=512, stride=256)
print(f"Dataset samples: {len(dataset)}")

Dataset samples: 4068


### Bigger Model w/ Regularization (dropout)

In [9]:
vocab_size = max(max(t) for t in all_tokens) + 1

config = GPT2Config(
    vocab_size=vocab_size,
    n_positions=512,
    n_embd=512,       # Doubled from 256
    n_layer=12,       # Doubled from 6
    n_head=8,
    resid_pdrop=0.1,  # Dropout on residual connections
    embd_pdrop=0.1,   # Dropout on embeddings
    attn_pdrop=0.1,   # Dropout on attention
)

model = GPT2LMHeadModel(config)
model.to(device)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

Parameters: 38,348,800


### Train with early stopping

In [10]:
training_args = TrainingArguments(
    output_dir="./midi-gpt2-large",
    overwrite_output_dir=True,
    num_train_epochs=30,          # Fewer epochs
    per_device_train_batch_size=8,
    learning_rate=1e-4,           # Lower learning rate
    warmup_steps=200,
    logging_steps=50,
    save_steps=500,
    save_total_limit=3,
    report_to="none",
    fp16=True,
    weight_decay=0.1,             # L2 regularization
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
)

trainer.train()

`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Step,Training Loss
50,4.9382
100,3.3836
150,2.8669
200,2.5719
250,2.4208
300,2.3135
350,2.2017
400,2.1524
450,2.0588
500,1.9795


KeyboardInterrupt: 

### Try model at different checkpoints to check overfitting

In [11]:
model = GPT2LMHeadModel.from_pretrained("./midi-gpt2-large/checkpoint-2000")
model = model.to('cpu')
model.eval()
# Use first 20 tokens from real data as seed
seed = all_tokens[0][:20]
print("Seed tokens:")
for i, tok_id in enumerate(seed):
    print(f"  {i}: {tokenizer[tok_id]}")

seed_tokens = torch.tensor([seed])

with torch.no_grad():
    output = model.generate(
        seed_tokens,
        max_length=500,
        temperature=0.8,
        top_p=0.9,
        do_sample=True,
        repetition_penalty=1.2,
        pad_token_id=tokenizer['PAD_None'],
        eos_token_id=tokenizer['EOS_None'],
    )

generated_ids = output[0].tolist()
generated_midi = tokenizer.decode(generated_ids)
generated_midi.dump_midi("generated_seeded.mid")

from symusic import Score
score = Score("generated_seeded.mid")
print(f"Total notes: {sum(len(t.notes) for t in score.tracks)}")

from google.colab import files
files.download("generated_seeded.mid")

Seed tokens:
  0: Bar_None
  1: TimeSig_4/4
  2: Position_0
  3: Tempo_140.0
  4: Bar_None
  5: TimeSig_4/4
  6: Bar_None
  7: TimeSig_4/4
  8: Position_16
  9: Program_-1
  10: PitchDrum_36
  11: Velocity_79
  12: Duration_1.0.8
  13: Position_24
  14: Program_-1
  15: PitchDrum_38
  16: Velocity_103
  17: Duration_0.2.8
  18: Program_-1
  19: PitchDrum_36
Total notes: 2


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [15]:
song_tokens = all_tokens[0]
start_pos = len(song_tokens) // 2
seed = song_tokens[start_pos:start_pos + 100]

seed_tokens = torch.tensor([seed])

with torch.no_grad():
    output = model.generate(
        seed_tokens,
        max_length=500,
        temperature=0.5,          # Much lower - more predictable
        top_k=20,                 # Only top 20 choices
        top_p=0.85,
        do_sample=True,
        repetition_penalty=1.3,   # Stronger repetition penalty
        pad_token_id=tokenizer['PAD_None'],
        eos_token_id=tokenizer['EOS_None'],
    )

generated_ids = output[0].tolist()
score = tokenizer.decode(generated_ids)
score.dump_midi("tainted_love_seed.mid")

print(f"Notes: {sum(len(t.notes) for t in score.tracks)}")

from google.colab import files
files.download("tainted_love_seed.mid")

Notes: 22


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [16]:
# Generate long unique sample
def generate_long_unique(model, tokenizer, all_tokens, num_chunks=12, chunk_length=400, overlap=50):
    """Generate longer sequences, starting from random positions in different songs"""
    import random

    # Start with a random seed from a random song
    song_idx = random.randint(0, len(all_tokens) - 1)
    start_pos = len(all_tokens[song_idx]) // 2
    generated = list(all_tokens[song_idx][start_pos:start_pos + 50])

    for i in range(num_chunks):
        context = generated[-overlap:]
        input_tokens = torch.tensor([context])

        with torch.no_grad():
            output = model.generate(
                input_tokens,
                max_length=len(context) + chunk_length,
                temperature=1.0,          # Higher = more creative/random
                top_k=100,                # More choices
                top_p=0.95,
                do_sample=True,
                repetition_penalty=1.3,   # Discourage repetition
                no_repeat_ngram_size=8,   # No repeated 8-token phrases
                pad_token_id=tokenizer['PAD_None'],
                eos_token_id=tokenizer['EOS_None'],
            )

        new_tokens = output[0].tolist()[len(context):]
        generated.extend(new_tokens)

        # Progress
        score = tokenizer.decode(generated)
        notes = sum(len(t.notes) for t in score.tracks)
        print(f"Chunk {i+1}/{num_chunks}: {notes} notes, {len(generated)} tokens")

    return generated

# Generate ~5000 tokens (roughly 30-60 seconds of music)
generated_ids = generate_long_unique(model, tokenizer, all_tokens, num_chunks=12, chunk_length=400, overlap=50)

# Save
score = tokenizer.decode(generated_ids)
score.dump_midi("generated_long_unique.mid")

total_notes = sum(len(t.notes) for t in score.tracks)
drum_notes = sum(len(t.notes) for t in score.tracks if t.is_drum)
duration_bars = score.end() / score.ticks_per_quarter / 4

print(f"\n=== Final ===")
print(f"Total notes: {total_notes} ({drum_notes} drums, {total_notes - drum_notes} other)")
print(f"Duration: ~{duration_bars:.0f} bars")
print(f"Tokens: {len(generated_ids)}")

from google.colab import files
files.download("generated_long_unique.mid")

Chunk 1/12: 11 notes, 450 tokens
Chunk 2/12: 11 notes, 850 tokens
Chunk 3/12: 11 notes, 1250 tokens
Chunk 4/12: 11 notes, 1650 tokens
Chunk 5/12: 11 notes, 2050 tokens
Chunk 6/12: 11 notes, 2450 tokens
Chunk 7/12: 11 notes, 2850 tokens
Chunk 8/12: 11 notes, 3250 tokens
Chunk 9/12: 11 notes, 3650 tokens
Chunk 10/12: 11 notes, 4050 tokens
Chunk 11/12: 11 notes, 4450 tokens
Chunk 12/12: 11 notes, 4850 tokens

=== Final ===
Total notes: 11 (2 drums, 9 other)
Duration: ~1 bars
Tokens: 4850


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

### Continuation Tests

In [18]:
# Test: continue a real song and see if it sounds coherent
song_idx = 1
seed = all_tokens[song_idx][500:600]  # Middle of song

seed_tokens = torch.tensor([seed])

with torch.no_grad():
    output = model.generate(
        seed_tokens,
        max_length=500,
        temperature=0.7,
        do_sample=True,
        pad_token_id=tokenizer['PAD_None'],
        eos_token_id=tokenizer['EOS_None'],
    )

generated_ids = output[0].tolist()
score = tokenizer.decode(generated_ids)
score.dump_midi("continuation_test.mid")

print(f"Notes: {sum(len(t.notes) for t in score.tracks)}")

from google.colab import files
files.download("continuation_test.mid")

Notes: 21


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [19]:
def continue_song(model, tokenizer, seed_tokens, total_length=2000, chunk_length=400, overlap=100):
    """Continue from a seed, generating in chunks"""

    generated = list(seed_tokens)

    while len(generated) < total_length:
        context = generated[-overlap:]
        input_tokens = torch.tensor([context])

        with torch.no_grad():
            output = model.generate(
                input_tokens,
                max_length=len(context) + chunk_length,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                repetition_penalty=1.1,
                pad_token_id=tokenizer['PAD_None'],
                eos_token_id=tokenizer['EOS_None'],
            )

        new_tokens = output[0].tolist()[len(context):]
        generated.extend(new_tokens)

        # Progress
        score = tokenizer.decode(generated)
        notes = sum(len(t.notes) for t in score.tracks)
        bars = score.end() / score.ticks_per_quarter / 4
        print(f"Tokens: {len(generated)}, Notes: {notes}, Bars: {bars:.0f}")

    return generated

# Continue from middle of song 0
song_idx = 0
seed = all_tokens[song_idx][500:600]

generated_ids = continue_song(model, tokenizer, seed, total_length=5000)

score = tokenizer.decode(generated_ids)
score.dump_midi("long_continuation.mid")

total_notes = sum(len(t.notes) for t in score.tracks)
bars = score.end() / score.ticks_per_quarter / 4
print(f"\n=== Final ===")
print(f"Notes: {total_notes}")
print(f"Bars: {bars:.0f}")

from google.colab import files
files.download("long_continuation.mid")

Tokens: 500, Notes: 22, Bars: 2
Tokens: 900, Notes: 22, Bars: 2
Tokens: 1300, Notes: 22, Bars: 2
Tokens: 1700, Notes: 22, Bars: 2
Tokens: 2100, Notes: 22, Bars: 2
Tokens: 2500, Notes: 22, Bars: 2
Tokens: 2900, Notes: 22, Bars: 2
Tokens: 3300, Notes: 22, Bars: 2
Tokens: 3700, Notes: 22, Bars: 2
Tokens: 4100, Notes: 22, Bars: 2
Tokens: 4500, Notes: 22, Bars: 2
Tokens: 4900, Notes: 22, Bars: 2
Tokens: 5300, Notes: 22, Bars: 2

=== Final ===
Notes: 22
Bars: 2


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [20]:
def continue_song_debug(model, tokenizer, seed_tokens, total_length=3000, chunk_length=400, overlap=100):
    """Continue with debugging to see what's happening"""

    generated = list(seed_tokens)

    # Check initial seed
    score = tokenizer.decode(generated)
    initial_notes = sum(len(t.notes) for t in score.tracks)
    print(f"Seed: {len(generated)} tokens, {initial_notes} notes")

    chunk_num = 0
    while len(generated) < total_length:
        chunk_num += 1
        context = generated[-overlap:]
        input_tokens = torch.tensor([context])

        with torch.no_grad():
            output = model.generate(
                input_tokens,
                max_length=len(context) + chunk_length,
                temperature=0.75,
                top_p=0.9,
                do_sample=True,
                repetition_penalty=1.2,
                pad_token_id=tokenizer['PAD_None'],
                eos_token_id=tokenizer['EOS_None'],
            )

        new_tokens = output[0].tolist()[len(context):]

        # Debug: check what tokens are being generated
        print(f"\nChunk {chunk_num}: {len(new_tokens)} new tokens")
        print(f"First 20 new tokens:")
        for i, tok_id in enumerate(new_tokens[:20]):
            print(f"  {tokenizer[tok_id]}")

        generated.extend(new_tokens)

        # Check if notes increased
        score = tokenizer.decode(generated)
        notes = sum(len(t.notes) for t in score.tracks)
        bars = score.end() / score.ticks_per_quarter / 4
        print(f"Total: {len(generated)} tokens, {notes} notes, {bars:.0f} bars")

        if chunk_num >= 3:  # Just check first 3 chunks
            break

    return generated

# Debug
seed = all_tokens[0][500:600]
generated_ids = continue_song_debug(model, tokenizer, seed, total_length=3000)

Seed: 100 tokens, 22 notes

Chunk 1: 400 new tokens
First 20 new tokens:
  Bar_None
  Position_0
  Pitch_39
  Duration_1.0.8
  Pitch_75
  Duration_1.0.8
  Pitch_75
  Duration_1.0.8
  Program_32
  Velocity_71
  Program_16
  Velocity_127
  Program_17
  Velocity_55
  Program_53
  Velocity_79
  Position_8
  Pitch_39
  Duration_1.0.8
  PitchDrum_38
Total: 500 tokens, 22 notes, 2 bars

Chunk 2: 400 new tokens
First 20 new tokens:
  Velocity_79
  Position_24
  PitchDrum_38
  Duration_0.3.8
  PitchDrum_36
  Duration_1.0.8
  TimeSig_4/4
  Program_32
  Velocity_71
  Program_16
  Velocity_127
  Program_17
  Velocity_55
  Program_27
  Velocity_111
  Program_27
  Velocity_79
  Position_28
  Pitch_55
  Duration_0.4.8
Total: 900 tokens, 22 notes, 2 bars

Chunk 3: 400 new tokens
First 20 new tokens:
  Program_53
  Velocity_79
  Position_13
  Pitch_63
  Duration_0.5.8
  PitchDrum_38
  Duration_0.2.8
  PitchDrum_36
  Duration_1.0.8
  Program_53
  Velocity_103
  Position_16
  PitchDrum_36
  Duration_1.0.

In [21]:
# Compare real vs generated token patterns
print("=== REAL DATA (from training) ===")
real_tokens = all_tokens[0][500:550]
for i, tok_id in enumerate(real_tokens):
    print(f"{i:3d}: {tokenizer[tok_id]}")

print("\n=== GENERATED ===")
for i, tok_id in enumerate(generated_ids[100:150]):
    print(f"{i:3d}: {tokenizer[tok_id]}")

=== REAL DATA (from training) ===
  0: Velocity_87
  1: Duration_0.3.8
  2: Position_24
  3: Program_-1
  4: PitchDrum_38
  5: Velocity_103
  6: Duration_0.3.8
  7: Program_-1
  8: PitchDrum_36
  9: Velocity_79
 10: Duration_1.0.8
 11: Bar_None
 12: TimeSig_4/4
 13: Position_0
 14: Program_32
 15: Pitch_31
 16: Velocity_71
 17: Duration_0.5.8
 18: Program_16
 19: Pitch_31
 20: Velocity_127
 21: Duration_0.5.8
 22: Program_61
 23: Pitch_55
 24: Velocity_127
 25: Duration_0.5.8
 26: Program_61
 27: Pitch_43
 28: Velocity_127
 29: Duration_0.5.8
 30: Program_17
 31: Pitch_67
 32: Velocity_63
 33: Duration_0.5.8
 34: Program_-1
 35: PitchDrum_36
 36: Velocity_79
 37: Duration_1.0.8
 38: Position_8
 39: Program_32
 40: Pitch_31
 41: Velocity_71
 42: Duration_1.1.8
 43: Program_16
 44: Pitch_31
 45: Velocity_127
 46: Duration_1.1.8
 47: Program_61
 48: Pitch_43
 49: Velocity_111

=== GENERATED ===
  0: Bar_None
  1: Position_0
  2: Pitch_39
  3: Duration_1.0.8
  4: Pitch_75
  5: Duration_1.0

In [22]:
def generate_constrained(model, tokenizer, seed, max_new_tokens=500):
    """Generate while enforcing valid token order"""

    generated = list(seed)

    # Track what token type we expect next
    # Valid order: Position -> Program -> Pitch -> Velocity -> Duration

    for _ in range(max_new_tokens):
        context = generated[-100:]
        input_tokens = torch.tensor([context])

        with torch.no_grad():
            outputs = model(input_tokens)
            logits = outputs.logits[0, -1, :]  # Last position

        # Get last token to determine what should come next
        last_tok = tokenizer[generated[-1]]

        # Mask invalid tokens based on grammar
        mask = torch.zeros_like(logits)

        if 'Position' in last_tok or 'Bar' in last_tok or 'TimeSig' in last_tok:
            # After Position/Bar, allow Program or Tempo
            for i in range(len(tokenizer)):
                tok = tokenizer[i]
                if 'Program' in tok or 'Tempo' in tok:
                    mask[i] = 1
        elif 'Program' in last_tok:
            # After Program, allow Pitch or PitchDrum
            for i in range(len(tokenizer)):
                tok = tokenizer[i]
                if 'Pitch' in tok:
                    mask[i] = 1
        elif 'Pitch' in last_tok:
            # After Pitch, allow Velocity
            for i in range(len(tokenizer)):
                tok = tokenizer[i]
                if 'Velocity' in tok:
                    mask[i] = 1
        elif 'Velocity' in last_tok:
            # After Velocity, allow Duration
            for i in range(len(tokenizer)):
                tok = tokenizer[i]
                if 'Duration' in tok:
                    mask[i] = 1
        elif 'Duration' in last_tok:
            # After Duration, allow Position, Bar, Program (next note same position)
            for i in range(len(tokenizer)):
                tok = tokenizer[i]
                if 'Position' in tok or 'Bar' in tok or 'Program' in tok:
                    mask[i] = 1
        else:
            mask = torch.ones_like(logits)  # Allow anything

        # Apply mask
        logits = logits + (mask - 1) * 10000

        # Sample
        probs = torch.softmax(logits / 0.8, dim=-1)
        next_token = torch.multinomial(probs, 1).item()
        generated.append(next_token)

        if len(generated) % 100 == 0:
            score = tokenizer.decode(generated)
            notes = sum(len(t.notes) for t in score.tracks)
            print(f"Tokens: {len(generated)}, Notes: {notes}")

    return generated

# Test constrained generation
seed = all_tokens[0][500:600]
generated_ids = generate_constrained(model, tokenizer, seed, max_new_tokens=500)

score = tokenizer.decode(generated_ids)
score.dump_midi("constrained.mid")
print(f"Total notes: {sum(len(t.notes) for t in score.tracks)}")

from google.colab import files
files.download("constrained.mid")

Tokens: 200, Notes: 47
Tokens: 300, Notes: 72
Tokens: 400, Notes: 97
Tokens: 500, Notes: 122
Tokens: 600, Notes: 147
Total notes: 147


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [14]:
print(midi_files)

[PosixPath('dance-midi/TaintedLove.mid'), PosixPath('dance-midi/SexyBitch.mid'), PosixPath('dance-midi/Rose_Royce_-_Car_Wash.mid'), PosixPath('dance-midi/YouSpinMeRound.mid'), PosixPath('dance-midi/Gloria_Gaynor_-_I_Will_Survive.mid'), PosixPath('dance-midi/People on the high line.mid'), PosixPath('dance-midi/KC_and_The_Sunshine_Band_-_KC_Medley.mid'), PosixPath('dance-midi/David_Bowie_-_Lets_Dance.mid'), PosixPath('dance-midi/Hot_Chocolate_-_You_Sexy_Thing.mid'), PosixPath('dance-midi/AroundTheWorld.mid'), PosixPath('dance-midi/New_Order_-_Blue_Monday.mid'), PosixPath('dance-midi/keep it comin love.mid'), PosixPath('dance-midi/Michael_Jackson_-_Thriller.mid'), PosixPath('dance-midi/Rick_James_-_Super_Freak.mid'), PosixPath('dance-midi/Kool_and_the_Gang_-_Ladies_Night.mid'), PosixPath('dance-midi/Frankie_Goes_to_Hollywood_-_Relax.mid'), PosixPath('dance-midi/KC_and_The_Sunshine_Band_-_Im_Your_Boogie_Man.mid'), PosixPath('dance-midi/Donna_Summer_-_I_Feel_Love.mid'), PosixPath('dance-mid