# 1) Preparing Dataset

## Tokenize MIDI Files

In [1]:
from miditok import REMI
from pathlib import Path
from tqdm import tqdm

# Choose a tokenizer
tokenizer = REMI()

# Path to MIDI dataset
MIDI_DIR = Path("../data/audio")

# Enocde all MIDI files into tokens
tokenized = []

for midi_path in tqdm(list(MIDI_DIR.glob("*.mid")), desc="Tokenizing MIDI files"):
    tokens = tokenizer.encode(midi_path) # pass path directly
    tokenized.append(tokens)

Tokenizing MIDI files: 100%|██████████| 182/182 [00:16<00:00, 11.13it/s]


In [2]:
# Save tokenizer
SAVE_DIR = Path("tokenizer")

# Save
tokenizer.save(SAVE_DIR)

print(f"Tokenizer saved to {SAVE_DIR}/")

Tokenizer saved to tokenizer/


In [3]:
# Load tokenizer (testing)
LOAD_DIR = Path("tokenizer/tokenizer.json")

tokenizer2 = REMI(params=LOAD_DIR)

print("Tokenizer reloaded!")

Tokenizer reloaded!


## Prepare Dataset for Training

In [None]:
def to_int_list(seq):
    """
    Convert any seq (TokSequence, list of TokSequence, or list of ints) 
    into a flat list of ints.
    """
    if hasattr(seq, "ids"):  # direct TokSequence
        return list(seq.ids)
    
    elif isinstance(seq, (list, tuple)):
        out = []
        for x in seq:
            if hasattr(x, "ids"):      # nested TokSequence
                out.extend(list(x.ids))
            elif isinstance(x, int):
                out.append(x)
            else:
                raise TypeError(f"Unexpected element type: {type(x)}")
        return out
    
    elif isinstance(seq, int):
        return [seq]
    
    else:
        raise TypeError(f"Unexpected type at top level: {type(seq)}")


# Apply to all your tokenized data
int_seqs = [to_int_list(seq) for seq in tokenized]

# Compute vocab size
vocab_size = max(max(seq) for seq in int_seqs) + 1

print(f"Vocab size: {vocab_size}")
print("First sequence:", int_seqs[0][:20])


Vocab size: 276
First sequence: [4, 190, 56, 108, 128, 32, 105, 127, 196, 56, 108, 127, 44, 102, 126, 202, 63, 115, 127, 48]


In [19]:
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch
from torch.nn.functional import pad
from tqdm import tqdm


class MIDIDataset(Dataset):
    def __init__(self, sequences, seq_len=32, pad_id=0, stride=1):
        """
        Args:
            sequences: list of token ID lists
            seq_len: length of each training sequence
            pad_id: padding token ID
            stride: step size for sliding window
                    - 1 = fully overlapping
                    - seq_len = non-overlapping
                    - e.g. 8 = partial overlap
        """
        self.data = []
        self.seq_len = seq_len
        self.pad_id = pad_id

        for seq in tqdm(sequences, desc="Creating dataset"):
            ids = list(seq)

            if len(ids) < 2:  # skip too-short sequences
                continue

            # Sliding windows with configurable stride
            for i in range(0, len(ids) - 1, stride):
                x = ids[i:i+seq_len]
                y = ids[i+1:i+seq_len+1]

                # Pad if shorter than seq_len
                if len(x) < seq_len:
                    x = pad(torch.tensor(x, dtype=torch.long),
                            (0, seq_len - len(x)), value=pad_id)
                    y = pad(torch.tensor(y, dtype=torch.long),
                            (0, seq_len - len(y)), value=pad_id)
                else:
                    x = torch.tensor(x, dtype=torch.long)
                    y = torch.tensor(y, dtype=torch.long)

                self.data.append((x, y))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [20]:
def collate_fn(batch, pad_id=0):
    """Pads a batch of variable-length sequences to the max length in batch."""
    xs, ys = zip(*batch)  # unzip
    xs = pad_sequence(xs, batch_first=True, padding_value=pad_id)
    ys = pad_sequence(ys, batch_first=True, padding_value=pad_id)
    return xs, ys

In [21]:
dataset = MIDIDataset(int_seqs, seq_len=32, pad_id=0, stride=8)
loader = DataLoader(dataset, batch_size=32, shuffle=True,
                    collate_fn=lambda b: collate_fn(b, pad_id=0))

