# Load the Maestro dataset, tokenize, split, collate, and make a dataloader

In [1]:
#Load the maestro and Lakh datasets
import os

os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

import zipfile
import requests
from pathlib import Path
import tarfile
import shutil

DATA_DIR = Path("data")
DATA_DIR.mkdir(exist_ok=True)

def download_file(url, dest):
    if not dest.exists():
        response = requests.get(url, stream = True)
        with open(dest, "wb")  as f:
            shutil.copyfileobj(response.raw, f)
    else:
        print(f"{dest} already exists. Skipping download.")
        
def extract_zip(zip_path, extract_to):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)

def extract_tar(tar_path, extract_to):
    with tarfile.open(tar_path, 'r:gz') as tar_ref:
        tar_ref.extractall(extract_to)

maestro_url = "https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip"
maestro_zip_path = DATA_DIR / "maestro-v.3.0.0.zip"
maestro_extract_path = DATA_DIR / "maestro"


#Only download if not exist
if not maestro_extract_path.exists():
    download_file(maestro_url, maestro_zip_path)
    extract_zip(maestro_zip_path, maestro_extract_path)
    maestro_zip_path.unlink()  # Remove the zip file after extraction

lmd_url = "http://hog.ee.columbia.edu/craffel/lmd/lmd_full.tar.gz"
lmd_tar_path = DATA_DIR / "lmd_full.tar.gz"
lmd_extract_path = DATA_DIR / "lakh"

if not lmd_extract_path.exists():
    download_file(lmd_url, lmd_tar_path)
    extract_tar(lmd_tar_path, lmd_extract_path)
    lmd_tar_path.unlink()  # Remove the tar file after extraction


In [2]:

import pandas as pd
from pathlib import Path
from miditok import REMI, TokenizerConfig
from miditok.utils import split_files_for_training
from miditok.pytorch_data import DatasetMIDI, DataCollator
from torch.utils.data import DataLoader, random_split


# ========== CONFIGURATION ==========
MAX_SEQ_LEN = 1024
BATCH_SIZE = 64
VOCAB_SIZE = 40000
DATA_DIR = Path("data")
TOKENIZER_CONFIG = TokenizerConfig(
    pitch_range=[21, 109],
    beat_res={(0, 4): 16},
    num_velocities=16,
    use_chords=True,
    use_tempos=True,
)
TOKENIZER_PATH = DATA_DIR / "combined_tokenizer.json"
CHUNKS_DIR = DATA_DIR / "combined_dataset_chunks"
# ===================================

def get_maestro_midi_paths():
    maestro_dir = (DATA_DIR / "maestro" / "maestro-v3.0.0").resolve()
    metadata_path = maestro_dir / "maestro-v3.0.0.csv"
    metadata = pd.read_csv(metadata_path)
    return [maestro_dir / x for x in metadata["midi_filename"]]

def get_lakh_midi_paths(limit=None):
    lakh_root = (DATA_DIR / "lakh" / "lmd_full").resolve()
    midi_paths = list(lakh_root.rglob("*.mid"))
    return midi_paths[:limit] if limit else midi_paths

def prepare_and_tokenize_dataset():
    if CHUNKS_DIR.exists():
        print("Tokenized dataset chunks already exist. Skipping creation.")
        tokenizer = REMI(params=TOKENIZER_PATH)
        return tokenizer

    # Get MIDI file paths
    maestro_paths = get_maestro_midi_paths()
    lakh_paths = get_lakh_midi_paths(limit=2000)  # Optional: limit for dev

    all_midi_paths = maestro_paths + lakh_paths

    # Initialize and train tokenizer
    tokenizer = REMI(tokenizer_config=TOKENIZER_CONFIG)
    print("Training tokenizer on combined dataset...")
    tokenizer.train(vocab_size=VOCAB_SIZE, files_paths=all_midi_paths)
    tokenizer.save(TOKENIZER_PATH)

    # Tokenize and split
    print(f"Tokenizing and splitting into chunks...")
    split_files_for_training(
        files_paths=all_midi_paths,
        tokenizer=tokenizer,
        save_dir=CHUNKS_DIR,
        max_seq_len=MAX_SEQ_LEN,
    )
    print("Tokenization complete.")

    return tokenizer

