In [1]:

!pip install miditok>=3.0.0 symusic torch tqdm matplotlib
     

In [2]:
from pathlib import Path
from symusic import Score

midi_files = list(Path('/kaggle/input/dance-midi').glob('*.mid')) + list(Path('/kaggle/input/dance-midi').glob('*.midi'))
print(f"Found {len(midi_files)} MIDI files\n")

# Analyze a few
total_bpm = 0
for midi_path in midi_files:
    try:
        score = Score(str(midi_path))
        print(f"{midi_path.name}")
        print(f"  Duration: {score.end() / score.ticks_per_quarter / 4:.1f} bars (assuming 4/4) BPM: ", score.tempos[0].qpm)
        print(f"  Tracks: {len(score.tracks)}")
        for track in score.tracks:
            print(f"    - {track.name or 'Unnamed'}: {len(track.notes)} notes, program {track.program}")
        print()
        if score.tempos[0].qpm < 50 or score.tempos[0].qpm > 150:
          total_bpm += 120
        else:
          total_bpm += score.tempos[0].qpm
    except Exception as e:
        print(f"{midi_path.name}: Error - {e}\n")


print(total_bpm / 38)

Found 38 MIDI files

KC_and_The_Sunshine_Band_-_Im_Your_Boogie_Man.mid
  Duration: 146.7 bars (assuming 4/4) BPM:  122.0000691333725
  Tracks: 8
    - Unnamed: 1714 notes, program 0
    - Unnamed: 646 notes, program 33
    - Unnamed: 100 notes, program 27
    - Unnamed: 468 notes, program 5
    - Unnamed: 1469 notes, program 61
    - Unnamed: 4007 notes, program 0
    - Unnamed: 506 notes, program 66
    - Unnamed: 3694 notes, program 27

faith.mid
  Duration: 27.0 bars (assuming 4/4) BPM:  158.0003476007647
  Tracks: 16
    - Faith: 52 notes, program 0
    - Faith: 133 notes, program 32
    - Faith: 132 notes, program 26
    - Faith: 99 notes, program 66
    - Faith: 90 notes, program 54
    - Faith: 18 notes, program 53
    - Faith: 52 notes, program 12
    - Faith: 48 notes, program 56
    - Faith: 16 notes, program 61
    - Faith: 362 notes, program 0
    - Faith: 32 notes, program 65
    - Faith: 16 notes, program 66
    - Faith: 28 notes, program 27
    - Faith: 466 notes, progra

In [3]:
from miditok import REMI, TokenizerConfig

# Configure tokenizer for dance music
config = TokenizerConfig(
    num_velocities=16,          # Quantize velocity into 16 bins
    use_chords=False,                    # Enable chord detection
    # chord_tokens_with_root_note=True,    # Include root note (e.g. "Chord_C:maj")
    use_programs=True,                  # Enable multi-instrument
    use_time_signatures=True,
    use_tempos=True,                    # Allows model to predict changes in tempo
    num_tempos=32,                      # Number of tempo bins
    tempo_range=(100, 140),             # Dance music tempo range
    one_token_stream_for_programs=True,  # Prepend instrument to Pitch, NoteOn, NoteOff (test with initially - might set to false to generate individual instruments if I run into issues)
    beat_res={(0, 4): 8, (4, 12): 4},
)

tokenizer = REMI(config)

  super().__init__(tokenizer_config, params)


In [4]:
all_tokens = []
total_tokens = 0

print(f"{'File':<60} {'Tokens':<10}")
print("-" * 70)

for midi_path in midi_files:
  try:
    tokens = tokenizer(midi_path)
    all_tokens.append(tokens)
    total_tokens += len(tokens.ids)
    print(f"{midi_path.name:<60} {len(tokens.ids):<10}")
  except Exception as e:
    print(f"{midi_path.name:<60} Error: {e}")

File                                                         Tokens    
----------------------------------------------------------------------
KC_and_The_Sunshine_Band_-_Im_Your_Boogie_Man.mid            52521     
faith.mid                                                    6910      
Im a believer.mid                                            3721      
CantGetYououtofMyHead(3).mid                                 22211     
New_Order_-_Blue_Monday.mid                                  37023     
Gloria_Gaynor_-_I_Will_Survive.mid                           44759     
fallinlove2nite.mid                                          6349      
AroundTheWorld.mid                                           11562     
Wild_Cherry_-_Play_That_Funky_Music.mid                      28671     
YouSpinMeRound.mid                                           52172     
Kool_and_the_Gang_-_Get_Down_On_It.mid                       39310     
DontYouWantMe.mid                                            3584

In [5]:
!pip install transformers accelerate -q

In [6]:
from transformers import GPT2Config, GPT2LMHeadModel, Trainer, TrainingArguments
from torch.utils.data import Dataset
import torch

2025-12-17 19:03:42.408491: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765998222.595236      55 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765998222.650059      55 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1765998223.090909      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1765998223.090943      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1765998223.090946      55 computation_placer.cc:177] computation placer alr

