# Symbolic Music LLM Scaling Laws - Complete Pipeline

This notebook contains the complete pipeline for:
1. Data collection and preprocessing (10,000 MIDI files → ABC notation)
2. Tokenization and train/val/test splitting
3. Training transformer models for scaling laws study

**Note**: This notebook is designed for Google Colab. Make sure to enable GPU runtime (Runtime → Change runtime type → GPU).


In [None]:
# Install required packages
%pip install -q music21 pretty_midi librosa mir_eval numpy pandas scipy torch transformers datasets matplotlib seaborn tqdm pyyaml wandb


## 1. Setup and Imports


In [None]:
import os
import sys
import json
import pickle
import warnings
from pathlib import Path
from collections import Counter
from typing import List, Tuple, Dict, Optional
import random
import gc
import io
import contextlib

# Aggressively suppress all warnings, especially from music21
warnings.filterwarnings('ignore')
os.environ['PYTHONWARNINGS'] = 'ignore'
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=RuntimeWarning)

import music21
music21.environment.UserSettings()['warnings'] = 0

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import math

# Setup paths for Colab
if 'google.colab' in str(get_ipython()):
    BASE_DIR = Path('/content/symbolic-music-llm')
    BASE_DIR.mkdir(exist_ok=True)
    os.chdir(BASE_DIR)
else:
    BASE_DIR = Path.cwd()

DATA_DIR = BASE_DIR / "data"
OUTPUT_DIR = BASE_DIR / "data" / "processed"
LMD_DIR = DATA_DIR / "lmd_matched"

# Create directories
DATA_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
(OUTPUT_DIR / "abc").mkdir(exist_ok=True)
(OUTPUT_DIR / "tokenized").mkdir(exist_ok=True)

print(f"Working directory: {BASE_DIR}")
print(f"Data directory: {DATA_DIR}")
print(f"Output directory: {OUTPUT_DIR}")


## 2. Model Architecture


In [None]:
# Transformer Model Classes
class CausalSelfAttention(nn.Module):
    """Multi-head causal self-attention block."""
    
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        
        self.c_attn = nn.Linear(d_model, 3 * d_model)
        self.c_proj = nn.Linear(d_model, d_model)
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        B, T, C = x.size()
        q, k, v = self.c_attn(x).split(self.d_model, dim=2)
        k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        
        y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.1 if self.training else 0, is_causal=True)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        return y


class MLP(nn.Module):
    """Feedforward network with GELU activation."""
    
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.c_fc = nn.Linear(d_model, d_ff)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x


class Block(nn.Module):
    """Transformer block: attention + feedforward with residual connections."""
    
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.ln_1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads, dropout)
        self.ln_2 = nn.LayerNorm(d_model)
        self.mlp = MLP(d_model, d_ff, dropout)
        
    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class MusicTransformer(nn.Module):
    """Decoder-only Transformer for symbolic music generation."""
    
    def __init__(self, vocab_size, d_model=512, n_layers=6, n_heads=8, 
                 d_ff=None, max_seq_length=5000, dropout=0.1):
        super().__init__()
        
        if d_ff is None:
            d_ff = 4 * d_model
        
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.d_ff = d_ff
        self.max_seq_length = max_seq_length
        
        self.wte = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.wpe = nn.Embedding(max_seq_length, d_model)
        self.drop = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([
            Block(d_model, n_heads, d_ff, dropout) 
            for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        
        # Tie weights
        self.wte.weight = self.lm_head.weight
        
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)
    
    def forward(self, idx, targets=None):
        B, T = idx.size()
        tok_emb = self.wte(idx)
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
        pos_emb = self.wpe(pos)
        x = self.drop(tok_emb + pos_emb)
        
        for block in self.blocks:
            x = block(x)
        
        x = self.ln_f(x)
        logits = self.lm_head(x)
        
        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)), 
                targets.view(-1), 
                ignore_index=-1
            )
        
        return logits, loss
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """Generate new tokens given a context."""
        self.eval()
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.max_seq_length:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

print("Model classes defined!")


