<a href="https://colab.research.google.com/github/tom-howes/HouseMusicGPT/blob/main/HouseMusicGPT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!git config --global user.name "tom-howes"
!git config --global user.email "thomas.howes01@gmail.com"

In [None]:
%cd /content/HouseMusicGPT/

/content/HouseMusicGPT


In [None]:
!git status

On branch main

No commits yet

nothing to commit (create/copy files and use "git add" to track)


In [2]:
!pip install miditok>=3.0.0 symusic torch tqdm matplotlib

In [8]:
!mkdir -p dance_midi

## MIDI File Exploratory Analysis


In [28]:
from pathlib import Path
from symusic import Score

midi_files = list(Path('dance_midi').glob('*.mid')) + list(Path('dance_midi').glob('*.midi'))
print(f"Found {len(midi_files)} MIDI files\n")

# Analyze a few
total_bpm = 0
for midi_path in midi_files:
    try:
        score = Score(str(midi_path))
        print(f"{midi_path.name}")
        print(f"  Duration: {score.end() / score.ticks_per_quarter / 4:.1f} bars (assuming 4/4) BPM: ", score.tempos[0].qpm)
        print(f"  Tracks: {len(score.tracks)}")
        for track in score.tracks:
            print(f"    - {track.name or 'Unnamed'}: {len(track.notes)} notes, program {track.program}")
        print()
        if score.tempos[0].qpm < 50 or score.tempos[0].qpm > 150:
          total_bpm += 120
        else:
          total_bpm += score.tempos[0].qpm
    except Exception as e:
        print(f"{midi_path.name}: Error - {e}\n")


print(total_bpm / 38)

Found 38 MIDI files

RockDj.mid
  Duration: 102.9 bars (assuming 4/4) BPM:  103.00163944276113
  Tracks: 19
    - CH0171: 246 notes, program 48
    - CH0171: 441 notes, program 33
    - CH0171: 432 notes, program 1
    - CH0171: 448 notes, program 65
    - CH0171: 363 notes, program 52
    - CH0171: 312 notes, program 81
    - CH0171: 30 notes, program 30
    - CH0171: 106 notes, program 59
    - CH0171: 14 notes, program 69
    - CH0171: 3 notes, program 122
    - CH0171: 909 notes, program 17
    - CH0171: 1432 notes, program 27
    - CH0171: 1832 notes, program 0
    - CH0171: 1127 notes, program 48
    - CH0171: 2392 notes, program 48
    - CH0171: 1284 notes, program 28
    - CH0171: 960 notes, program 27
    - CH0171: 1408 notes, program 1
    - CH0171: 206 notes, program 0

Rose_Royce_-_Car_Wash.mid
  Duration: 111.4 bars (assuming 4/4) BPM:  113.99995440001824
  Tracks: 15
    - Drums: 1425 notes, program 0
    - Perc: 273 notes, program 0
    - Bass Gtr: 637 notes, program 0
 

In [29]:
print(f"{'File':<45} {'BPM':<10} {'Likely Actual BPM'}")
print("-" * 70)

tempos = []

for midi_path in midi_files:
    try:
        score = Score(str(midi_path))

        if score.tempos:
            bpm = score.tempos[0].qpm
        else:
            bpm = 120

        # Flag likely errors and suggest fix
        if bpm > 160:
            likely_bpm = bpm / 2
            flag = f"← probably {likely_bpm:.0f}"
        elif bpm < 80:
            likely_bpm = bpm * 2
            flag = f"← probably {likely_bpm:.0f}"
        else:
            likely_bpm = bpm
            flag = ""

        tempos.append(likely_bpm)
        print(f"{midi_path.name:<45} {bpm:<10.1f} {flag}")

    except Exception as e:
        print(f"{midi_path.name:<45} Error: {e}")

print(f"\n--- Summary ---")
print(f"Tempo range (corrected): {min(tempos):.0f} - {max(tempos):.0f} BPM")
print(f"Average: {sum(tempos)/len(tempos):.0f} BPM")

File                                          BPM        Likely Actual BPM
----------------------------------------------------------------------
RockDj.mid                                    103.0      
Rose_Royce_-_Car_Wash.mid                     114.0      
Hot_Chocolate_-_You_Sexy_Thing.mid            107.0      
Wild_Cherry_-_Play_That_Funky_Music.mid       111.0      
Kool_and_the_Gang_-_Ladies_Night.mid          112.0      
DontYouWantMe.mid                             234.0      ← probably 117
TaintedLove.mid                               148.0      
NaturalBlues.mid                              108.0      
Mustang Sally.mid                             155.0      
CantGetYououtofMyHead(3).mid                  125.0      
keep it comin love.mid                        150.0      
Rick_James_-_Super_Freak.mid                  132.0      
Kiki_Dee_-_I've_Got_the_Music_in_Me.mid       122.0      
Donna_Summer_-_I_Feel_Love.mid                130.0      
David_Bowie_-_Let's_Dance.mi

## Tokenization

In [32]:
from miditok import REMI, TokenizerConfig
from symusic import Score
from pathlib import Path

