In [1]:
##Train the GPT2 model on NES-MDB MIDI dataset

In [7]:
# %% [markdown]
# # NES-MDB Chiptune Transformer — Minimal, Robust Notebook (2025)
# - Tokenize raw MIDIs directly (monophonic skyline) with **MIDILike** (no durations)
# - Build dataset with short-clip support and fallback seq_len
# - Train a small GPT-like model (Transformers)
# - Generate continuation + try to write MIDI
#
# Works with:
#   torch (CUDA if available), transformers 4.4x, datasets 2.2x,
#   miditok 3.0.6, miditoolkit 0.1.16, pretty_midi 0.2.9, numpy 2.x

# =========================
# 0) Compatibility & env check
# =========================
# %%
import sys, os, json, random, numpy as np
from importlib.metadata import version, PackageNotFoundError

# NumPy 2.x removed np.int / np.bool etc; some libs still reference them
if not hasattr(np, "int"):    np.int = int
if not hasattr(np, "bool"):   np.bool = bool
if not hasattr(np, "float"):  np.float = float
if not hasattr(np, "object"): np.object = object

def pkgver(name: str) -> str:
    try: return version(name)
    except PackageNotFoundError: return "not-found"

import torch
print("Python:", sys.version)
print("Torch:", torch.__version__, "| CUDA build:", torch.version.cuda, "| cuda available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

import transformers, datasets
print("Transformers:", pkgver("transformers"))
print("Datasets:",    pkgver("datasets"))
print("miditok:",     pkgver("miditok"))
print("miditoolkit:", pkgver("miditoolkit"))
print("pretty_midi:", pkgver("pretty_midi"))
print("numpy:",       np.__version__)



Python: 3.10.18 | packaged by Anaconda, Inc. | (main, Jun  5 2025, 13:08:55) [MSC v.1929 64 bit (AMD64)]
Torch: 2.4.0+cu121 | CUDA build: 12.1 | cuda available: True
GPU: NVIDIA GeForce RTX 3060 Laptop GPU
Transformers: 4.44.2
Datasets: 2.21.0
miditok: 3.0.6
miditoolkit: 0.1.16
pretty_midi: 0.2.9
numpy: 2.2.6


In [8]:

# =========================
# 1) Paths & knobs
# =========================
# %%
from pathlib import Path

REPO      = Path.cwd()
DATA_DIR  = REPO / "data" / "nesmdb_midi"   # <--- put your raw MIDIs here
WORK      = REPO / "nes_transformer"

TOK_DIR   = WORK / "tokens"       # token JSONs { "ids": [...] }
RUN_DIR   = WORK / "hf_runs"      # HF checkpoints / logs
SAMPLES   = WORK / "samples"      # generated MIDIs
for p in [WORK, TOK_DIR, RUN_DIR, SAMPLES]:
    p.mkdir(parents=True, exist_ok=True)

print("DATA_DIR:", DATA_DIR.resolve(), "exists:", DATA_DIR.exists())

# --- tokenization speed/quality knobs ---
SUBSET_N   = 2000      # process only this many raw MIDIs now (None = all)  << adjust for time
SEED       = 42
LO_PITCH   = 48        # C3
HI_PITCH   = 96        # C7
MAX_TICKS  = None      # e.g. 20000 to crop early for speed; None = full file (more robust)
USE_THREADS= False     # start serial for stability; flip to True for speed once it works
MAX_WORKERS= max(4, os.cpu_count() or 4)



DATA_DIR: C:\Users\rohit\Downloads\hacknc2025-1\hacknc2025\src\ai\data\nesmdb_midi exists: False


In [9]:

# =========================
# 2) Tokenize directly from RAW (MIDILike, skyline to 1-voice)
# =========================
# %%
# Patch miditoolkit.Note with a .duration property (needed by miditok's converter under the hood)
import miditoolkit
try:
    from miditoolkit.midi.containers import Note as MTKNote
except Exception:
    from miditoolkit import Note as MTKNote
if not hasattr(MTKNote, "duration"):
    MTKNote.duration = property(lambda self: self.end - self.start)

from miditoolkit import MidiFile, Instrument, Note
from miditok import MIDILike, TokenizerConfig, TokSequence
from tqdm import tqdm

# Minimal tokenizer config (avoid durations/tempos/rets complexity)
tok_config = TokenizerConfig(
    beat_res={(0,0):4},     # 16th grid for positions
    use_chords=False,
    use_rests=False,
    use_tempos=False,
    use_time_signatures=False,
    use_programs=False
)
tokenizer = MIDILike(tok_config)
print("Tokenizer: MIDILike | vocab_size:", tokenizer.vocab_size)

# Skyline melody extractor (ticks): always monophonic
def skyline_ticks(notes, min_dur=1):
    if not notes: return []
    times = sorted({n.start for n in notes} | {n.end for n in notes})
    out = []
    cur_pitch, cur_start = None, None
    for t in times:
        active = [n for n in notes if n.start <= t < n.end]
        if active:
            pitch = max(active, key=lambda n: n.pitch).pitch
            if pitch != cur_pitch:
                if cur_pitch is not None and (t - cur_start) >= min_dur:
                    out.append(Note(velocity=90, pitch=cur_pitch, start=cur_start, end=t))
                cur_pitch, cur_start = pitch, t
        else:
            if cur_pitch is not None and (t - cur_start) >= min_dur:
                out.append(Note(velocity=90, pitch=cur_pitch, start=cur_start, end=t))
            cur_pitch, cur_start = None, None
    return out

# Gather raw files (+ optional subset)
raw_all = sorted([p for ext in ("*.mid","*.midi","*.MID","*.MIDI") for p in DATA_DIR.rglob(ext)])
print("Raw MIDIs found:", len(raw_all))
if SUBSET_N is not None and SUBSET_N < len(raw_all):
    random.seed(SEED)
    raw = sorted(random.sample(raw_all, SUBSET_N))
    print(f"Using subset: {len(raw)} / {len(raw_all)}")
else:
    raw = raw_all
    print("Using ALL raw files")

def tokenize_one(path: Path) -> Path | None:
    out = TOK_DIR / (path.stem + ".json")
    if out.exists():
        return out
    try:
        mf = MidiFile(str(path))
        # collect notes (non-drum, in range), optional crop for speed
        cand = []
        for inst in mf.instruments:
            if inst.is_drum or not inst.notes:
                continue
            notes = inst.notes
            if MAX_TICKS is not None:
                notes = [n for n in notes if n.start < MAX_TICKS]
            notes = [n for n in notes if LO_PITCH <= n.pitch <= HI_PITCH]
            cand.extend(notes)
        if len(cand) < 4:
            return None

        mono = skyline_ticks(cand, min_dur=1)
        if len(mono) < 4:
            return None

        # Build tiny 1-track MIDI in memory (keeps original grid)
        one = MidiFile(ticks_per_beat=mf.ticks_per_beat)
        one.tempo_changes = mf.tempo_changes
        one.time_signature_changes = mf.time_signature_changes
        inst = Instrument(program=80, is_drum=False, name="lead")
        inst.notes = mono
        one.instruments = [inst]

        # Tokenize (MIDILike)
        toks = tokenizer.tokenize(one) if hasattr(tokenizer, "tokenize") else tokenizer(one)
        ids  = toks.ids if hasattr(toks, "ids") else toks
        if not ids:
            return None

        out.write_text(json.dumps({"ids": ids}))
        return out
    except Exception:
        return None

token_files = []
if USE_THREADS:
    from concurrent.futures import ThreadPoolExecutor, as_completed
    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
        futures = [ex.submit(tokenize_one, p) for p in raw]
        for f in tqdm(as_completed(futures), total=len(futures), desc="Tokenizing (threads)"):
            res = f.result()
            if res is not None:
                token_files.append(res)
else:
    for p in tqdm(raw, desc="Tokenizing (serial)"):
        res = tokenize_one(p)
        if res is not None:
            token_files.append(res)

print("Token files written:", len(token_files), "→", TOK_DIR)

# Simple peek
if token_files:
    sample_ids = json.loads(token_files[0].read_text())["ids"]
    print("Sample token length:", len(sample_ids))



Tokenizer: MIDILike | vocab_size: 338
Raw MIDIs found: 10556
Using subset: 2000 / 10556


  toks = tokenizer.tokenize(one) if hasattr(tokenizer, "tokenize") else tokenizer(one)
Tokenizing (serial):   7%|▋         | 136/2000 [00:09<02:16, 13.62it/s]


KeyboardInterrupt: 

In [5]:

# =========================
# 3) Simple 16th-grid tokenizer (no MidiTok)
# =========================
from pathlib import Path
from tqdm import tqdm
from dataclasses import dataclass
import json, os, random, math

from miditoolkit import MidiFile, Instrument, Note

# ---- Paths (redefine if kernel was restarted) ----
REPO      = Path.cwd()
DATA_DIR  = REPO / "data" / "nesmdb_midi"   # raw MIDIs here
WORK      = REPO / "nes_transformer"
TOK_DIR   = WORK / "tokens_simple"
SAMPLES   = WORK / "samples"
for p in [WORK, TOK_DIR, SAMPLES]:
    p.mkdir(parents=True, exist_ok=True)

print("DATA_DIR:", DATA_DIR.resolve(), "| exists:", DATA_DIR.exists())

# ---- Token vocabulary: PAD=0, REST=1, HOLD=2, P0..P127=3..130 ----
PAD_ID  = 0
REST_ID = 1
HOLD_ID = 2
PITCH_BASE = 3  # P0 maps to 3, P127 maps to 130
VOCAB_SIZE = PITCH_BASE + 128

def pitch_to_id(p: int) -> int:
    p = max(0, min(127, int(p)))
    return PITCH_BASE + p

def id_to_pitch(i: int) -> int | None:
    if i >= PITCH_BASE and i < PITCH_BASE + 128:
        return i - PITCH_BASE
    return None  # REST/HOLD/PAD

# ---- Monophonic skyline over ticks ----
def skyline_ticks(notes, min_dur_ticks: int = 1):
    if not notes:
        return []
    times = sorted({n.start for n in notes} | {n.end for n in notes})
    out = []
    cur_pitch, cur_start = None, None
    for t in times:
        active = [n for n in notes if n.start <= t < n.end]
        if active:
            pitch = max(active, key=lambda n: n.pitch).pitch
            if pitch != cur_pitch:
                if cur_pitch is not None and (t - cur_start) >= min_dur_ticks:
                    out.append((cur_pitch, cur_start, t))
                cur_pitch, cur_start = pitch, t
        else:
            if cur_pitch is not None and (t - cur_start) >= min_dur_ticks:
                out.append((cur_pitch, cur_start, t))
            cur_pitch, cur_start = None, None
    return out

# ---- Quantize monophonic notes to a 16th grid ----
def quantize_to_grid_16th(mf: MidiFile, mono_notes, max_steps: int | None = None):
    tpq = max(1, int(mf.ticks_per_beat))
    ticks_per_step = max(1, tpq // 4)  # 16th = TPQ/4
    if not mono_notes:
        return []

    max_tick = 0
    for pitch, s, e in mono_notes:
        max_tick = max(max_tick, e)

    total_steps = math.ceil(max_tick / ticks_per_step)
    if max_steps is not None:
        total_steps = min(total_steps, max_steps)

    seq = [REST_ID] * max(1, total_steps)

    cur_idx = 0
    for pitch, s, e in mono_notes:
        start_idx = int(round(s / ticks_per_step))
        end_idx   = int(max(start_idx + 1, round(e / ticks_per_step)))
        if max_steps is not None:
            start_idx = min(start_idx, max_steps - 1)
            end_idx   = min(end_idx,   max_steps)

        # fill leading rest if any
        while cur_idx < start_idx and cur_idx < len(seq):
            seq[cur_idx] = REST_ID
            cur_idx += 1

        if start_idx < len(seq):
            seq[start_idx] = pitch_to_id(pitch)
            cur_idx = start_idx + 1

        # fill holds
        while cur_idx < end_idx and cur_idx < len(seq):
            seq[cur_idx] = HOLD_ID
            cur_idx += 1

    return seq

# ---- File → tokens pipeline ----
def midi_to_token_ids(path: Path, lo_pitch: int | None = None, hi_pitch: int | None = None,
                      crop_ticks: int | None = None, max_steps: int | None = 2048):
    try:
        mf = MidiFile(str(path))
    except Exception:
        return None

    # collect non-drum notes (optionally filter pitch / crop)
    notes = []
    for inst in mf.instruments:
        if inst.is_drum or not inst.notes:
            continue
        ns = inst.notes
        if crop_ticks is not None:
            ns = [n for n in ns if n.start < crop_ticks]
        if lo_pitch is not None and hi_pitch is not None:
            ns = [n for n in ns if lo_pitch <= n.pitch <= hi_pitch]
        notes.extend(ns)

    if not notes:
        return None

    mono = skyline_ticks(notes, min_dur_ticks=1)
    if not mono:
        return None

    ids = quantize_to_grid_16th(mf, mono, max_steps=max_steps)
    # keep only reasonably sized sequences
    return ids if len(ids) >= 16 else None

# ---- Batch tokenization ----
SUBSET_N   = 2000          # start small to confirm; bump to 1500–3000 when it works
SEED       = 42
LO_PITCH   = None         # None means keep all pitches (widest, safest)
HI_PITCH   = None
CROP_TICKS = None         # set e.g. 20000 to speed up
MAX_STEPS  = 1024         # cap per piece

raw_all = sorted([p for ext in ("*.mid","*.midi","*.MID","*.MIDI") for p in DATA_DIR.rglob(ext)])
print("Raw files:", len(raw_all))
random.seed(SEED)
raw = sorted(random.sample(raw_all, min(SUBSET_N, len(raw_all))))
print("Using subset:", len(raw))

written = 0
for p in tqdm(raw, desc="Tokenizing (simple)"):
    out = TOK_DIR / f"{p.stem}.json"
    if out.exists():
        written += 1
        continue
    ids = midi_to_token_ids(p, lo_pitch=LO_PITCH, hi_pitch=HI_PITCH,
                            crop_ticks=CROP_TICKS, max_steps=MAX_STEPS)
    if ids is None:
        continue
    out.write_text(json.dumps({"ids": ids}))
    written += 1

print("Token files written:", written, "→", TOK_DIR)
print("VOCAB_SIZE:", VOCAB_SIZE, "| PAD/REST/HOLD ids:", PAD_ID, REST_ID, HOLD_ID)





DATA_DIR: C:\Users\rohit\Downloads\hacknc2025-1\hacknc2025\src\ai\data\nesmdb_midi | exists: False
Raw files: 0
Using subset: 0


Tokenizing (simple): 0it [00:00, ?it/s]

Token files written: 0 → c:\Users\rohit\Downloads\hacknc2025-1\hacknc2025\src\ai\nes_transformer\tokens_simple
VOCAB_SIZE: 131 | PAD/REST/HOLD ids: 0 1 2





In [6]:

# =========================
# 4) Build dataset (keeps short clips; 512→256→128 fallback)
# =========================
from datasets import Dataset
import json, random

def build_ds(token_dir: Path, seq_len: int, keep_short_min: int = 8, step_frac: float = 0.5):
    files = sorted(token_dir.glob("*.json"))
    sequences = []
    step = max(1, int(seq_len * step_frac))
    for p in files:
        try:
            ids = json.loads(p.read_text())["ids"]
        except Exception:
            continue
        if not ids:
            continue
        if len(ids) <= seq_len:
            if len(ids) >= keep_short_min:
                sequences.append({"input_ids": ids, "labels": ids.copy()})
            continue
        # sliding window
        for i in range(0, len(ids) - seq_len + 1, step):
            seq = ids[i:i+seq_len]
            sequences.append({"input_ids": seq, "labels": seq.copy()})
    return sequences

SEQ_TRY = [512, 256, 128]
final_sequences, final_len = None, None
for L in SEQ_TRY:
    seqs = build_ds(TOK_DIR, seq_len=L, keep_short_min=8, step_frac=0.5)
    print(f"SEQ_LEN={L} -> sequences: {len(seqs)}")
    if seqs:
        final_sequences, final_len = seqs, L
        break

if final_sequences is None:
    raise RuntimeError("No sequences produced. Increase SUBSET_N, set CROP_TICKS=None, or lower keep_short_min to 4, then re-run.")

random.shuffle(final_sequences)
ds = Dataset.from_list(final_sequences).train_test_split(test_size=0.05, seed=42)
print(f"USING SEQ_LEN={final_len} | train={len(ds['train'])} | test={len(ds['test'])}")
ds




SEQ_LEN=512 -> sequences: 0
SEQ_LEN=256 -> sequences: 0
SEQ_LEN=128 -> sequences: 0


RuntimeError: No sequences produced. Increase SUBSET_N, set CROP_TICKS=None, or lower keep_short_min to 4, then re-run.

In [None]:

# =========================
# 5) Define & train model (no eval during training)
# =========================
import torch
from transformers import GPT2Config, AutoModelForCausalLM, Trainer, TrainingArguments

vocab_size = VOCAB_SIZE
gpt_cfg = GPT2Config(
    vocab_size=vocab_size,
    n_positions=max(1024, final_len * 2),
    n_embd=256,
    n_layer=4,
    n_head=8,
    n_inner=1024,
)
model = AutoModelForCausalLM.from_config(gpt_cfg)

def collate(batch):
    PAD = PAD_ID
    maxlen = max(len(x["input_ids"]) for x in batch)
    input_ids, labels, attn = [], [], []
    for x in batch:
        seq = x["input_ids"]
        pad = [PAD] * (maxlen - len(seq))
        inp = seq + pad
        lab = seq + pad
        for j in range(len(seq), maxlen):
            lab[j] = -100  # mask pads
        input_ids.append(inp)
        labels.append(lab)
        attn.append([1]*len(seq) + [0]*len(pad))
    return {
        "input_ids": torch.tensor(input_ids, dtype=torch.long),
        "labels": torch.tensor(labels, dtype=torch.long),
        "attention_mask": torch.tensor(attn, dtype=torch.long),
    }

BATCH = 8
args = TrainingArguments(
    output_dir=str(WORK / "hf_runs_simple"),
    per_device_train_batch_size=BATCH,
    per_device_eval_batch_size=BATCH,
    learning_rate=3e-4,
    warmup_steps=200,
    num_train_epochs=3,
    logging_steps=50,
    # IMPORTANT: turn OFF eval to avoid numpy conversion in the eval loop
    eval_strategy="no",          # (use this new arg name on 4.44+)
    save_strategy="steps",
    save_steps=1000,
    report_to=[],
    fp16=torch.cuda.is_available(),
    optim="adamw_torch",
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=ds["train"],
    data_collator=collate,
    # NOTE: don't pass eval_dataset here
)

trainer.train()



  self.scaler = torch.cuda.amp.GradScaler(**kwargs)
  attn_output = torch.nn.functional.scaled_dot_product_attention(
 15%|█▌        | 51/339 [00:03<00:18, 15.50it/s]

{'loss': 3.2979, 'grad_norm': 2.6805660724639893, 'learning_rate': 7.35e-05, 'epoch': 0.44}


 30%|███       | 102/339 [00:06<00:13, 17.70it/s]

{'loss': 2.6101, 'grad_norm': 1.4782524108886719, 'learning_rate': 0.00014849999999999998, 'epoch': 0.88}


 45%|████▍     | 152/339 [00:08<00:09, 18.83it/s]

{'loss': 2.2788, 'grad_norm': 2.8734097480773926, 'learning_rate': 0.00022349999999999998, 'epoch': 1.33}


 60%|█████▉    | 202/339 [00:11<00:06, 21.47it/s]

{'loss': 2.354, 'grad_norm': 3.0835440158843994, 'learning_rate': 0.0002985, 'epoch': 1.77}


 75%|███████▍  | 254/339 [00:13<00:03, 22.45it/s]

{'loss': 2.2819, 'grad_norm': 1.3314687013626099, 'learning_rate': 0.00019424460431654675, 'epoch': 2.21}


 89%|████████▉ | 302/339 [00:15<00:01, 19.01it/s]

{'loss': 2.154, 'grad_norm': 1.2306898832321167, 'learning_rate': 8.633093525179855e-05, 'epoch': 2.65}


100%|██████████| 339/339 [00:17<00:00, 19.00it/s]

{'train_runtime': 17.8337, 'train_samples_per_second': 152.072, 'train_steps_per_second': 19.009, 'train_loss': 2.454839003121255, 'epoch': 3.0}





TrainOutput(global_step=339, training_loss=2.454839003121255, metrics={'train_runtime': 17.8337, 'train_samples_per_second': 152.072, 'train_steps_per_second': 19.009, 'total_flos': 25352295800832.0, 'train_loss': 2.454839003121255, 'epoch': 3.0})

In [None]:

# =========================
# 6) Generate continuation and write MIDI
# =========================
def tokens_to_midi(ids, out_path: Path, bpm: int = 140):
    """Decode simple tokens back to a single-track MIDI."""
    tpq = 480
    ticks_per_step = tpq // 4  # 16th grid
    cur_idx = 0
    notes = []

    i = 0
    while i < len(ids):
        tid = ids[i]
        pitch = id_to_pitch(tid)
        if pitch is None:
            i += 1
            continue
        # start note
        start = i
        j = i + 1
        while j < len(ids) and ids[j] == HOLD_ID:
            j += 1
        end = j
        start_tick = start * ticks_per_step
        end_tick   = max(start_tick + ticks_per_step, end * ticks_per_step)
        notes.append((pitch, start_tick, end_tick))
        i = j

    # write MIDI
    mf = MidiFile(ticks_per_beat=tpq)
    inst = Instrument(program=80, is_drum=False, name="lead")
    inst.notes = [Note(velocity=90, pitch=p, start=s, end=e) for (p, s, e) in notes]
    mf.instruments = [inst]
    # Add a constant tempo (approximate)
    from miditoolkit.midi.containers import TempoChange
    mf.tempo_changes = [TempoChange(tempo=bpm, time=0)]
    mf.dump(str(out_path))

# Pick a short seed (or any)
files = sorted(TOK_DIR.glob("*.json"))
assert files, "No token files found—rerun Step 3."
seed_path = min(files, key=lambda p: len(json.loads(p.read_text())["ids"]))
seed_ids  = json.loads(seed_path.read_text())["ids"]

inp = torch.tensor(seed_ids, dtype=torch.long)[None, :]
if torch.cuda.is_available():
    model.to("cuda"); inp = inp.to("cuda")

model.eval()
with torch.no_grad():
    gen = model.generate(
        input_ids=inp,
        max_new_tokens=256,
        do_sample=True,
        temperature=1.0,
        top_p=0.95,
        pad_token_id=PAD_ID
    )

out_ids = gen[0].tolist()
out_mid = SAMPLES / "sample_simple.mid"
tokens_to_midi(out_ids, out_mid, bpm=140)
print("Wrote:", out_mid)


Wrote: c:\Users\rohit\Downloads\hacknc2025\hacknc2025\src\ai\nes_transformer\samples\sample_simple.mid


In [None]:
# Add these cells at the end of train.ipynb AFTER trainer.train()

# Save HF format for the Python server
model_dir = WORK / "model_export"
model.save_pretrained(model_dir)
print("Saved model to", model_dir)

# (Optional) Export ONNX for Node runtime
import torch
onnx_path = WORK / "model.onnx"
dummy = torch.randint(0, PITCH_BASE+128, (1, 32), dtype=torch.long)
attn = torch.ones_like(dummy)
torch.onnx.export(
    model, (dummy, attn),
    str(onnx_path),
    input_names=["input_ids", "attention_mask"],
    output_names=["logits"],
    dynamic_axes={"input_ids": {0: "batch", 1: "seq"}, "attention_mask": {0:"batch", 1:"seq"}, "logits": {0:"batch", 1:"seq"}},
    opset_version=17
)
print("Wrote", onnx_path)