In [None]:
# MIDI to ABC Converter
class MIDIToABCConverter:
    """Convert MIDI files to ABC notation using music21."""
    
    def __init__(self):
        self.conversion_stats = {'success': 0, 'failed': 0, 'errors': []}
    
    def convert_midi_to_abc(self, midi_path: Path) -> Optional[str]:
        """Convert a MIDI file to ABC notation."""
        try:
            null_stderr = io.StringIO()
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                with contextlib.redirect_stderr(null_stderr):
                    score = music21.converter.parse(str(midi_path))
            
            abc_str = self._score_to_abc_manual(score)
            
            if abc_str:
                abc_str = self._clean_abc(abc_str)
                if len(abc_str.strip()) > 0:
                    self.conversion_stats['success'] += 1
                    return abc_str
            
            self.conversion_stats['failed'] += 1
            return None
        except Exception as e:
            self.conversion_stats['failed'] += 1
            self.conversion_stats['errors'].append(str(e))
            return None
    
    def _score_to_abc_manual(self, score) -> str:
        """Manually convert music21 score to ABC notation."""
        try:
            abc_lines = []
            
            # ABC header
            abc_lines.append("X:1")
            abc_lines.append("M:4/4")  # Default time signature
            abc_lines.append("L:1/8")  # Default note length
            abc_lines.append("K:C")     # Default key
            
            # Extract time signature if available
            for ts in score.flat.getElementsByClass('TimeSignature'):
                if ts.numerator and ts.denominator:
                    abc_lines[1] = f"M:{ts.numerator}/{ts.denominator}"
                    break
            
            # Extract key signature if available
            try:
                key = score.analyze('key')
                if key:
                    key_name = key.tonic.name
                    mode = 'maj' if key.mode == 'major' else 'min'
                    abc_lines[3] = f"K:{key_name}{mode[0]}"
            except:
                pass
            
            # Convert notes to ABC body
            abc_body = []
            measure_count = 0
            
            for element in score.flat.notesAndRests:
                if isinstance(element, music21.note.Note):
                    abc_body.append(self._note_to_abc(element))
                elif isinstance(element, music21.note.Rest):
                    dur = self._duration_to_abc(element.duration.quarterLength)
                    abc_body.append("z" + dur)
                elif isinstance(element, music21.chord.Chord):
                    # Handle chords (simplified: use first note)
                    if len(element.notes) > 0:
                        abc_body.append(self._note_to_abc(element.notes[0]))
                
                # Add bar lines periodically
                measure_count += 1
                if measure_count % 4 == 0:
                    abc_body.append("|")
            
            body_str = "".join(abc_body)
            if len(body_str) > 80:
                parts = body_str.split("|")
                formatted_parts = []
                for part in parts:
                    if len(part) > 80:
                        words = part.split()
                        line = []
                        for word in words:
                            if len(" ".join(line + [word])) > 80 and line:
                                formatted_parts.append(" ".join(line))
                                line = [word]
                            else:
                                line.append(word)
                        if line:
                            formatted_parts.append(" ".join(line))
                    else:
                        formatted_parts.append(part)
                body_str = "|".join(formatted_parts)
            
            abc_lines.append(body_str)
            return "\n".join(abc_lines) if abc_lines else ""
        except Exception as e:
            return ""
    
    def _note_to_abc(self, note) -> str:
        """Convert a music21 note to ABC notation."""
        try:
            note_name = note.pitch.name[0]
            
            if note.pitch.accidental:
                if note.pitch.accidental.alter == 1:
                    note_name = "^" + note_name
                elif note.pitch.accidental.alter == -1:
                    note_name = "_" + note_name
            
            octave = note.pitch.octave
            if octave < 4:
                note_name = note_name.lower() * (4 - octave)
            elif octave > 4:
                note_name = note_name + "'" * (octave - 4)
            
            dur = self._duration_to_abc(note.duration.quarterLength)
            return note_name + dur
        except Exception:
            return ""
    
    def _duration_to_abc(self, quarter_length: float) -> str:
        """Convert duration in quarter notes to ABC notation."""
        eighth_notes = quarter_length * 2
        eighth_notes = round(eighth_notes * 8) / 8
        
        if eighth_notes <= 0:
            return ""
        elif eighth_notes == 0.5:
            return "/"
        elif eighth_notes == 1.0:
            return ""
        elif eighth_notes == 2.0:
            return "2"
        elif eighth_notes == 3.0:
            return "3"
        elif eighth_notes == 4.0:
            return "4"
        elif eighth_notes == 6.0:
            return "6"
        elif eighth_notes == 8.0:
            return "8"
        else:
            dur_int = int(eighth_notes)
            if dur_int > 0 and dur_int <= 16:
                return str(dur_int)
            else:
                return f"/{int(1/eighth_notes)}" if eighth_notes < 1 else str(int(eighth_notes))
    
    def _clean_abc(self, abc_str: str) -> str:
        """Clean and normalize ABC notation string."""
        lines = abc_str.split('\n')
        cleaned_lines = []
        for line in lines:
            line = line.strip()
            if line and not line.startswith('%'):
                cleaned_lines.append(line)
        return '\n'.join(cleaned_lines)

