
# 03 — AudioCraft generator training (MusicGen, unconditional debug)

Stage 2: train the MusicGen token LM on the prepared FMA mini dataset, using the compression/codebook model trained in 02. Run this after 01b (dataset prep) and 02 (compression debug). No dataset download or prep happens here.


## 1) Imports + shared paths/constants

In [1]:
from pathlib import Path

# Ensure writable caches for numba/joblib in container environments
import os as _os
_os.environ.setdefault('NUMBA_CACHE_DIR', '/tmp/numba_cache')
_os.environ.setdefault('NUMBA_DISABLE_CACHING', '1')
_os.environ.setdefault('JOBLIB_TEMP_FOLDER', '/tmp')

import os, sys, subprocess, datetime, json
from typing import List, Tuple, Optional

AUDIOCRAFT_REPO = Path("/root/workspace/audiocraft")
AUDIOCRAFT_DORA_DIR = Path("/root/workspace/experiments/audiocraft")
OUTPUT_DIR = Path("/root/workspace/Training/outputs/musicgen_uncond_debug")

DSET = "audio/fma_small_mini"
CONFIG_PATH = Path("/root/workspace/Training/model_config/fma_small_mini.yaml")
EGS_DIR = Path("/root/workspace/data/fma_small_mini/egs")
EXPECTED_EGS = [EGS_DIR / "train", EGS_DIR / "valid"]