print("Dataset size:", len(dataset))

# Grab a single batch
x, y = next(iter(loader))

print("x shape:", x.shape)  # (batch_size, max_len_in_batch)
print("y shape:", y.shape)  # (batch_size, max_len_in_batch)
print("x[0]:", x[0])        # first input sequence
print("y[0]:", y[0])        # corresponding target sequence

Creating dataset: 100%|██████████| 182/182 [00:07<00:00, 23.80it/s]

Dataset size: 488913
x shape: torch.Size([32, 32])
y shape: torch.Size([32, 32])
x[0]: tensor([112, 129, 206,  65, 112, 137, 218,  68, 112, 129,   4, 190,  65, 112,
        129, 194,  63, 112, 129, 198,  62, 112, 129, 202,  63, 112, 129, 206,
         56, 112, 126,  68])
y[0]: tensor([129, 206,  65, 112, 137, 218,  68, 112, 129,   4, 190,  65, 112, 129,
        194,  63, 112, 129, 198,  62, 112, 129, 202,  63, 112, 129, 206,  56,
        112, 126,  68, 112])





# 2) Create Transformer

## Define a Simple Transformer Model

In [17]:
import torch
import torch.nn as nn

class MusicTransformer(nn.Module):
    """
    A simple Transformer model for music token prediction.
    Takes a sequence of token IDs (from MIDI tokenization)
    and predicts the next token in the sequence.
    """
    
    def __init__(self, 
                vocab_size: int,   # number of unique tokens (notes, events, etc.)
                embed_dim: int = 256, 
                n_heads: int = 4, 
                n_layers: int = 4,
                ff_dim: int = 512, # feedforward dimension inside Transformer
                dropout: float = 0.1,
                max_seq_len: int = 1024):
        
        super().__init__()
        
        # Token embedding layer
        self.embed = nn.Embedding(vocab_size, embed_dim)
        
        # Positional encoding (learnable, not sinusoidal)
        self.positional_encoding = nn.Parameter(
            torch.zeros(1, max_seq_len, embed_dim)
        )
        
        # Transformer encoder (batch_first=True removes permutes!)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, 
            nhead=n_heads, 
            dim_feedforward=ff_dim,
            dropout=dropout,
            batch_first=True   # keep batch dim first
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        
        # Final projection to vocab
        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch, seq_len) of token IDs
        Returns:
            (batch, seq_len, vocab_size) of logits
        """
        # Embed tokens: (batch, seq_len, embed_dim)
        x = self.embed(x)
        
        # Add positional encoding
        x = x + self.positional_encoding[:, :x.size(1), :]
        
        # Run through Transformer
        out = self.transformer(x)  # (batch, seq_len, embed_dim)
        
        # Project to vocab
        return self.fc(out)  # (batch, seq_len, vocab_size)


# 3) Training Loop

In [25]:
import torch
import torch.nn as nn
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
EPOCHS = 5

# Initialize model, optimizer, and loss
model = MusicTransformer(vocab_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0

    progress_bar = tqdm(enumerate(loader), 
                        desc=f"Epoch {epoch+1}/{EPOCHS}", 
                        total=len(loader),
                        leave=True)

    for step, (x, y) in progress_bar:
        x, y = x.to(device), y.to(device)

        # Forward pass (shift inputs/targets)
        logits = model(x[:, :-1])
        loss = loss_fn(
            logits.reshape(-1, logits.size(-1)),
            y[:, 1:].reshape(-1)
        )

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()
        avg_loss = total_loss / (step + 1)

        # Update progress bar live
        progress_bar.set_postfix({
            "step": f"{step+1}/{len(loader)}",
            "loss": f"{loss.item():.4f}",
            "avg_loss": f"{avg_loss:.4f}"
        })

    print(f"Epoch {epoch+1} finished | Average Loss: {avg_loss:.4f}")

    # Save checkpoint
    torch.save({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "loss": avg_loss,
    }, f"checkpoint_epoch{epoch+1}.pt")


Epoch 1/5:   0%|          | 0/15279 [00:00<?, ?it/s]

Epoch 1/5: 100%|██████████| 15279/15279 [05:52<00:00, 43.40it/s, step=15279/15279, loss=0.0840, avg_loss=0.6239]


Epoch 1 finished | Average Loss: 0.6239


Epoch 2/5: 100%|██████████| 15279/15279 [05:51<00:00, 43.41it/s, step=15279/15279, loss=0.0901, avg_loss=0.0942]


Epoch 2 finished | Average Loss: 0.0942


Epoch 3/5: 100%|██████████| 15279/15279 [05:41<00:00, 44.78it/s, step=15279/15279, loss=0.0646, avg_loss=0.0862]


Epoch 3 finished | Average Loss: 0.0862


Epoch 4/5: 100%|██████████| 15279/15279 [05:53<00:00, 43.27it/s, step=15279/15279, loss=0.0838, avg_loss=0.0823]


Epoch 4 finished | Average Loss: 0.0823


Epoch 5/5: 100%|██████████| 15279/15279 [05:57<00:00, 42.77it/s, step=15279/15279, loss=0.0596, avg_loss=0.0794]


Epoch 5 finished | Average Loss: 0.0794


# 4) Generate

## vocab utils

In [58]:
# note name -> MIDI number (C4 = 60)
NOTE_TO_SEMITONE = {
    'C': 0, 'C#': 1, 'Db': 1, 'D': 2, 'D#': 3, 'Eb': 3, 'E': 4, 'F': 5,
    'F#': 6, 'Gb': 6, 'G': 7, 'G#': 8, 'Ab': 8, 'A': 9, 'A#': 10, 'Bb': 10, 'B': 11
}


def build_vocab_maps(tokenizer):
    # tokenizer.vocab is a dict token->id as you showed
    vocab = tokenizer.vocab
    id_to_token = {int(v): k for k, v in vocab.items()}
    return vocab, id_to_token


def find_token_key_starting_with(vocab, prefix):
    # returns token key matching prefix exactly or starting with prefix
    if prefix in vocab:
        return prefix
    # if multiple tokens share prefix (e.g. Duration_1.0.8) pick the first found
    for k in vocab:
        if k.startswith(prefix):
            return k
    return None


def choose_velocity_token(vocab, target_vel=90):
    # collect velocity numeric tokens
    vel_tokens = []
    for k in vocab:
        if k.startswith("Velocity_"):
            try:
                n = int(k.split("_", 1)[1])
                vel_tokens.append((n, k))
            except Exception:
                pass
    if not vel_tokens:
        return None
    # pick token with closest numeric velocity to target
    vel_tokens.sort(key=lambda x: abs(x[0] - target_vel))
    return vel_tokens[0][1]

## note utils

In [59]:
def note_name_to_midi(note):
    # Accept formats like C4, C#4, Db4
    # Split letter(s) and octave
    import re
    m = re.match(r'^([A-G][#b]?)(-?\d+)$', note)
    if not m:
        raise ValueError(f"Invalid note name: {note}")
    name, octave = m.group(1), int(m.group(2))
    return 12 * (octave + 1) + NOTE_TO_SEMITONE[name]  # MIDI formula

In [60]:
def note_token_ids_from_symbolic(tokenizer, note_symbol, default_velocity=90):
    """
    Convert a single symbolic like "C4_q" -> list of token ids [Pitch_x, Velocity_y, Duration_z]
    duration codes supported: w,h,q,e (whole, half, quarter, eighth)
    """
    vocab, _ = build_vocab_maps(tokenizer)
    # parse like "C4_q" or "C#4_q"
    if "_" not in note_symbol:
        raise ValueError("Symbolic note must be like 'C4_q'")
    note_part, dur_code = note_symbol.split("_", 1)

    midi = note_name_to_midi(note_part)   # 60 for C4
    pitch_key = f"Pitch_{midi}"
    pitch_token = find_token_key_starting_with(vocab, pitch_key)
    if pitch_token is None:
        raise RuntimeError(f"Pitch token for MIDI {midi} not found in vocab")

    # duration mapping (assumes vocab durations use numeric beats like 1.0, 0.5, 2.0, ...)
    dur_map = {
        'w': '4.0',   # whole
        'h': '2.0',   # half
        'q': '1.0',   # quarter
        'e': '0.5',   # eighth
        's': '0.25',  # sixteenth (if available)
    }
    dur_value = dur_map.get(dur_code, None)
    if dur_value is None:
        # maybe user passed exact duration like "1.0.8"
        dur_value = dur_code

    duration_prefix = f"Duration_{dur_value}"
    duration_token = find_token_key_starting_with(vocab, duration_prefix)
    if duration_token is None:
        # fallback: pick a default duration token (quarter-ish)
        duration_token = find_token_key_starting_with(vocab, "Duration_1.0")
        if duration_token is None:
            raise RuntimeError("No suitable Duration token found in vocab")

    velocity_token = choose_velocity_token(vocab, target_vel=default_velocity)
    if velocity_token is None:
        # fallback: try a specific velocity token name
        velocity_token = find_token_key_starting_with(vocab, "Velocity_95")
        if velocity_token is None:
            raise RuntimeError("No Velocity token found in vocab")

    # return ids as integers
    return [int(vocab[pitch_token]), int(vocab[velocity_token]), int(vocab[duration_token])]

In [61]:
import random

DUR_CODES = ["w", "h", "q", "e", "s"]  # whole, half, quarter, eighth, sixteenth

def random_note_symbol(octave_range=(3,5)):
    """
    Generate a random symbolic note like 'C4_q' or 'F#5_e'
    """
    note_names = ["C","C#","D","D#","E","F","F#","G","G#","A","A#","B"]
    note = random.choice(note_names)
    octave = random.randint(octave_range[0], octave_range[1])
    dur = random.choice(DUR_CODES)
    return f"{note}{octave}_{dur}"

def random_start_symbols(n=3, octave_range=(3,5)):
    """
    Generate N random symbolic notes
    """
    return [random_note_symbol(octave_range) for _ in range(n)]


## Generator

In [86]:
# Determine the type of token
def token_type(token_name):
    for t in ["Pitch", "Velocity", "Duration", "Bar", "Position"]:
        if token_name.startswith(t):
            return t
    return None

In [90]:
import torch
from miditok import TokSequence

def generate_with_rhythm(
    model, tokenizer,
    start_symbols=["C4_q", "E4_q", "G4_q"],
    max_len=2000,
    out_path="model_outputs/generated_rhythm.mid",
    max_input_len_window=1024,
    default_velocity=90,
    stop_on_eos=True
):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    model.eval()

    # Convert start symbols to token ids
    start_ids = []
    for sym in start_symbols:
        start_ids.extend(note_token_ids_from_symbolic(tokenizer, sym, default_velocity))

    ids = start_ids.copy()
    vocab, id_to_token = build_vocab_maps(tokenizer)

    # Try to find EOS id
    eos_id = int(vocab["EOS_None"]) if "EOS_None" in vocab else None

    with torch.no_grad():
        for step in range(max_len):
            window = ids[-max_input_len_window:]
            x = torch.tensor([window], dtype=torch.long, device=device)
            out = model(x)
            logits = out[0] if isinstance(out, (tuple, list)) else out

            if not isinstance(logits, torch.Tensor):
                logits = torch.tensor(logits, device=device)

            next_id = int(torch.argmax(logits[0, -1]).item())
            token_name = id_to_token.get(next_id, "")

            # --- Smarter Rhythm Filtering ---
            # Get last non-Bar/Position token to determine expected note type
            last_note_tokens = [id_to_token.get(t, "") for t in reversed(ids) 
                                if not id_to_token.get(t, "").startswith(("Bar", "Position"))]
            last_type = token_type(last_note_tokens[0]) if last_note_tokens else None

            # Determine expected type for note triplet
            if last_type == "Pitch":
                expected = "Velocity"
            elif last_type == "Velocity":
                expected = "Duration"
            else:
                expected = "Pitch"  # start of new triplet

            # Always accept Bar/Position tokens; otherwise enforce triplet order
            if token_name.startswith(("Bar", "Position")) or token_name.startswith(expected):
                ids.append(next_id)

            if stop_on_eos and eos_id is not None and next_id == eos_id:
                print(f"Hit EOS at step {step+1}")
                break

            if (step + 1) % 500 == 0:
                print(f"Generated {step+1} tokens (current length {len(ids)})...")

    # Decode with TokSequence
    seq = TokSequence(ids=ids)
    decoded = tokenizer.decode([seq])
    midi_obj = decoded[0] if isinstance(decoded, (list, tuple)) else decoded

    # Save MIDI
    if hasattr(midi_obj, "dump_midi"):
        midi_obj.dump_midi(out_path)
    elif hasattr(midi_obj, "dumps_midi"):
        data = midi_obj.dumps_midi()
        with open(out_path, "wb") as f:
            f.write(data if isinstance(data, bytes) else data.encode())
    else:
        raise RuntimeError("Decoded object has no known dump method")

    print(f"✅ Saved MIDI to {out_path}")
    return ids


In [92]:
# OR use random start notes
rand_syms = random_start_symbols(n=3)
print("🎵 Random start notes:", rand_syms)
ids = generate_with_rhythm(model, tokenizer, start_symbols=rand_syms, max_len=5000)

🎵 Random start notes: ['F#5_h', 'B4_q', 'F3_h']
Generated 500 tokens (current length 9)...
Generated 1000 tokens (current length 9)...
Generated 1500 tokens (current length 9)...
Generated 2000 tokens (current length 9)...
Generated 2500 tokens (current length 9)...
Generated 3000 tokens (current length 9)...
Generated 3500 tokens (current length 9)...
Generated 4000 tokens (current length 9)...
Generated 4500 tokens (current length 9)...
Generated 5000 tokens (current length 9)...
✅ Saved MIDI to model_outputs/generated_rhythm.mid


In [96]:
import os
import subprocess
from pydub import AudioSegment

def midi_to_wav(midi_path, 
                soundfont_path="../soundfonts/AegeanSymphonicOrchestra-SND.sf2", 
                output_dir="model_outputs/", 
                output_name=None, 
                sample_rate=44100, 
                gain=2.0,
                normalize=True,
                fluidsynth_path=r"C:\tools\fluidsynth\bin\fluidsynth.exe",
                print_details=True):
    """
    Convert a MIDI file to WAV using FluidSynth and a specified SoundFont.

    Args:
        midi_path (str): Path to the input MIDI file.
        soundfont_path (str): Path to the SoundFont (.sf2) file.
        output_dir (str): Directory to save the output WAV.
        output_name (str): Optional output filename (without extension). Defaults to MIDI basename.
        sample_rate (int): Sampling rate for WAV.
        gain (float): Gain multiplier for FluidSynth.
        normalize (bool): Normalize the audio using pydub.
        fluidsynth_path (str): Path to FluidSynth executable.
        print_details (bool): Whether to print logs.

    Returns:
        str: Path to the generated WAV file.
    """
    os.makedirs(output_dir, exist_ok=True)

    # Determine output filename
    if output_name is None:
        base = os.path.splitext(os.path.basename(midi_path))[0]
        output_name = base
    output_wav = os.path.join(output_dir, f"{output_name}.wav")

    if print_details:
        print(f"Rendering '{midi_path}' to WAV...")
        print(f"Using SoundFont: {soundfont_path}")
        print(f"Output path: {output_wav}")

    # FluidSynth command
    cmd = [
        fluidsynth_path,
        "-a", "file",          # write to file
        "-F", output_wav,      # output WAV
        "-ni", soundfont_path, # load SoundFont
        "-r", str(sample_rate),
        "-g", str(gain),
        midi_path
    ]

    # Run FluidSynth
    try:
        subprocess.run(cmd, check=True)
    except Exception as e:
        print(f"Error running FluidSynth: {e}")
        raise

    # Normalize audio if requested
    if normalize:
        sound = AudioSegment.from_wav(output_wav)
        normalized = sound.normalize()
        normalized.export(output_wav, format="wav")

    if print_details:
        print(f"✅ Successfully converted MIDI to WAV: {output_wav}")

    return output_wav


In [97]:
midi_path = "model_outputs/demo_generated_long.mid"
wav_path = midi_to_wav(midi_path)

Rendering 'model_outputs/demo_generated_long.mid' to WAV...
Using SoundFont: ../soundfonts/AegeanSymphonicOrchestra-SND.sf2
Output path: model_outputs/demo_generated_long.wav
✅ Successfully converted MIDI to WAV: model_outputs/demo_generated_long.wav


# DEBUGGING

In [79]:
# Look at the first 3 sequences
for i, seq in enumerate(int_seqs[:3]):
    print(f"Sequence {i} (first 50 tokens):")
    print(seq[:50])

Sequence 0 (first 50 tokens):
[4, 190, 56, 108, 128, 32, 105, 127, 196, 56, 108, 127, 44, 102, 126, 202, 63, 115, 127, 48, 106, 126, 207, 63, 114, 127, 44, 105, 126, 213, 65, 112, 127, 49, 104, 126, 218, 65, 108, 127, 44, 103, 126, 4, 192, 63, 108, 126, 48, 100]
Sequence 1 (first 50 tokens):
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 190, 61, 107, 139, 206, 60, 107, 129, 210, 61, 106, 129, 214, 60, 107, 129, 218, 61, 106, 129, 4, 190, 63, 107, 129, 194, 61, 106, 129, 198, 59, 106, 129, 202, 58, 106, 129, 206, 58, 107]
Sequence 2 (first 50 tokens):
[4, 4, 190, 65, 116, 149, 214, 63, 116, 133, 4, 190, 62, 116, 132, 198, 63, 117, 137, 210, 65, 116, 129, 214, 66, 116, 129, 218, 68, 116, 129, 4, 190, 69, 113, 132, 4, 190, 63, 116, 149, 214, 61, 116, 133, 4, 190, 60, 116, 133]


In [80]:
vocab, id_to_token = build_vocab_maps(tokenizer)
for t in seq[:50]:
    print(id_to_token[t], end=", ")
print()

Bar_None, Bar_None, Position_0, Pitch_81, Velocity_91, Duration_3.0.8, Position_24, Pitch_79, Velocity_91, Duration_1.0.8, Bar_None, Position_0, Pitch_78, Velocity_91, Duration_0.7.8, Position_8, Pitch_79, Velocity_95, Duration_1.4.8, Position_20, Pitch_81, Velocity_91, Duration_0.4.8, Position_24, Pitch_82, Velocity_91, Duration_0.4.8, Position_28, Pitch_84, Velocity_91, Duration_0.4.8, Bar_None, Position_0, Pitch_85, Velocity_79, Duration_0.7.8, Bar_None, Position_0, Pitch_79, Velocity_91, Duration_3.0.8, Position_24, Pitch_77, Velocity_91, Duration_1.0.8, Bar_None, Position_0, Pitch_76, Velocity_91, Duration_1.0.8, 


In [93]:
import torch
from miditok import TokSequence

def generate_raw(model, tokenizer, start_symbols=["C4_q", "E4_q", "G4_q"], max_len=2000, out_path="model_outputs/raw_output.mid"):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    model.eval()

    # Convert start symbols to token IDs
    start_ids = []
    for sym in start_symbols:
        start_ids.extend(note_token_ids_from_symbolic(tokenizer, sym))

    ids = start_ids.copy()
    vocab, id_to_token = build_vocab_maps(tokenizer)

    with torch.no_grad():
        for step in range(max_len):
            x = torch.tensor([ids[-1024:]], dtype=torch.long, device=device)
            out = model(x)
            logits = out[0] if isinstance(out, (tuple, list)) else out
            next_id = int(torch.argmax(logits[0, -1]).item())
            ids.append(next_id)

            if (step + 1) % 500 == 0:
                print(f"Generated {step+1} tokens (current length {len(ids)})...")

    # Map IDs back to token names
    token_names = [id_to_token.get(i, str(i)) for i in ids]

    # Decode to MIDI (without filtering)
    seq = TokSequence(ids=ids)
    try:
        decoded = tokenizer.decode([seq])
    except Exception:
        decoded = tokenizer(seq)

    midi_obj = decoded[0] if isinstance(decoded, (list, tuple)) else decoded
    if hasattr(midi_obj, "dump_midi"):
        midi_obj.dump_midi(out_path)
    elif hasattr(midi_obj, "dumps_midi"):
        data = midi_obj.dumps_midi()
        with open(out_path, "wb") as f:
            if isinstance(data, str):
                f.write(data.encode())
            else:
                f.write(data)

    print(f"✅ Saved raw MIDI to {out_path}")
    return ids, token_names

# Example usage
rand_syms = random_start_symbols(n=3)
print("🎵 Random start notes:", rand_syms)
ids, tokens = generate_raw(model, tokenizer, start_symbols=rand_syms, max_len=2000)


🎵 Random start notes: ['B3_q', 'D#5_e', 'B4_s']
Generated 500 tokens (current length 509)...
Generated 1000 tokens (current length 1009)...
Generated 1500 tokens (current length 1509)...
Generated 2000 tokens (current length 2009)...
✅ Saved raw MIDI to model_outputs/raw_output.mid