print("MIDI to ABC converter defined!")


In [None]:
# Music Tokenizer
class MusicTokenizer:
    """Tokenizer for ABC notation."""
    
    def __init__(self):
        self.vocab = {}
        self.vocab_size = 0
        self.token_to_id = {}
        self.id_to_token = {}
        self.special_tokens = {
            '<PAD>': 0,
            '<UNK>': 1,
            '<START>': 2,
            '<END>': 3,
            '<SEP>': 4,
        }
    
    def build_vocab(self, abc_strings: List[str], min_freq: int = 2):
        """Build vocabulary from ABC strings."""
        print("Building vocabulary...")
        token_counter = Counter()
        
        for abc_str in tqdm(abc_strings, desc="Tokenizing for vocab"):
            tokens = self._tokenize_abc(abc_str)
            token_counter.update(tokens)
        
        vocab = dict(self.special_tokens)
        current_id = len(self.special_tokens)
        
        for token, count in token_counter.items():
            if count >= min_freq:
                vocab[token] = current_id
                current_id += 1
        
        self.vocab = vocab
        self.vocab_size = len(vocab)
        self.token_to_id = vocab
        self.id_to_token = {v: k for k, v in vocab.items()}
        
        print(f"Vocabulary size: {self.vocab_size}")
        print(f"  Special tokens: {len(self.special_tokens)}")
        print(f"  Regular tokens: {self.vocab_size - len(self.special_tokens)}")
    
    def _tokenize_abc(self, abc_str: str) -> List[str]:
        """Tokenize ABC notation string into music-aware tokens."""
        tokens = []
        i = 0
        
        while i < len(abc_str):
            char = abc_str[i]
            
            if char.isspace():
                i += 1
                continue
            
            if char == '|':
                tokens.append('|')
                i += 1
                continue
            
            if char.upper() in 'ABCDEFG':
                note_token = char.upper()
                i += 1
                
                if i < len(abc_str) and abc_str[i] in '^_':
                    note_token += abc_str[i]
                    i += 1
                
                while i < len(abc_str) and abc_str[i] in ",'":
                    note_token += abc_str[i]
                    i += 1
                
                tokens.append(note_token)
                continue
            
            if char.isdigit():
                duration = char
                i += 1
                while i < len(abc_str) and abc_str[i].isdigit():
                    duration += abc_str[i]
                    i += 1
                tokens.append(f"DUR:{duration}")
                continue
            
            if char == 'z':
                tokens.append('z')
                i += 1
                continue
            
            tokens.append(char)
            i += 1
        
        return tokens
    
    def encode(self, abc_str: str) -> List[int]:
        """Encode ABC string to token IDs."""
        tokens = self._tokenize_abc(abc_str)
        token_ids = []
        for token in tokens:
            if token in self.token_to_id:
                token_ids.append(self.token_to_id[token])
            else:
                token_ids.append(self.token_to_id['<UNK>'])
        return token_ids
    
    def decode(self, token_ids: List[int]) -> str:
        """Decode token IDs back to ABC string."""
        tokens = []
        for token_id in token_ids:
            if token_id in self.id_to_token:
                tokens.append(self.id_to_token[token_id])
            else:
                tokens.append('<UNK>')
        return ' '.join(tokens)
    
    def save(self, path: Path):
        """Save tokenizer to disk."""
        with open(path, 'wb') as f:
            pickle.dump({
                'vocab': self.vocab,
                'token_to_id': self.token_to_id,
                'id_to_token': self.id_to_token,
                'vocab_size': self.vocab_size,
                'special_tokens': self.special_tokens
            }, f)
    
    @classmethod
    def load(cls, path: Path):
        """Load tokenizer from disk."""
        with open(path, 'rb') as f:
            data = pickle.load(f)
        tokenizer = cls()
        tokenizer.vocab = data['vocab']
        tokenizer.token_to_id = data['token_to_id']
        tokenizer.id_to_token = data['id_to_token']
        tokenizer.vocab_size = data['vocab_size']
        tokenizer.special_tokens = data['special_tokens']
        return tokenizer