def load_dataset(tokenizer, max_seq_len):
    # Load dataset
    files = list(CHUNKS_DIR.glob("**/*.midi"))
    print(f"Loading {len(files)} MIDI chunks...")

    dataset = DatasetMIDI(
        files_paths=files,
        tokenizer=tokenizer,
        max_seq_len=max_seq_len,
        bos_token_id=tokenizer["BOS_None"],
        eos_token_id=tokenizer["EOS_None"],
    )

    return dataset


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import torch
    
tokenizer = None

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

tokenizer = prepare_and_tokenize_dataset()
dataset = load_dataset(tokenizer, MAX_SEQ_LEN)

BOS_TOKEN_ID = tokenizer["BOS_None"]
EOS_TOKEN_ID = tokenizer["EOS_None"]
PAD_TOKEN_ID = tokenizer.pad_token_id

collator = DataCollator(
    pad_token_id=tokenizer.pad_token_id,
    copy_inputs_as_labels=True,
    shift_labels=False,
    pad_on_left=False,
)


# Split into train and validation
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=collator, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, collate_fn=collator)

print("Dataloaders created successfully.")

Tokenized dataset chunks already exist. Skipping creation.
Loading 12061 MIDI chunks...
Dataloaders created successfully.


# XL-transformer decoder time

In [None]:
# Load transformer XL model
from transformers import TransfoXLConfig, TransfoXLLMHeadModel, Trainer, TrainingArguments

tokenizer = REMI(params=TOKENIZER_PATH)

OUTPUT_DIR = Path("checkpoint/transformer_xl")

config = TransfoXLConfig(
    vocab_size = tokenizer.vocab_size,
    d_embed = 64,
    d_model = 64,
    n_layer=2,
    n_head=2,
    mem_len=256,
    clamp_len=0,
    cutoffs = [],
    adaptive = False,
    eos_token_id=tokenizer["EOS_None"]
)

model = TransfoXLLMHeadModel(config)
model.to(device)

TransfoXLLMHeadModel has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


'\ntraining_arg = TrainingArguments(\n    output_dir = OUTPUT_DIR,\n    per_device_train_batch_size=BATCH_SIZE,\n    per_device_eval_batch_size=BATCH_SIZE,\n    eval_strategy="epoch",\n    save_strategy="epoch",\n    num_train_epochs=NUM_EPOCHS,\n    logging_dir="./logs",\n    logging_steps=50,\n    save_total_limit=2,\n    learning_rate=5e-5,\n    weight_decay=0.01,\n)\n'

In [None]:
#Training loop
from torch.optim import AdamW
from tqdm import tqdm
from torch.nn import CrossEntropyLoss

OUTPUT_DIR = Path("checkpoint/transformer_xl")
LOG_DIR  = Path("logs")
BATCH_SIZE=1
NUM_EPOCHS = 1
LR = 5e-5
WEIGHT_DECAY = 0.01

device = torch.device("cpu")
model = model.to(device)

optimizer = AdamW(model.parameters(), lr = LR, weight_decay=WEIGHT_DECAY)

for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0

    for batch in train_loader:
        optimizer.zero_grad()

        inputs = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(inputs, labels=labels)
        loss = outputs.loss

        total_loss += loss.item()

        print(total_loss)

        #logits of size [batch_size*seq_len, vocab_size]
        #labels of size [batch_size*seq_len]

        loss.backward()
        optimizer.step()

        print("iteration")
    
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}, Loss: {avg_loss:.4f}")

    #Eval
    model.eval()
    total_eval_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            inputs = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(inputs, labels=labels)
            loss = outputs.loss
            total_eval_loss += loss.item()
    
    avg_eval_loss = total_eval_loss / len(val_loader)
    print(f"Validation Loss: {avg_eval_loss:.4f}")

    # Save model
    model.save_pretrained(OUTPUT_DIR / f"epoch_{epoch + 1}")

print("training done")


10.596830368041992
iteration
21.19341278076172
iteration


KeyboardInterrupt: 