# Configure tokenizer for dance music
config = TokenizerConfig(
    num_velocities=16,          # Quantize velocity into 16 bins
    use_chords=False,                    # Enable chord detection
    # chord_tokens_with_root_note=True,    # Include root note (e.g. "Chord_C:maj")
    use_programs=True,                  # Enable multi-instrument
    use_time_signatures=True,
    use_tempos=True,                    # Allows model to predict changes in tempo
    num_tempos=32,                      # Number of tempo bins
    tempo_range=(100, 140),             # Dance music tempo range
    one_token_stream_for_programs=True,  # Prepend instrument to Pitch, NoteOn, NoteOff (test with initially - might set to false to generate individual instruments if I run into issues)
    beat_res={(0, 4): 8, (4, 12): 4},
)

tokenizer = REMI(config)

print("Vocabulary size:", len(tokenizer))
print("\nSample tokens from vocabulary:")
for i, token in enumerate(tokenizer.vocab):
    print(f"  {i}: {token}")
    if i > 30:
        print("  ...")
        break

Vocabulary size: 502

Sample tokens from vocabulary:
  0: PAD_None
  1: BOS_None
  2: EOS_None
  3: MASK_None
  4: Bar_None
  5: Pitch_21
  6: Pitch_22
  7: Pitch_23
  8: Pitch_24
  9: Pitch_25
  10: Pitch_26
  11: Pitch_27
  12: Pitch_28
  13: Pitch_29
  14: Pitch_30
  15: Pitch_31
  16: Pitch_32
  17: Pitch_33
  18: Pitch_34
  19: Pitch_35
  20: Pitch_36
  21: Pitch_37
  22: Pitch_38
  23: Pitch_39
  24: Pitch_40
  25: Pitch_41
  26: Pitch_42
  27: Pitch_43
  28: Pitch_44
  29: Pitch_45
  30: Pitch_46
  31: Pitch_47
  ...


  super().__init__(tokenizer_config, params)


In [46]:
test_file = midi_files[0]

tokens = tokenizer(test_file)

print(f"File: {test_file.name}")
print(f"Total tokens: {len(tokens.ids)}")

# Find bar positions
bar_positions = [i for i, tok_id in enumerate(tokens.ids) if 'Bar' in tokenizer[tok_id]]

# Ensure bars follow song structure and tokens are evenly spread
print(f"Total bars: {len(bar_positions)}")
print(f"\nTokens per bar (first 20 bars):")
for i in range(min(20, len(bar_positions) - 1)):
    start = bar_positions[i]
    end = bar_positions[i + 1]
    print(f"  Bar {i + 1}: {end - start} tokens")

# Check that individual tokens look correct
print(f"\nFirst 100 tokens:")
for i, tok_id in enumerate(tokens.ids[:100]):
    tok = tokenizer[tok_id]
    print(f"  {i:3d}: {tok}")

File: RockDj.mid
Total tokens: 57426
Total bars: 103

Tokens per bar (first 20 bars):
  Bar 1: 4 tokens
  Bar 2: 22 tokens
  Bar 3: 314 tokens
  Bar 4: 314 tokens
  Bar 5: 314 tokens
  Bar 6: 314 tokens
  Bar 7: 561 tokens
  Bar 8: 548 tokens
  Bar 9: 561 tokens
  Bar 10: 532 tokens
  Bar 11: 547 tokens
  Bar 12: 540 tokens
  Bar 13: 557 tokens
  Bar 14: 564 tokens
  Bar 15: 526 tokens
  Bar 16: 511 tokens
  Bar 17: 523 tokens
  Bar 18: 587 tokens
  Bar 19: 526 tokens
  Bar 20: 514 tokens

First 100 tokens:
    0: Bar_None
    1: TimeSig_4/4
    2: Position_0
    3: Tempo_102.58
    4: Bar_None
    5: TimeSig_4/4
    6: Position_0
    7: Program_-1
    8: PitchDrum_37
    9: Velocity_71
   10: Duration_0.2.8
   11: Position_8
   12: Program_-1
   13: PitchDrum_37
   14: Velocity_47
   15: Duration_0.2.8
   16: Position_16
   17: Program_-1
   18: PitchDrum_37
   19: Velocity_47
   20: Duration_0.2.8
   21: Position_24
   22: Program_-1
   23: PitchDrum_37
   24: Velocity_47
   25: Dura

In [53]:
all_tokens = []
total_tokens = 0

print(f"{'File':<60} {'Tokens':<10}")
print("-" * 70)

for midi_path in midi_files:
  try:
    tokens = tokenizer(midi_path)
    all_tokens.append(tokens)
    total_tokens += len(tokens.ids)
    print(f"{midi_path.name:<60} {len(tokens.ids):<10}")
  except Exception as e:
    print(f"{midi_path.name:<60} Error: {e}")

print("-" * 70)
print(f"{'TOTAL':<60} {total_tokens:<10}")
print(f"Files tokenized: {len(all_tokens)}")
print(f"Average tokens per song: {total_tokens / len(all_tokens)}")
print(f"Vocab size: {len(tokenizer)}")

File                                                         Tokens    
----------------------------------------------------------------------
RockDj.mid                                                   57426     
Rose_Royce_-_Car_Wash.mid                                    31590     
Hot_Chocolate_-_You_Sexy_Thing.mid                           22440     
Wild_Cherry_-_Play_That_Funky_Music.mid                      28671     
Kool_and_the_Gang_-_Ladies_Night.mid                         38592     
DontYouWantMe.mid                                            35849     
TaintedLove.mid                                              11568     
NaturalBlues.mid                                             28066     
Mustang Sally.mid                                            2987      
CantGetYououtofMyHead(3).mid                                 22211     
keep it comin love.mid                                       8500      
Rick_James_-_Super_Freak.mid                                 1416