print("Music tokenizer defined!")


## 4. Data Collection

This section downloads and extracts the Lakh MIDI Dataset (LMD-matched) to the `/data` folder.

**Note**: The dataset is ~1.7GB compressed and ~2-3GB extracted. This may take several minutes to download.


In [None]:
# Data Collection: Download and Extract Lakh MIDI Dataset
import urllib.request
import tarfile
import shutil

LMD_URL = "http://hog.ee.columbia.edu/craffel/lmd/lmd_matched.tar.gz"
TAR_PATH = DATA_DIR / "lmd_matched.tar.gz"

def download_with_progress(url: str, destination: Path):
    """Download file with progress bar."""
    def reporthook(count, block_size, total_size):
        percent = int(count * block_size * 100 / total_size)
        print(f"\rDownloading... {percent}% ({count * block_size / 1024 / 1024:.1f} MB / {total_size / 1024 / 1024:.1f} MB)", end='', flush=True)
    
    try:
        urllib.request.urlretrieve(url, destination, reporthook=reporthook)
        print("\n✓ Download complete!")
        return True
    except Exception as e:
        print(f"\n✗ Download failed: {e}")
        return False

# Check if dataset already exists
if LMD_DIR.exists() and len(list(LMD_DIR.rglob("*.mid"))) > 0:
    print(f"✓ LMD dataset already exists in {LMD_DIR}")
    print(f"  Found {len(list(LMD_DIR.rglob('*.mid')))} MIDI files")
else:
    print("LMD dataset not found. Starting download...")
    print(f"URL: {LMD_URL}")
    print(f"Destination: {TAR_PATH}")
    print(f"Note: This is a large file (~1.7GB), download may take 5-15 minutes depending on connection speed.\n")
    
    # Download the dataset
    if not TAR_PATH.exists():
        if download_with_progress(LMD_URL, TAR_PATH):
            print(f"✓ Downloaded to {TAR_PATH}")
        else:
            print("✗ Download failed. Please check your internet connection and try again.")
            raise RuntimeError("Failed to download LMD dataset")
    else:
        print(f"✓ Tar file already exists: {TAR_PATH}")
    
    # Extract the dataset
    print(f"\nExtracting {TAR_PATH} to {DATA_DIR}...")
    print("This may take a few minutes...")
    
    try:
        with tarfile.open(TAR_PATH, 'r:gz') as tar:
            # Get total members for progress
            members = tar.getmembers()
            total = len(members)
            
            # Extract with progress
            for i, member in enumerate(members):
                tar.extract(member, DATA_DIR)
                if (i + 1) % 1000 == 0:
                    print(f"  Extracted {i+1}/{total} files...", end='\r', flush=True)
            
            print(f"\n✓ Extracted {total} files")
        
        # The tar file might extract to a folder with a different name
        # Check for extracted folders and rename if needed
        extracted_folders = [d for d in DATA_DIR.iterdir() 
                           if d.is_dir() and 'lmd' in d.name.lower() and d != LMD_DIR]
        
        if extracted_folders and not LMD_DIR.exists():
            if len(extracted_folders) == 1:
                print(f"Renaming {extracted_folders[0]} to {LMD_DIR}")
                extracted_folders[0].rename(LMD_DIR)
        
        # Clean up tar file to save space (optional - comment out if you want to keep it)
        # TAR_PATH.unlink()
        # print(f"✓ Removed tar file to save space")
        
    except Exception as e:
        print(f"✗ Extraction failed: {e}")
        raise RuntimeError(f"Failed to extract dataset: {e}")

# Verify MIDI files are available
midi_files = list(LMD_DIR.rglob("*.mid"))
print(f"\n{'='*60}")
print(f"Data Collection Summary:")
print(f"{'='*60}")
print(f"MIDI files found: {len(midi_files):,}")
print(f"Location: {LMD_DIR}")