In [7]:
# Dataset for tokens
class MidiDataset(Dataset):
  def __init__(self, all_tokens, seq_length=512, stride=512):
    self.seq_length = seq_length

    self.data = []
    for token in all_tokens:
      self.data.extend(tokens)

    self.data = torch.tensor(self.data, dtype=torch.long)
    # Use stride to avoid non-overlapping chunks for faster training
    self.indices = list(range(0, len(self.data) - seq_length - 1, stride))

  def __len__(self):
    return len(self.indices)
    
  def __getitem__(self, idx):
    start = self.indices[idx]
    chunk = self.data[start:start + self.seq_length + 1]
    return {
        'input_ids': chunk[:-1], 
        'labels' : chunk[1:]
    }
# Create dataset
dataset = MidiDataset(all_tokens, seq_length=512, stride=256)
print(f"Dataset samples: {len(dataset)}")
print(f"Steps per epoch: {len(dataset) // 8}")
print(f"Total steps (10 epochs): {len(dataset) // 8 * 10}")

Dataset samples: 6586
Steps per epoch: 823
Total steps (10 epochs): 8230


In [8]:
config = GPT2Config(
    vocab_size=len(tokenizer),
    n_positions=512,
    n_embd=256,
    n_layer=6,
    n_head=8,
    bos_token_id=tokenizer['BOS_None'],
    eos_token_id=tokenizer['EOS_None']
)

model = GPT2LMHeadModel(config)
print(f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}")

Model Parameters: 4,998,656


### Attempt 1: 10 epochs, batch_size = 8

In [9]:
# # Load tensorboard extension
# %load_ext tensorboard

# training_args = TrainingArguments(
#     output_dir="./midi-gpt2",
#     overwrite_output_dir=True,
#     num_train_epochs=10,
#     per_device_train_batch_size=8,
#     learning_rate=5e-4,
#     warmup_steps=100,
#     logging_steps=50,
#     logging_dir="./logs",       # TensorBoard logs
#     save_steps=500,
#     save_total_limit=2,
#     report_to="tensorboard",    # Use tensorboard
# )

# trainer = Trainer(
#     model=model,
#     args=training_args,
#     train_dataset=dataset,
# )

# trainer.train()

In [10]:
# # Generate a sequence
# model.eval()

# # Start with a bar token
# start_tokens = torch.tensor([[tokenizer['Bar_None']]], device=device)

# with torch.no_grad():
#     output = model.generate(
#         start_tokens,
#         max_length=500,
#         temperature=1.0,
#         top_p=0.95,
#         do_sample=True,
#     )

# # Decode
# generated = [tokenizer[tok_id.item()] for tok_id in output[0]]
# print("Generated tokens:")
# for i, tok in enumerate(generated[:100]):
#     print(f"  {i}: {tok}")

In [11]:
# # Decode the generated tokens back to MIDI
# from miditok import TokSequence

# # Get the generated token IDs (not strings)
# generated_ids = output[0].tolist()

# # Create a TokSequence and decode to MIDI
# generated_midi = tokenizer.decode(generated_ids)
# generated_midi.dump_midi("generated_sample.mid")

# print("Saved to generated_sample.mid")

In [12]:
# from IPython.display import FileLink
# FileLink("generated_sample.mid")

In [13]:
# Train for more epochs
training_args = TrainingArguments(
    output_dir="./midi-gpt2",
    overwrite_output_dir=True,
    num_train_epochs=50,
    per_device_train_batch_size=16,
    learning_rate=3e-4,
    warmup_steps=100,
    logging_steps=100,
    save_steps=500,
    save_total_limit=2,
    report_to="none",
    fp16=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
)

trainer.train()

`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Step,Training Loss
100,3.4937
200,1.6593
300,1.1385
400,0.6611
500,0.419
600,0.2855
700,0.2118
800,0.164
900,0.1325
1000,0.1123


TrainOutput(global_step=20600, training_loss=0.06102376694239459, metrics={'train_runtime': 2246.3236, 'train_samples_per_second': 146.595, 'train_steps_per_second': 9.171, 'total_flos': 4794090730291200.0, 'train_loss': 0.06102376694239459, 'epoch': 50.0})

In [20]:
torch.cuda.empty_cache()
if torch.cuda.is_available():
    torch.cuda.reset_peak_memory_stats()

AcceleratorError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [17]:
model.eval()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
# Seed with real music
seed_length = 50
seed_tokens = torch.tensor([all_tokens[0][:seed_length]], device=device)

with torch.no_grad():
    output = model.generate(
        seed_tokens,
        max_length=2000,        # Longer output
        temperature=0.9,
        top_p=0.95,
        do_sample=True,
        pad_token_id=tokenizer['PAD_None'],
    )

generated_ids = output[0].tolist()
generated_midi = tokenizer.decode(generated_ids)
generated_midi.dump_midi("generated_v2.mid")

# Check
score = Score("generated_v2.mid")
print(f"Tracks: {len(score.tracks)}")
print(f"Total notes: {sum(len(t.notes) for t in score.tracks)}")
print(f"Duration: {score.end() / score.ticks_per_quarter / 2:.1f} bars")

from IPython.display import FileLink
FileLink("generated_v2.mid")

AcceleratorError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
