# Session 6b: Training Anticipatory Music Transformers (GPU Notebook)

## Download environment (if on Colab)

> **IMPORTANT: You will need to restart the kernel after running this cell.**

In [2]:
# Clone the course repository
!git clone https://github.com/lancelotblanchard/ai_music_course.git

# Install dependencies
!pip install -r ai_music_course/requirements.txt
!pip install --upgrade torchvision

Cloning into 'ai_music_course'...
remote: Enumerating objects: 75, done.[K
remote: Counting objects: 100% (23/23), done.[K
remote: Compressing objects: 100% (19/19), done.[K
remote: Total 75 (delta 4), reused 16 (delta 4), pack-reused 52 (from 3)[K
Receiving objects: 100% (75/75), 225.74 MiB | 17.15 MiB/s, done.
Resolving deltas: 100% (8/8), done.
Updating files: 100% (35/35), done.
Collecting torchinfo@ git+https://github.com/lancelotblanchard/torchinfo@87dd4eb (from -r ai_music_course/requirements.txt (line 11))
  Cloning https://github.com/lancelotblanchard/torchinfo (to revision 87dd4eb) to /tmp/pip-install-o4_8tl6t/torchinfo_2a5e5e5ff9264c32af823ed8d19d71f8
  Running command git clone --filter=blob:none --quiet https://github.com/lancelotblanchard/torchinfo /tmp/pip-install-o4_8tl6t/torchinfo_2a5e5e5ff9264c32af823ed8d19d71f8
[0m  Running command git checkout -q 87dd4eb
  Resolved https://github.com/lancelotblanchard/torchinfo to commit 87dd4eb
  Preparing metadata (setup.py) 

In [3]:
# Restart kernel
print("Restarting of kernel...")
get_ipython().kernel.do_shutdown(True)

Restarting of kernel...


{'status': 'ok', 'restart': True}

## Download repository & dataset (if on Colab)

In [3]:
# Clone the anticipation repository
!git clone https://github.com/lancelotblanchard/anticipation.git ../repositories/anticipation

Cloning into '../repositories/anticipation'...
remote: Enumerating objects: 580, done.[K
remote: Counting objects: 100% (224/224), done.[K
remote: Compressing objects: 100% (52/52), done.[K
remote: Total 580 (delta 195), reused 180 (delta 172), pack-reused 356 (from 1)[K
Receiving objects: 100% (580/580), 120.16 KiB | 5.46 MiB/s, done.
Resolving deltas: 100% (362/362), done.


In [4]:
# Let's clone the dataset repository and convert the data to MIDI
!git clone https://github.com/lancelotblanchard/JSB-Chorales-dataset-midi.git ../datasets/JSB-Chorales-midi
!cd ../datasets/JSB-Chorales-midi && python ./JsbToMidi.py 4

Cloning into '../datasets/JSB-Chorales-midi'...
remote: Enumerating objects: 60, done.[K
remote: Counting objects: 100% (24/24), done.[K
remote: Compressing objects: 100% (16/16), done.[K
remote: Total 60 (delta 11), reused 17 (delta 7), pack-reused 36 (from 1)[K
Receiving objects: 100% (60/60), 2.78 MiB | 16.29 MiB/s, done.
Resolving deltas: 100% (19/19), done.
Converting jsb-chorales-quarter.json to MIDI files.
229it [00:00, 379.78it/s]
76it [00:00, 366.67it/s]
77it [00:00, 356.85it/s]


## Training Data Preprocessing

In [9]:
# We will first process our midi
import sys
sys.path.append("../repositories/anticipation")

import traceback
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from glob import glob

from tqdm import tqdm

from anticipation.convert import midi_to_compound
from anticipation.config import PREPROC_WORKERS, TIME_RESOLUTION

def convert_midi(filename, debug=False):
    try:
        tokens = midi_to_compound(filename, debug=debug)
    except Exception:
        if debug:
            print('Failed to process: ', filename)
            print(traceback.format_exc())

        return 1

    with open(f"{filename}.compound.txt", 'w') as f:
        f.write(' '.join(str(tok) for tok in tokens))

    return 0

data_dir = "../datasets/JSB-Chorales-midi/midi-outputs"

filenames = glob(data_dir + '/**/*.mid', recursive=True) \
        + glob(data_dir + '/**/*.midi', recursive=True)

convert_midi_partial = partial(convert_midi)

print(f'Preprocessing {len(filenames)} files with {PREPROC_WORKERS} workers')
with ProcessPoolExecutor(max_workers=PREPROC_WORKERS) as executor:
    results = list(tqdm(executor.map(convert_midi_partial, filenames), desc='Preprocess', total=len(filenames)))

discards = round(100*sum(results)/float(len(filenames)),2)
print(f'Successfully processed {len(filenames) - sum(results)} files (discarded {discards}%)')

Preprocessing 382 files with 16 workers


Preprocess: 100%|██████████| 382/382 [00:03<00:00, 114.18it/s]


Successfully processed 382 files (discarded 0.0%)


In [10]:
# Then we can tokenize it
import os
from multiprocessing import Pool, RLock
from glob import glob

from tqdm import tqdm

from anticipation.config import *
from anticipation.tokenize import tokenize, tokenize_ia

encoding = 'arrival'
AUGMENT_FACTOR = 10
DATA_DIR = "../datasets/JSB-Chorales-midi/midi-outputs"
split_names = ['train', 'test', 'valid']

print('Tokenizing Custom MIDI Dataset')
print(f'  encoding type: {encoding}')

print(f'  train split: {split_names[0]}')
print(f'  validation split: {split_names[2]}')
print(f'  test split: {split_names[1]}')

print('Tokenization parameters:')
print(f'  anticipation interval = {DELTA}s')
print(f'  augment = {AUGMENT_FACTOR}x')
print(f'  max track length = {MAX_TRACK_TIME_IN_SECONDS}s')
print(f'  min track length = {MIN_TRACK_TIME_IN_SECONDS}s')
print(f'  min track events = {MIN_TRACK_EVENTS}')

split_paths = [os.path.join(DATA_DIR, s) for s in split_names]
files = [glob(f'{p}/*.compound.txt') for p in split_paths]
outputs = [os.path.join(DATA_DIR, f'tokenized-events-{s}.txt') for s in split_names]

print(files)
# Augmentation settings
augment = [AUGMENT_FACTOR if s == 'Train' else 1 for s in split_names]

with Pool(processes=PREPROC_WORKERS, initargs=(RLock(),), initializer=tqdm.set_lock) as pool:
    results = pool.starmap(tokenize, zip(files, outputs, augment, range(len(split_names))))

seq_count, rest_count, too_short, too_long, too_manyinstr, discarded_seqs, truncations \
        = (sum(x) for x in zip(*results))
rest_ratio = round(100*float(rest_count)/(seq_count*M),2)

trunc_type = 'duration'
trunc_ratio = round(100*float(truncations)/(seq_count*M),2)

print('Tokenization complete.')
print(f'  => Processed {seq_count} sequences')
print(f'  => Inserted {rest_count} REST tokens ({rest_ratio}% of events)')
print(f'  => Discarded {too_short+too_long+too_manyinstr} sequences for being out of bounds')
print(f'      - {too_short} too short')
print(f'      - {too_long} too long')
print(f'      - {too_manyinstr} too many instruments')
print(f'  => Discarded {discarded_seqs} sequences for other reasons')
print(f'  => Truncated {truncations} {trunc_type} times ({trunc_ratio}% of {trunc_type}s)')
print('Remember to shuffle the training split!')

Tokenizing Custom MIDI Dataset
  encoding type: arrival
  train split: train
  validation split: valid
  test split: test
Tokenization parameters:
  anticipation interval = 5s
  augment = 10x
  max track length = 3600s
  min track length = 10s
  min track events = 100
[['../datasets/JSB-Chorales-midi/midi-outputs/train/chorale_train_156.mid.compound.txt', '../datasets/JSB-Chorales-midi/midi-outputs/train/chorale_train_144.mid.compound.txt', '../datasets/JSB-Chorales-midi/midi-outputs/train/chorale_train_067.mid.compound.txt', '../datasets/JSB-Chorales-midi/midi-outputs/train/chorale_train_050.mid.compound.txt', '../datasets/JSB-Chorales-midi/midi-outputs/train/chorale_train_023.mid.compound.txt', '../datasets/JSB-Chorales-midi/midi-outputs/train/chorale_train_017.mid.compound.txt', '../datasets/JSB-Chorales-midi/midi-outputs/train/chorale_train_065.mid.compound.txt', '../datasets/JSB-Chorales-midi/midi-outputs/train/chorale_train_021.mid.compound.txt', '../datasets/JSB-Chorales-midi/mi



#1:   0%|          | 0/77 [00:00<?, ?it/s][A[A


#2:   0%|          | 0/76 [00:00<?, ?it/s][A[A[A
#1: 100%|██████████| 77/77 [00:00<00:00, 832.05it/s]



#2: 100%|██████████| 76/76 [00:00<00:00, 603.52it/s]

#0: 100%|██████████| 229/229 [00:00<00:00, 1144.75it/s]


Tokenization complete.
  => Processed 178 sequences
  => Inserted 543 REST tokens (0.89% of events)
  => Discarded 20 sequences for being out of bounds
      - 20 too short
      - 0 too long
      - 0 too many instruments
  => Discarded 0 sequences for other reasons
  => Truncated 3 duration times (0.0% of durations)
Remember to shuffle the training split!


In [11]:
import os

# We can merge the validation and train sequences in one file

DATA_DIR = "../datasets/JSB-Chorales-midi/midi-outputs"

train_file = os.path.join(DATA_DIR, "tokenized-events-train.txt")
valid_file = os.path.join(DATA_DIR, "tokenized-events-train.txt")

with open(train_file, 'r') as f:
    train_data = f.readlines()

with open(valid_file, 'r') as f:
    valid_data = f.readlines()

merged_data = train_data + valid_data

with open(os.path.join(DATA_DIR, "tokenized-events-merged.txt"), 'w') as f:
    f.writelines(merged_data)

## Training Script

In [12]:
# first login with WandB
!wandb login WANDB_API_KEY

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
import os
import sys
import random
from datasets import load_dataset
from torch.optim import AdamW
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, GPT2LMHeadModel
from torch.utils.data import SequentialSampler, Subset
from datasets import Dataset, DatasetDict
from tqdm import tqdm
from transformers import DataCollatorForLanguageModeling
from torch.nn.utils.rnn import pad_sequence
import torch
import sys

def parse_amt_tokens(token_file):
    lines = open(token_file).readlines()

    all_tokens = []

    for l in lines:
        token_text = l.strip()
        tokens = [int(t) for t in token_text.split(" ")]

        all_tokens.append(tokens)

    return all_tokens

def load_tokenized_data(filename, train_ratio=0.8, seed=42, max_length=1024, vocab_size=55028):
    data = []
    all_tokens = parse_amt_tokens(filename)
    for tokens in all_tokens:
        if len(tokens) > 1 and not any([t >= vocab_size for t in tokens]):
            # Truncate if too long
            tokens = tokens[:max_length]

            # Pad with SEQ token
            if len(tokens) < max_length:
                tokens += [50256] * (max_length - len(tokens))

            data.append({"input_ids": tokens, "labels": tokens.copy()})
    else:
        print(len(tokens))
    if not data:
        raise ValueError("No valid tokenized data found!")

    random.seed(seed)
    random.shuffle(data)
    split_idx = int(len(data) * train_ratio)

    ds_train = Dataset.from_list(data[:split_idx])
    ds_valid = Dataset.from_list(data[split_idx:])

    return DatasetDict({"train": ds_train, "valid": ds_valid})

def debug_collate_fn(features, max_token_id=55028):
    """
    A collate function that checks if any token in input_ids or labels
    exceeds max_token_id, and if so, skips that entire sample.
    """
    valid_features = []
    for i, feature in enumerate(features):
        input_ids = feature["input_ids"]
        labels = feature["labels"]

        # Check input_ids
        if any(t > max_token_id for t in input_ids):
            print(f"[WARNING] Found out-of-range token in sample {i} (input_ids): {input_ids}")
            continue

        # Check labels
        if any(t > max_token_id for t in labels):
            print(f"[WARNING] Found out-of-range token in sample {i} (labels): {labels}")
            continue

        valid_features.append(feature)

    # If every feature in the batch is invalid, raise an error (or you could return an empty batch).
    if not valid_features:
        raise ValueError("All samples in this batch contained out-of-range tokens!")

    # Convert to tensors.
    batch_input_ids = [torch.tensor(f["input_ids"], dtype=torch.long) for f in valid_features]
    batch_labels    = [torch.tensor(f["labels"],    dtype=torch.long) for f in valid_features]

    # Stack into a single batch tensor.
    input_ids = torch.stack(batch_input_ids, dim=0)
    labels    = torch.stack(batch_labels, dim=0)

    return {"input_ids": input_ids, "labels": labels}


# Our base_model
GPT2_MODEL_NAME = "stanford-crfm/music-small-800k"
# ENTER PATH TO TOKENIZED MIDI FILES HERE
DATA_DIR = "../datasets/JSB-Chorales-midi/midi-outputs/tokenized-events-merged.txt"
CKPT_DIR = "jsb_amt_run01"
SEQLEN = 1024
LR = 1e-5

class SequentialTrainer(Trainer):
    """For fair comparison at same training steps, no shuffling."""
    def _get_train_sampler(self):
        return SequentialSampler(self.train_dataset)

if __name__ == "__main__":

    model = AutoModelForCausalLM.from_pretrained(GPT2_MODEL_NAME).to("cuda")

    embedding_size = model.get_input_embeddings().num_embeddings
    print("Model embedding size:", embedding_size)

    print("total trainable params:", sum(p.numel() for p in model.parameters() if p.requires_grad))

    dataset_dict = load_tokenized_data(DATA_DIR)
    ds_train = dataset_dict["train"]
    ds_valid = dataset_dict["valid"]

    for i, sample in enumerate(ds_train):
      for t in sample["input_ids"]:
            if t >= embedding_size:
                print(f"Out-of-range token found in sample {i}, token={t}")

    optimizer = AdamW(model.parameters(), lr=LR)

    # Training arguments
    training_args = TrainingArguments(
        output_dir=CKPT_DIR,
        learning_rate=LR,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=16,
        warmup_steps=500,
        lr_scheduler_type="cosine",
        max_steps=2000,
        save_steps=100,
        eval_steps=100,
        logging_steps=10,
        logging_dir="./logs",
        bf16=True,  # Enable mixed precision
        report_to="wandb",
        dataloader_num_workers=4,
        do_eval=True,
        gradient_accumulation_steps=4,
        save_safetensors=False,
    )

    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=ds_train,
        eval_dataset=ds_valid,
        optimizers=(optimizer, None),
        data_collator=debug_collate_fn,
    )

    # Train
    trainer.train()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.96k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/512M [00:00<?, ?B/s]

Model embedding size: 55028
total trainable params: 128103936
1024


[34m[1mwandb[0m: Currently logged in as: [33mlancelotblanchard[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


Step,Training Loss
10,6.2429