if len(midi_files) == 0:
    print("\n⚠ WARNING: No MIDI files found!")
    print("Please check:")
    print(f"  1. Extraction completed successfully")
    print(f"  2. Files are in: {LMD_DIR}")
    print(f"  3. Directory structure is correct")
    raise RuntimeError("No MIDI files found after download/extraction")

# Limit to 10,000 files for this run
MAX_FILES = 10000
if len(midi_files) > MAX_FILES:
    print(f"\nLimiting to {MAX_FILES} files for processing")
    midi_files = midi_files[:MAX_FILES]

print(f"\n✓ Ready to process {len(midi_files):,} MIDI files")
print(f"{'='*60}")


## 5. Data Processing: Convert MIDI to ABC Notation

This section processes the downloaded MIDI files and converts them to ABC notation format.

**Processing**: Up to 10,000 MIDI files will be converted to ABC notation.


In [None]:
# Convert MIDI files to ABC notation
converter = MIDIToABCConverter()
abc_data = []

print(f"Converting {len(midi_files)} MIDI files to ABC notation...")
print("This may take a while. Progress will be shown below.")

for midi_file in tqdm(midi_files, desc="Converting MIDI to ABC"):
    abc_str = converter.convert_midi_to_abc(midi_file)
    
    if abc_str:
        # Save ABC file
        abc_path = OUTPUT_DIR / "abc" / f"{midi_file.stem}.abc"
        abc_path.parent.mkdir(parents=True, exist_ok=True)
        
        with open(abc_path, 'w') as f:
            f.write(abc_str)
        
        abc_data.append((midi_file, abc_str))
    
    # Periodic garbage collection
    if len(abc_data) % 100 == 0:
        gc.collect()

print(f"\nConversion complete!")
print(f"  Successful: {converter.conversion_stats['success']}")
print(f"  Failed: {converter.conversion_stats['failed']}")
print(f"  Success rate: {converter.conversion_stats['success'] / len(midi_files) * 100:.1f}%")
print(f"  Total ABC files: {len(abc_data)}")


## 6. Build Vocabulary and Tokenize


In [None]:
# Build vocabulary from ABC strings
print("Building vocabulary...")
tokenizer = MusicTokenizer()
abc_strings = [abc_str for _, abc_str in abc_data]
tokenizer.build_vocab(abc_strings, min_freq=2)

# Save tokenizer
tokenizer.save(OUTPUT_DIR / "tokenizer.pkl")
print(f"Tokenizer saved to {OUTPUT_DIR / 'tokenizer.pkl'}")


## 7. Filter Sequences and Create Splits


In [None]:
# Filter sequences by length and tokenize
MIN_SEQUENCE_LENGTH = 50
MAX_SEQUENCE_LENGTH = 5000

print(f"Filtering sequences (length: {MIN_SEQUENCE_LENGTH}-{MAX_SEQUENCE_LENGTH} tokens)...")

filtered_data = []
stats = {'too_short': 0, 'too_long': 0, 'valid': 0}

for midi_path, abc_str in tqdm(abc_data, desc="Filtering and tokenizing"):
    token_ids = tokenizer.encode(abc_str)
    seq_length = len(token_ids)
    
    if seq_length < MIN_SEQUENCE_LENGTH:
        stats['too_short'] += 1
        continue
    elif seq_length > MAX_SEQUENCE_LENGTH:
        stats['too_long'] += 1
        continue
    else:
        stats['valid'] += 1
        filtered_data.append((midi_path, abc_str, token_ids))

print(f"\nFiltering statistics:")
print(f"  Too short (<{MIN_SEQUENCE_LENGTH}): {stats['too_short']}")
print(f"  Too long (>{MAX_SEQUENCE_LENGTH}): {stats['too_long']}")
print(f"  Valid: {stats['valid']}")


In [None]:
# Create train/val/test splits
TRAIN_SPLIT = 0.98
VAL_SPLIT = 0.01
TEST_SPLIT = 0.01

print("Creating train/val/test splits...")

# Shuffle data
np.random.seed(42)
indices = np.random.permutation(len(filtered_data))
filtered_data = [filtered_data[i] for i in indices]

# Calculate split indices
n_total = len(filtered_data)
n_train = int(n_total * TRAIN_SPLIT)
n_val = int(n_total * VAL_SPLIT)

train_data = filtered_data[:n_train]
val_data = filtered_data[n_train:n_train + n_val]
test_data = filtered_data[n_train + n_val:]