# Debug-friendly hyperparams; tweak as needed
SEGMENT_SECONDS = 10
BATCH_SIZE = 4
NUM_WORKERS = max(2, min(8, (os.cpu_count() or 8) // 4))
UPDATES_PER_EPOCH = 30
EPOCHS = 1
GENERATE_EVERY = 10
EVALUATE_EVERY = 10
SEED = 1234
GENERATE_SAMPLES = 2
GENERATE_SECONDS = 8  # duration per sample for inference

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print(f"AUDIOCRAFT_REPO exists: {AUDIOCRAFT_REPO.exists()}")
print(f"AUDIOCRAFT_DORA_DIR: {AUDIOCRAFT_DORA_DIR}")
try:
    ffmpeg_ver = subprocess.run(["ffmpeg", "-version"], capture_output=True, text=True)
    first_line = (ffmpeg_ver.stdout or ffmpeg_ver.stderr).splitlines()[0]
    print("ffmpeg:", first_line)
except FileNotFoundError:
    print("ffmpeg not found (install via apt if you need extra formats)")


AUDIOCRAFT_REPO exists: True
AUDIOCRAFT_DORA_DIR: /root/workspace/experiments/audiocraft
ffmpeg: ffmpeg version 4.4.2-0ubuntu0.22.04.1 Copyright (c) 2000-2021 the FFmpeg developers


## 2) Verify prerequisites exist

In [2]:

missing = []
for p in [CONFIG_PATH, AUDIOCRAFT_REPO, EGS_DIR]:
    if not p.exists():
        missing.append(str(p))

missing_egs = [str(p) for p in EXPECTED_EGS if not p.exists()]
if missing:
    raise FileNotFoundError(f"Missing required paths: {missing}. Run 01b to prepare data and 02 for codec training.")
if missing_egs:
    raise FileNotFoundError(f"Missing egs folders: {missing_egs}. Re-run 01b_fma_small_mini_downloader.ipynb.")

xps_root = AUDIOCRAFT_DORA_DIR / "xps"
xps_root.mkdir(parents=True, exist_ok=True)
print("All prerequisite paths are present.")


All prerequisite paths are present.


## 3) Find the latest compression checkpoint automatically

In [3]:

import yaml
from audiocraft.solvers import CompressionSolver

def read_solver_from_config(xp_dir: Path):
    for candidate in [xp_dir / "config.yaml", xp_dir / ".hydra" / "config.yaml", xp_dir / "hydra-config.yaml"]:
        if candidate.exists():
            try:
                cfg = yaml.safe_load(candidate.read_text())
            except Exception as exc:  # noqa: BLE001
                print(f"[warn] Could not parse {candidate}: {exc}")
                continue
            solver = None
            if isinstance(cfg, dict):
                solver = cfg.get("solver")
                if solver is None and isinstance(cfg.get("xp"), dict):
                    solver = cfg["xp"].get("solver")
            return solver, cfg
    return None, None

def find_xps(filter_keyword: Optional[str] = None):
    xp_root = AUDIOCRAFT_DORA_DIR / "xps"
    if not xp_root.exists():
        return []
    results = []
    for xp in xp_root.iterdir():
        if xp.is_dir():
            solver, cfg = read_solver_from_config(xp)
            solver_str = str(solver) if solver is not None else ""
            if filter_keyword and filter_keyword not in solver_str:
                continue
            results.append((xp, solver_str))
    results.sort(key=lambda t: t[0].stat().st_mtime)
    return results

def pick_checkpoint(xp_dir: Path):
    priority = [
        "*best*.pt", "*best*.pth", "*best*.th",
        "*latest*.pt", "*latest*.pth", "*latest*.th",
        "checkpoint*.pt", "checkpoint*.pth", "checkpoint*.th",
    ]
    def grab(patterns):
        for pat in patterns:
            files = sorted(xp_dir.rglob(pat), key=lambda p: p.stat().st_mtime, reverse=True)
            if files:
                return files[0]
        return None
    ckpt = grab(priority)
    if ckpt is None:
        pool = sorted(list(xp_dir.rglob("*.pt")) + list(xp_dir.rglob("*.pth")) + list(xp_dir.rglob("*.th")),
                      key=lambda p: p.stat().st_mtime, reverse=True)
        if pool:
            ckpt = pool[0]
    return ckpt

compression_xps = find_xps("compression")
if not compression_xps:
    raise FileNotFoundError("No compression Dora runs found. Run notebook 02_audiocraft_train_compression_debug.ipynb first.")

compression_dir, compression_solver_name = compression_xps[-1]
compression_ckpt = pick_checkpoint(compression_dir)
if compression_ckpt is None:
    raise FileNotFoundError(f"No checkpoint file found under {compression_dir}")

ckpt_time = datetime.datetime.fromtimestamp(compression_ckpt.stat().st_mtime)
print(f"Using compression XP: {compression_dir.name} | solver: {compression_solver_name}")
print(f"Checkpoint: {compression_ckpt}")
print(f"Timestamp: {ckpt_time:%Y-%m-%d %H:%M:%S}")

compression_model = CompressionSolver.model_from_checkpoint(str(compression_ckpt), device="cpu")
COMPRESSION_META = dict(
    xp_dir=compression_dir,
    ckpt_path=compression_ckpt,
    sample_rate=compression_model.sample_rate,
    channels=compression_model.channels,
    n_q=getattr(compression_model, "num_codebooks", None),
    cardinality=getattr(compression_model, "cardinality", None),
    frame_rate=getattr(compression_model, "frame_rate", None),
)
if COMPRESSION_META["n_q"] is None and hasattr(compression_model, "quantizer"):
    COMPRESSION_META["n_q"] = getattr(compression_model.quantizer, "n_q", None)
print("Compression meta:", {k: v for k, v in COMPRESSION_META.items() if k not in ("xp_dir", "ckpt_path")})
del compression_model


Using compression XP: 060c08dd | solver: compression
Checkpoint: /root/workspace/experiments/audiocraft/xps/060c08dd/checkpoint.th
Timestamp: 2026-01-26 12:23:15


Using compression XP: 060c08dd | solver: compression
Checkpoint: /root/workspace/experiments/audiocraft/xps/060c08dd/checkpoint.th
Timestamp: 2026-01-26 12:23:15


Dora directory: /tmp/audiocraft_root


Using compression XP: 060c08dd | solver: compression
Checkpoint: /root/workspace/experiments/audiocraft/xps/060c08dd/checkpoint.th
Timestamp: 2026-01-26 12:23:15


Dora directory: /tmp/audiocraft_root


Compression meta: {'sample_rate': 16000, 'channels': 1, 'n_q': 32, 'cardinality': 1024, 'frame_rate': 50}




## 4) Identify generator solver/config (unconditional)

In [4]:

solver_root = AUDIOCRAFT_REPO / "config" / "solver"
musicgen_solvers = sorted([p.relative_to(solver_root).as_posix() for p in solver_root.glob("musicgen/*.yaml")])
audiogen_solvers = sorted([p.relative_to(solver_root).as_posix() for p in solver_root.glob("audiogen/*.yaml")])
print("Available musicgen solvers:", musicgen_solvers)
print("Available audiogen solvers:", audiogen_solvers)

# Choose a minimal, unconditional solver
LM_SOLVER = "musicgen/default"
if not (solver_root / (LM_SOLVER + ".yaml")).exists():
    raise FileNotFoundError(f"Expected solver config missing: {LM_SOLVER}.yaml")
print(f"Selected solver: {LM_SOLVER} (conditioner=none by default)")

LM_OVERRIDES = {
    "model.lm.model_scale": "xsmall",
    "conditioner": "none",
    "generate.lm.use_sampling": False,
    "generate.lm.prompted_samples": False,
    "generate.lm.unprompted_samples": True,
    "generate.lm.no_text_conditioning": True,
    "generate.lm.top_k": 0,
    "generate.lm.top_p": 0.0,
}
print("Overrides for unconditional debug:", LM_OVERRIDES)


Available musicgen solvers: ['musicgen/debug.yaml', 'musicgen/debug_mini.yaml', 'musicgen/default.yaml', 'musicgen/musicgen_base_32khz.yaml', 'musicgen/musicgen_melody_32khz.yaml', 'musicgen/musicgen_style_32khz.yaml']
Available audiogen solvers: ['audiogen/audiogen_base_16khz.yaml', 'audiogen/debug.yaml', 'audiogen/default.yaml']
Selected solver: musicgen/default (conditioner=none by default)
Overrides for unconditional debug: {'model.lm.model_scale': 'xsmall', 'conditioner': 'none', 'generate.lm.use_sampling': False, 'generate.lm.prompted_samples': False, 'generate.lm.unprompted_samples': True, 'generate.lm.no_text_conditioning': True, 'generate.lm.top_k': 0, 'generate.lm.top_p': 0.0}


## 5) Run a small debug training job via Dora

In [None]:
import shlex, subprocess
import tempfile

# Record existing generator runs to identify the new one after training
pre_gen_dirs = {xp.name for xp, _ in find_xps("musicgen")}

env = os.environ.copy()
env["AUDIOCRAFT_TEAM"] = env.get("AUDIOCRAFT_TEAM", "default")
env["AUDIOCRAFT_DORA_DIR"] = str(AUDIOCRAFT_DORA_DIR)
env["USER"] = env.get("USER", "root")
env["PYTHONWARNINGS"] = "ignore::FutureWarning,ignore::UserWarning"
env["NUMBA_CACHE_DIR"] = "/tmp/numba_cache"
env["NUMBA_DISABLE_CACHING"] = "1"
env["JOBLIB_TEMP_FOLDER"] = "/tmp"

# Create delay pattern for 32 codebooks
delays = list(range(COMPRESSION_META['n_q']))  # [0, 1, 2, ..., 31]

# Create a temporary solver config that extends musicgen/default with our specific values
temp_solver_config = AUDIOCRAFT_REPO / "config" / "solver" / "musicgen" / "debug_mini.yaml"
solver_config_content = f"""# @package __global__

defaults:
  - musicgen/default
  - _self_

sample_rate: {COMPRESSION_META['sample_rate']}
channels: {COMPRESSION_META['channels']}
compression_model_checkpoint: {COMPRESSION_META['ckpt_path']}

lm_model: transformer_lm

codebooks_pattern:
  modeling: delay
  delay:
    delays: {delays}
    flatten_first: 0
    empty_initial: 0

transformer_lm:
  n_q: {COMPRESSION_META['n_q']}
  card: {COMPRESSION_META['cardinality']}
  dim: 128
  num_heads: 4
  hidden_scale: 2
  num_layers: 3
  causal: true
  memory_efficient: true
  bias_proj: false
  bias_ff: false
  bias_attn: false
  norm_first: true
  layer_scale: null
  weight_init: gaussian
  depthwise_init: current
  zero_bias_init: true
  attention_as_float32: false

dataset:
  segment_duration: {SEGMENT_SECONDS}
  batch_size: {BATCH_SIZE}
  num_workers: {NUM_WORKERS}
  min_segment_ratio: 1.0

generate:
  every: null

evaluate:
  every: {EVALUATE_EVERY}

checkpoint:
  save_last: true
  save_every: null

optim:
  epochs: {EPOCHS}
  updates_per_epoch: {UPDATES_PER_EPOCH}

tokens:
  padding_with_special_token: false

seed: {SEED}
fsdp:
  use: false
logging:
  log_tensorboard: false
"""

temp_solver_config.write_text(solver_config_content)
print(f"✓ Created solver config: {temp_solver_config.name}")
print(f"  • {COMPRESSION_META['n_q']} codebooks × {COMPRESSION_META['cardinality']} cardinality")
print(f"  • Mini transformer: 3 layers, 128 dim, 4 heads, causal=true")
print(f"  • {EPOCHS} epoch × {UPDATES_PER_EPOCH} updates, batch {BATCH_SIZE}")
print(f"  • Checkpoints: save_last=true, generation: disabled during training")

# Now run with the simplified command
cmd = [
    "python",
    "-m", "dora", "run",
    f"solver=musicgen/debug_mini",
    f"dset={DSET}",
    "conditioner=none",
]
print(f"\nStarting training...")

# Capture output
result = subprocess.run(cmd, cwd=str(AUDIOCRAFT_REPO), env=env, check=False, capture_output=True, text=True)

if result.returncode != 0:
    print("\n❌ Training failed")
    print("\n=== STDERR (last 8000 chars) ===")
    print(result.stderr[-8000:] if len(result.stderr) > 8000 else result.stderr)
    if temp_solver_config.exists():
        temp_solver_config.unlink()
    raise subprocess.CalledProcessError(result.returncode, cmd, result.stdout, result.stderr)
else:
    # Show training progress
    print("\n✓ Training completed!")
    stderr_lines = result.stderr.splitlines()
    if stderr_lines:
        progress_lines = [l for l in stderr_lines if any(x in l for x in ['Train', 'Valid', 'Evaluate', 'Model size', 'checkpoint', 'Saving'])]
        if progress_lines:
            print("\nKey training logs:")
            for line in progress_lines[-25:]:
                # Strip ANSI codes for cleaner output
                clean_line = line.replace('[36m', '').replace('[34m', '').replace('[32m', '').replace('[0m', '')
                print(clean_line)

# Clean up temp config
if temp_solver_config.exists():
    temp_solver_config.unlink()

post_gen = find_xps("musicgen")
new_gen = [xp for xp, _ in post_gen if xp.name not in pre_gen_dirs]
if not post_gen:
    raise FileNotFoundError("No MusicGen runs found after training.")
GEN_XP_DIR = new_gen[-1] if new_gen else post_gen[-1][0]
print(f"\n✓ Generator XP created: {GEN_XP_DIR.name}")
print(f"  Location: {GEN_XP_DIR}")


## 5b) Train a MusicGen model (FULL RUN)

In [20]:
from pathlib import Path
import os
import subprocess

# Single switch for all paths (defaults to your new location)
BASE_DIR = Path(os.environ.get("WORKSPACE_DIR", "/root/workspace"))

AUDIOCRAFT_REPO_DIR = BASE_DIR / "audiocraft"
EXPERIMENTS_DIR     = BASE_DIR / "experiments" / "audiocraft"

DSET = "audio/fma_small_mini"
SOLVER = "musicgen/musicgen_base_32khz"  # MusicGen solver, not compression

SEGMENT_SECONDS = 10
BATCH_SIZE = 64  # Increased from 8 to improve GPU utilization

# Auto-pick workers: reduce to 4-6 to avoid CPU contention
_cpu_count = os.cpu_count() or 32
# NUM_WORKERS = min(16, max(8, _cpu_count // 4))   # 8–16 is usually the sweet spot
NUM_WORKERS = 16 # this worked well on v84 cpu

UPDATES_PER_EPOCH = 50
VALID_NUM_SAMPLES = 30
GENERATE_EVERY = 10
EVALUATE_EVERY = 10
AUTOCAST = True
NUM_THREADS = 16  # Set number of threads for PyTorch operations
MP_START_METHOD = "fork" # Use 'fork' to reduce overhead on Linux systems
# MP_START_METHOD = "forkserver" # Alternative method if issues arise with 'fork'

CONFIG_PATH = AUDIOCRAFT_REPO_DIR / "config" / "dset" / "audio" / "fma_small_mini.yaml"
TRAIN_JSONL  = BASE_DIR / "data" / "fma_small_mini" / "egs" / "train" / "data.jsonl"
VALID_JSONL  = BASE_DIR / "data" / "fma_small_mini" / "egs" / "valid" / "data.jsonl"

# Use the compression checkpoint from the earlier training
COMPRESSION_CHECKPOINT = str(COMPRESSION_META['ckpt_path'])

# Create delay pattern for all codebooks
delays_str = "[" + ",".join(str(i) for i in range(COMPRESSION_META['n_q'])) + "]"

print(NUM_WORKERS)
print(CONFIG_PATH)
print(TRAIN_JSONL)
print(VALID_JSONL)
print(AUDIOCRAFT_REPO_DIR)
print(f"Using compression checkpoint: {COMPRESSION_CHECKPOINT}")
print(f"Compression model: {COMPRESSION_META['sample_rate']}Hz, {COMPRESSION_META['n_q']} codebooks, card={COMPRESSION_META['cardinality']}")

# Setup environment
env = os.environ.copy()
env['AUDIOCRAFT_TEAM'] = 'default'
env['AUDIOCRAFT_DORA_DIR'] = str(EXPERIMENTS_DIR)
env['USER'] = env.get('USER', 'root')
env['PYTHONWARNINGS'] = 'ignore::FutureWarning,ignore::UserWarning'

print(f"\nUsing config: dset={DSET}, solver={SOLVER}")
print(f"Training params: segment_duration={SEGMENT_SECONDS}, batch_size={BATCH_SIZE}, num_workers={NUM_WORKERS}")
print(f"Optimizer: updates_per_epoch={UPDATES_PER_EPOCH}")
print(f"Validation: num_samples={VALID_NUM_SAMPLES}")
print(f"Evaluate every: {EVALUATE_EVERY}")
print(f"Autocast: {AUTOCAST}")
print(f"Generate every: {GENERATE_EVERY}")
print(f"Num threads: {NUM_THREADS}")
print(f"MP start method: {MP_START_METHOD}")

cmd = f"""
cd {AUDIOCRAFT_REPO_DIR} && \
python -m dora run \
  solver={SOLVER} \
  dset={DSET} \
  compression_model_checkpoint={COMPRESSION_CHECKPOINT} \
  sample_rate={COMPRESSION_META['sample_rate']} \
  channels={COMPRESSION_META['channels']} \
  transformer_lm.card={COMPRESSION_META['cardinality']} \
  transformer_lm.n_q={COMPRESSION_META['n_q']} \
  'codebooks_pattern.delay.delays={delays_str}' \
  conditioner=none \
  dataset.segment_duration={SEGMENT_SECONDS} \
  dataset.batch_size={BATCH_SIZE} \
  dataset.num_workers={NUM_WORKERS} \
  optim.updates_per_epoch={UPDATES_PER_EPOCH} \
  dataset.valid.num_samples={VALID_NUM_SAMPLES} \
  generate.every={GENERATE_EVERY} \
  evaluate.every={EVALUATE_EVERY} \
  autocast={str(AUTOCAST).lower()} \
  num_threads={NUM_THREADS} \
  mp_start_method={MP_START_METHOD}
"""

subprocess.run(cmd, shell=True, check=True, env=env)


16
/root/workspace/audiocraft/config/dset/audio/fma_small_mini.yaml
/root/workspace/data/fma_small_mini/egs/train/data.jsonl
/root/workspace/data/fma_small_mini/egs/valid/data.jsonl
/root/workspace/audiocraft
Using compression checkpoint: /root/workspace/experiments/audiocraft/xps/060c08dd/checkpoint.th
Compression model: 16000Hz, 32 codebooks, card=1024

Using config: dset=audio/fma_small_mini, solver=musicgen/musicgen_base_32khz
Training params: segment_duration=10, batch_size=64, num_workers=16
Optimizer: updates_per_epoch=50
Validation: num_samples=30
Evaluate every: 10
Autocast: True
Generate every: 10
Num threads: 16
MP start method: fork


Dora directory: /root/workspace/experiments/audiocraft
[[36m01-26 13:37:41[0m][[34mdora.distrib[0m][[32mINFO[0m] - world_size is 1, skipping init.[0m
[[36m01-26 13:37:41[0m][[34mflashy.solver[0m][[32mINFO[0m] - Instantiating solver MusicGenSolver for XP bd119d3b[0m
[[36m01-26 13:37:41[0m][[34mflashy.solver[0m][[32mINFO[0m] - All XP logs are stored in /root/workspace/experiments/audiocraft/xps/bd119d3b[0m
[[36m01-26 13:37:41[0m][[34maudiocraft.solvers.builders[0m][[32mINFO[0m] - Loading audio data split train: /root/workspace/data/fma_small_mini/egs/train[0m
[[36m01-26 13:37:41[0m][[34maudiocraft.solvers.builders[0m][[32mINFO[0m] - Loading audio data split valid: /root/workspace/data/fma_small_mini/egs/valid[0m
[[36m01-26 13:37:41[0m][[34maudiocraft.solvers.builders[0m][[32mINFO[0m] - Loading audio data split evaluate: /root/workspace/data/fma_small_mini/egs/valid[0m
[[36m01-26 13:37:41[0m][[34maudiocraft.solvers.builders[0m][[32mINFO[0m] 

KeyboardInterrupt: 

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/usr/local/lib/python3.11/dist-packages/dora/__main__.py", line 174, in <module>
    main()
  File "/usr/local/lib/python3.11/dist-packages/dora/__main__.py", line 170, in main
    args.action(args, main)
  File "/usr/local/lib/python3.11/dist-packages/dora/run.py", line 69, in run_action
    main()
  File "/usr/local/lib/python3.11/dist-packages/dora/main.py", line 86, in __call__
    return self._main()
           ^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/dora/hydra.py", line 228, in _main
    return hydra.main(
           ^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/hydra/main.py", line 94, in decorated_main
    _run_hydra(
  File "/usr/local/lib/python3.11/dist-packages/hydra/_internal/utils.py", line 394, in _run_hydra
    _run_app(
  File "/usr/local/lib/python3.11/dist-packages/hydra/_internal/utils.

## 6) Locate generator checkpoint and generate audio

In [21]:
from audiocraft.solvers import CompressionSolver
from audiocraft.utils import checkpoint
from audiocraft.models.builders import get_lm_model
import torch, torchaudio
import omegaconf

# Re-discover in case the session was reloaded
musicgen_xps = find_xps("musicgen")
if not musicgen_xps:
    musicgen_xps = find_xps("audiogen")
if not musicgen_xps:
    raise FileNotFoundError("No generator XP found. Run the training cell above.")

GEN_XP_DIR = musicgen_xps[-1][0]
GEN_CKPT = pick_checkpoint(GEN_XP_DIR)
if GEN_CKPT is None:
    raise FileNotFoundError(f"No checkpoint found under {GEN_XP_DIR}")

print("Generator checkpoint info:")
print(f"  XP: {GEN_XP_DIR.name}")
print(f"  Checkpoint: {GEN_CKPT.name}")
print(f"  Compression: {COMPRESSION_META['ckpt_path'].name}")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"\nLoading model on {device}...")

try:
    # Load the config
    hydra_config = GEN_XP_DIR / ".hydra" / "config.yaml"
    if not hydra_config.exists():
        raise FileNotFoundError(f"Config not found at {hydra_config}")
    
    cfg = omegaconf.OmegaConf.load(hydra_config)
    cfg.device = device
    
    # Load compression model
    print(f"Loading compression model from {COMPRESSION_META['ckpt_path']}...")
    compression_model = CompressionSolver.model_from_checkpoint(COMPRESSION_META['ckpt_path'], device=device)
    frame_rate = compression_model.frame_rate
    
    # Load LM model
    print(f"Building LM model...")
    lm_model = get_lm_model(cfg)
    
    # Load model state from checkpoint
    print(f"Loading LM weights from {GEN_CKPT.name}...")
    checkpoint_data = checkpoint.load_checkpoint(GEN_CKPT, is_sharded=False)
    lm_model.load_state_dict(checkpoint_data['model'])
    lm_model.to(device)
    lm_model.eval()
    
    max_gen_len = int(GENERATE_SECONDS * frame_rate)
    
    print(f"✓ Model loaded: {max_gen_len} tokens ({GENERATE_SECONDS}s at {frame_rate} Hz)")
    
    print(f"\nGenerating {GENERATE_SAMPLES} samples...")
    with torch.no_grad():
        tokens = lm_model.generate(
            prompt=None,
            num_samples=GENERATE_SAMPLES,
            max_gen_len=max_gen_len,
            use_sampling=False,
            top_k=0,
            top_p=0.0,
        )
        audio = compression_model.decode(tokens)
    
    audio = audio.detach().cpu()
    sample_rate = int(compression_model.sample_rate)
    saved = []
    
    print(f"Saving audio files to {OUTPUT_DIR}...")
    for i in range(audio.shape[0]):
        wav = audio[i]
        if wav.dim() == 1:
            wav = wav.unsqueeze(0)
        out_path = OUTPUT_DIR / f"generated_{i:03d}.wav"
        torchaudio.save(str(out_path), wav, sample_rate)
        duration = wav.shape[-1] / sample_rate
        rms = torch.sqrt(torch.mean(wav ** 2)).item()
        peak = torch.max(torch.abs(wav)).item()
        saved.append({"path": str(out_path), "duration_sec": round(duration, 2), "rms": round(rms, 4), "peak": round(peak, 4)})
    
    print("\n✓ Generated samples:")
    for item in saved:
        print(f"  {Path(item['path']).name}: {item['duration_sec']}s, RMS={item['rms']}, peak={item['peak']}")
    
    SUMMARY = {
        "compression_checkpoint": str(COMPRESSION_META["ckpt_path"]),
        "generator_checkpoint": str(GEN_CKPT),
        "sample_rate": sample_rate,
        "num_codebooks": COMPRESSION_META["n_q"],
        "cardinality": COMPRESSION_META["cardinality"],
        "generated": saved,
    }
    print("\nSummary:")
    print(json.dumps(SUMMARY, indent=2))

except Exception as e:
    print(f"\n❌ Generation failed: {type(e).__name__}: {e}")
    import traceback
    traceback.print_exc()
    raise


Generator checkpoint info:
  XP: bd119d3b
  Checkpoint: checkpoint.th
  Compression: checkpoint.th

Loading model on cuda...
Loading compression model from /root/workspace/experiments/audiocraft/xps/060c08dd/checkpoint.th...




Building LM model...
Loading LM weights from checkpoint.th...
✓ Model loaded: 400 tokens (8s at 50 Hz)

Generating 2 samples...
Saving audio files to /root/workspace/Training/outputs/musicgen_uncond_debug...

✓ Generated samples:
  generated_000.wav: 8.0s, RMS=0.0142, peak=0.1826
  generated_001.wav: 8.0s, RMS=0.0142, peak=0.1826

Summary:
{
  "compression_checkpoint": "/root/workspace/experiments/audiocraft/xps/060c08dd/checkpoint.th",
  "generator_checkpoint": "/root/workspace/experiments/audiocraft/xps/bd119d3b/checkpoint.th",
  "sample_rate": 16000,
  "num_codebooks": 32,
  "cardinality": 1024,
  "generated": [
    {
      "path": "/root/workspace/Training/outputs/musicgen_uncond_debug/generated_000.wav",
      "duration_sec": 8.0,
      "rms": 0.0142,
      "peak": 0.1826
    },
    {
      "path": "/root/workspace/Training/outputs/musicgen_uncond_debug/generated_001.wav",
      "duration_sec": 8.0,
      "rms": 0.0142,
      "peak": 0.1826
    }
  ]
}


## 7) Minimal sanity evaluation

In [None]:
import torch
from pathlib import Path
import math
import torchaudio

check_paths = [Path(item["path"]) for item in saved]
for p in check_paths:
    wav, sr = torchaudio.load(p)
    duration = wav.shape[-1] / sr
    finite = torch.isfinite(wav).all().item()
    print(f"{p.name}: sr={sr}, duration={duration:.2f}s, finite={finite}, rms={wav.pow(2).mean().sqrt().item():.4f}, peak={wav.abs().max().item():.4f}")
    if sr != sample_rate:
        raise ValueError(f"Unexpected sample rate in {p}: {sr} (expected {sample_rate})")
    if not finite:
        raise ValueError(f"NaNs detected in {p}")
print("Sanity checks complete.")