print(f"  Train: {len(train_data)} sequences")
print(f"  Val: {len(val_data)} sequences")
print(f"  Test: {len(test_data)} sequences")


In [None]:
# Save splits to disk
print("Saving splits...")

def save_split(data, split_name):
    split_dir = OUTPUT_DIR / "tokenized" / split_name
    split_dir.mkdir(parents=True, exist_ok=True)
    
    json_data = []
    for midi_path, abc_str, token_ids in data:
        json_data.append({
            'midi_path': str(midi_path),
            'abc': abc_str,
            'tokens': token_ids,
            'length': len(token_ids)
        })
    
    with open(split_dir / "data.json", 'w') as f:
        json.dump(json_data, f, indent=2)

save_split(train_data, "train")
save_split(val_data, "val")
save_split(test_data, "test")

# Calculate statistics
train_tokens = sum(len(tokens) for _, _, tokens in train_data)
val_tokens = sum(len(tokens) for _, _, tokens in val_data)
test_tokens = sum(len(tokens) for _, _, tokens in test_data)

print(f"\nToken counts:")
print(f"  Train: {train_tokens:,} tokens ({train_tokens/1e6:.1f}M)")
print(f"  Val: {val_tokens:,} tokens")
print(f"  Test: {test_tokens:,} tokens")
print(f"  Total: {train_tokens + val_tokens + test_tokens:,} tokens")

print(f"\n✓ Data preprocessing complete!")
print(f"  Tokenizer: {OUTPUT_DIR / 'tokenizer.pkl'}")
print(f"  Train data: {OUTPUT_DIR / 'tokenized' / 'train' / 'data.json'}")
print(f"  Val data: {OUTPUT_DIR / 'tokenized' / 'val' / 'data.json'}")
print(f"  Test data: {OUTPUT_DIR / 'tokenized' / 'test' / 'data.json'}")


## 8. Data Loading Utilities


In [None]:
# Data loading classes
class MusicDataset(Dataset):
    """Dataset for tokenized music sequences."""
    
    def __init__(self, data_path: Path, max_seq_length: int = 5000):
        self.data_path = data_path
        self.max_seq_length = max_seq_length
        self.sequences = []
        
        print(f"Loading sequences from {data_path}...")
        with open(data_path, 'r') as f:
            content = f.read()
            try:
                data = json.loads(content)
                if isinstance(data, list):
                    for item in data:
                        token_ids = item.get('token_ids') or item.get('tokens', [])
                        if len(token_ids) > max_seq_length:
                            token_ids = token_ids[:max_seq_length]
                        if len(token_ids) > 0:
                            self.sequences.append(token_ids)
            except json.JSONDecodeError:
                f.seek(0)
                for line in f:
                    if line.strip():
                        data = json.loads(line)
                        token_ids = data.get('token_ids') or data.get('tokens', [])
                        if len(token_ids) > max_seq_length:
                            token_ids = token_ids[:max_seq_length]
                        if len(token_ids) > 0:
                            self.sequences.append(token_ids)
        
        print(f"Loaded {len(self.sequences)} sequences")
        if len(self.sequences) > 0:
            avg_len = sum(len(s) for s in self.sequences) / len(self.sequences)
            print(f"Average sequence length: {avg_len:.1f}")
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        tokens = torch.tensor(sequence, dtype=torch.long)
        input_ids = tokens[:-1]
        target_ids = tokens[1:]
        return input_ids, target_ids


class MusicDataLoader:
    """Data loader that batches by tokens (not sequences)."""
    
    def __init__(self, data_path: Path, batch_size_tokens: int, 
                 max_seq_length: int = 5000, shuffle: bool = True):
        self.data_path = data_path
        self.batch_size_tokens = batch_size_tokens
        self.max_seq_length = max_seq_length
        self.shuffle = shuffle
        
        self.dataset = MusicDataset(data_path, max_seq_length)
        self.dataloader = DataLoader(
            self.dataset,
            batch_size=1,
            shuffle=shuffle,
            num_workers=0,
            collate_fn=lambda batch: batch[0]
        )
    
    def __iter__(self):
        """Create batches based on token count."""
        batch_inputs = []
        batch_targets = []
        current_batch_tokens = 0
        
        for input_ids, target_ids in self.dataloader:
            seq_len = input_ids.size(0)
            
            if current_batch_tokens + seq_len > self.batch_size_tokens and len(batch_inputs) > 0:
                max_len = max(seq.size(0) for seq in batch_inputs)
                padded_inputs = []
                padded_targets = []
                
                for inp, tgt in zip(batch_inputs, batch_targets):
                    pad_len = max_len - inp.size(0)
                    if pad_len > 0:
                        inp = torch.cat([inp, torch.full((pad_len,), -1, dtype=inp.dtype)])
                        tgt = torch.cat([tgt, torch.full((pad_len,), -1, dtype=tgt.dtype)])
                    padded_inputs.append(inp)
                    padded_targets.append(tgt)
                
                yield torch.stack(padded_inputs), torch.stack(padded_targets)
                
                batch_inputs = []
                batch_targets = []
                current_batch_tokens = 0
            
            batch_inputs.append(input_ids)
            batch_targets.append(target_ids)
            current_batch_tokens += seq_len
        
        if len(batch_inputs) > 0:
            max_len = max(seq.size(0) for seq in batch_inputs)
            padded_inputs = []
            padded_targets = []
            
            for inp, tgt in zip(batch_inputs, batch_targets):
                pad_len = max_len - inp.size(0)
                if pad_len > 0:
                    inp = torch.cat([inp, torch.full((pad_len,), -1, dtype=inp.dtype)])
                    tgt = torch.cat([tgt, torch.full((pad_len,), -1, dtype=tgt.dtype)])
                padded_inputs.append(inp)
                padded_targets.append(tgt)
            
            yield torch.stack(padded_inputs), torch.stack(padded_targets)

print("Data loading utilities defined!")


In [None]:
# Training functions
def get_lr_schedule(optimizer, num_steps, warmup_steps=0):
    """Create cosine annealing learning rate schedule."""
    if warmup_steps > 0:
        def lr_lambda(step):
            if step < warmup_steps:
                return step / warmup_steps
            else:
                progress = (step - warmup_steps) / (num_steps - warmup_steps)
                return 0.5 * (1 + torch.cos(torch.tensor(progress * 3.14159)))
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    else:
        scheduler = CosineAnnealingLR(optimizer, T_max=num_steps)
    return scheduler


def train_one_epoch(model, train_loader, optimizer, scheduler, device, log_interval=100):
    """Train model for one epoch."""
    model.train()
    total_loss = 0.0
    num_steps = 0
    
    for step, (input_ids, target_ids) in enumerate(train_loader):
        input_ids = input_ids.to(device)
        target_ids = target_ids.to(device)
        
        # Clamp token IDs to valid range
        vocab_size = model.vocab_size
        input_ids = torch.clamp(input_ids, 0, vocab_size - 1)
        target_ids = torch.clamp(target_ids, 0, vocab_size - 1)
        input_ids = torch.where(input_ids == -1, torch.tensor(0, device=device), input_ids)
        
        # Forward pass
        logits, loss = model(input_ids, target_ids)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        num_steps += 1
        
        if step % log_interval == 0:
            avg_loss = total_loss / num_steps
            current_lr = scheduler.get_last_lr()[0] if hasattr(scheduler, 'get_last_lr') else optimizer.param_groups[0]['lr']
            print(f"Step {step:6d} | Loss: {avg_loss:.4f} | LR: {current_lr:.2e}")
    
    avg_loss = total_loss / num_steps
    return avg_loss


@torch.no_grad()
def evaluate(model, val_loader, device):
    """Evaluate model on validation set."""
    model.eval()
    total_loss = 0.0
    num_steps = 0
    
    for input_ids, target_ids in val_loader:
        input_ids = input_ids.to(device)
        target_ids = target_ids.to(device)
        
        vocab_size = model.vocab_size
        input_ids = torch.clamp(input_ids, 0, vocab_size - 1)
        target_ids = torch.clamp(target_ids, 0, vocab_size - 1)
        input_ids = torch.where(input_ids == -1, torch.tensor(0, device=device), input_ids)
        
        logits, loss = model(input_ids, target_ids)
        total_loss += loss.item()
        num_steps += 1
    
    avg_loss = total_loss / num_steps if num_steps > 0 else float('inf')
    return avg_loss

print("Training functions defined!")


In [None]:
# Load tokenizer
tokenizer = MusicTokenizer.load(OUTPUT_DIR / "tokenizer.pkl")
print(f"Tokenizer loaded: vocab_size={tokenizer.vocab_size}")

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Model configuration (Tiny model for quick testing)
# You can modify these for different model sizes
MODEL_CONFIG = {
    'd_model': 128,
    'n_layers': 2,
    'n_heads': 2,
    'd_ff': 512,
    'max_seq_length': 5000,
    'dropout': 0.1
}

# Initialize model
model = MusicTransformer(
    vocab_size=tokenizer.vocab_size,
    **MODEL_CONFIG
).to(device)

num_params = model.count_parameters()
print(f"\nModel initialized:")
print(f"  Parameters: {num_params:,} ({num_params/1e6:.2f}M)")
print(f"  Config: {MODEL_CONFIG}")


In [None]:
# Load data
train_path = OUTPUT_DIR / "tokenized" / "train" / "data.json"
val_path = OUTPUT_DIR / "tokenized" / "val" / "data.json"

BATCH_SIZE_TOKENS = 50000  # Target tokens per batch
MAX_SEQ_LENGTH = MODEL_CONFIG['max_seq_length']

train_loader = MusicDataLoader(
    train_path,
    batch_size_tokens=BATCH_SIZE_TOKENS,
    max_seq_length=MAX_SEQ_LENGTH,
    shuffle=True
)

val_loader = MusicDataLoader(
    val_path,
    batch_size_tokens=BATCH_SIZE_TOKENS,
    max_seq_length=MAX_SEQ_LENGTH,
    shuffle=False
)

print(f"Data loaders created:")
print(f"  Batch size (tokens): {BATCH_SIZE_TOKENS:,}")
print(f"  Max sequence length: {MAX_SEQ_LENGTH}")


## 11. Train Model


In [None]:
# Training configuration
LEARNING_RATE = 3e-4
NUM_EPOCHS = 1  # For scaling laws study, typically train for 1 epoch
WARMUP_STEPS = 0
LOG_INTERVAL = 100

# Estimate number of steps (approximate)
# You can adjust this based on your actual data size
estimated_steps = len(train_loader) if hasattr(train_loader, '__len__') else 1000

# Setup optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
scheduler = get_lr_schedule(optimizer, estimated_steps, warmup_steps=WARMUP_STEPS)

print(f"Training configuration:")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Estimated steps per epoch: {estimated_steps}")
print(f"  Log interval: {LOG_INTERVAL}")

# Train for 1 epoch
print(f"\n{'='*60}")
print(f"Training for {NUM_EPOCHS} epoch(s)...")
print(f"{'='*60}\n")

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    print("-" * 60)
    
    train_loss = train_one_epoch(
        model, train_loader, optimizer, scheduler, device,
        log_interval=LOG_INTERVAL
    )
    
    # Evaluate on validation set
    val_loss = evaluate(model, val_loader, device)
    
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss: {val_loss:.4f}")

print(f"\n{'='*60}")
print("Training complete!")
print(f"{'='*60}")


## 12. Save Model and Results

You can save the trained model and use it for generation or further analysis.


In [None]:
# Save model
model_save_path = OUTPUT_DIR / "model.pt"
torch.save({
    'model_state_dict': model.state_dict(),
    'model_config': MODEL_CONFIG,
    'vocab_size': tokenizer.vocab_size,
    'num_parameters': num_params,
}, model_save_path)

print(f"Model saved to: {model_save_path}")

# Optional: Test generation
print("\nTesting generation...")
model.eval()
with torch.no_grad():
    # Create a simple starting sequence (just a few tokens)
    start_tokens = torch.tensor([[tokenizer.token_to_id.get('C', 0)]], device=device)
    generated = model.generate(start_tokens, max_new_tokens=50, temperature=1.0)
    
    # Decode generated tokens
    generated_tokens = generated[0].cpu().tolist()
    generated_abc = tokenizer.decode(generated_tokens)
    print(f"Generated ABC (first 50 tokens): {generated_abc[:200]}...")

print("\n✓ All done! You can now:")
print("  1. Modify MODEL_CONFIG to train different model sizes")
print("  2. Train multiple models and collect validation losses")
print("  3. Plot scaling laws: validation loss vs. model size")
print("  4. Generate music samples from trained models")
