In [None]:
# @title 🔄 Resume From Checkpoint
# @markdown Resume a previous run or start fresh. Set `resume_run_id` to reload
# @markdown artifacts and skip completed stages.

resume_run_id = ""  # @param {type:"string"}
cache_root = "/content/drive/MyDrive/soundlab_cache"  # @param {type:"string"}
enable_drive_cache = True  # @param {type:"boolean"}

# === Execution ===
from __future__ import annotations

import os
from pathlib import Path

# Mount Google Drive if caching enabled
if enable_drive_cache:
    try:
        from google.colab import drive

        drive.mount("/content/drive", force_remount=False)
        print("✅ Google Drive mounted")
    except ImportError:
        print("⚠️ Not running in Colab, skipping Drive mount")
    except Exception as e:
        print(f"⚠️ Drive mount failed: {e}")

# Setup cache directories
CACHE_ROOT = Path(cache_root)
CACHE_ROOT.mkdir(parents=True, exist_ok=True)

RUNS_DIR = CACHE_ROOT / "runs"
CHECKPOINTS_DIR = CACHE_ROOT / "checkpoints"
MODELS_DIR = CACHE_ROOT / "models"

for d in [RUNS_DIR, CHECKPOINTS_DIR, MODELS_DIR]:
    d.mkdir(exist_ok=True)

# Resume logic
RUN_ID: str | None = None
RESUMED = False

if resume_run_id.strip():
    run_dir = RUNS_DIR / resume_run_id.strip()
    if run_dir.exists():
        RUN_ID = resume_run_id.strip()
        RESUMED = True
        print(f"✅ Resuming run: {RUN_ID}")
        print(f"   Run directory: {run_dir}")
    else:
        print(f"⚠️ Run '{resume_run_id}' not found, starting fresh")
else:
    print("ℹ️ No resume ID provided, will generate new run ID on upload")

print(f"\n📁 Cache root: {CACHE_ROOT}")
print(f"   Runs: {RUNS_DIR}")
print(f"   Checkpoints: {CHECKPOINTS_DIR}")
print(f"   Models: {MODELS_DIR}")

# 🎛️ SoundLab Studio

**Production-ready music processing pipeline**

---

## Features

| Feature | Description |
|---------|-------------|
| 🎚️ **Stem Separation** | Demucs HTDemucs/HTDemucs-FT models for vocals, drums, bass, other |
| 🎹 **Audio-to-MIDI** | Basic Pitch transcription with onset/frame thresholds |
| 🎨 **Effects Processing** | Pedalboard-based EQ, compression, reverb, and creative effects |
| 📊 **Audio Analysis** | Tempo, key, loudness (LUFS), spectral features |
| 🗣️ **Voice Generation** | XTTS-v2 TTS and RVC voice conversion (optional) |

## Workflow

1. **Setup** — Configure environment and install packages
2. **Upload** — Load audio file and compute metadata
3. **Separate** — Extract stems with quality-aware candidate selection
4. **Transcribe** — Convert stems to MIDI with cleanup
5. **Preview & QA** — Review results and re-run if needed
6. **Export** — Download stems, MIDI, and reports

## Requirements

- **GPU Runtime** recommended (T4 or better)
- ~8GB VRAM for Demucs separation
- ~4GB disk space for models

---

> 💡 **Tip:** Enable Google Drive caching in Cell 0 to persist models and resume runs.

In [None]:
# @title ⚙️ Environment Setup
# @markdown Configure the processing environment and runtime settings.

gpu_mode = "auto"  # @param ["auto", "force_gpu", "force_cpu"]
log_level = "INFO"  # @param ["DEBUG", "INFO", "WARNING", "ERROR"]
output_base = "/content/soundlab_outputs"  # @param {type:"string"}
deterministic = False  # @param {type:"boolean"}
random_seed = 42  # @param {type:"integer"}

# === Execution ===
from pathlib import Path

# Configure environment variables
os.environ["SOUNDLAB_LOG_LEVEL"] = log_level
os.environ["SOUNDLAB_GPU_MODE"] = gpu_mode

if deterministic:
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    os.environ["SOUNDLAB_SEED"] = str(random_seed)

# Setup output directories
OUTPUT_DIR = Path(output_base)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

STEMS_DIR = OUTPUT_DIR / "stems"
MIDI_DIR = OUTPUT_DIR / "midi"
EFFECTS_DIR = OUTPUT_DIR / "effects"
ANALYSIS_DIR = OUTPUT_DIR / "analysis"
VOICE_DIR = OUTPUT_DIR / "voice"
EXPORTS_DIR = OUTPUT_DIR / "exports"

for d in [STEMS_DIR, MIDI_DIR, EFFECTS_DIR, ANALYSIS_DIR, VOICE_DIR, EXPORTS_DIR]:
    d.mkdir(exist_ok=True)

# Runtime introspection
import torch

GPU_AVAILABLE = torch.cuda.is_available()
GPU_NAME = torch.cuda.get_device_name(0) if GPU_AVAILABLE else "N/A"
GPU_MEMORY = (
    f"{torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB" if GPU_AVAILABLE else "N/A"
)

print("🖥️ Runtime Information")
print(f"   GPU Available: {GPU_AVAILABLE}")
print(f"   GPU Name: {GPU_NAME}")
print(f"   GPU Memory: {GPU_MEMORY}")
print(f"   GPU Mode: {gpu_mode}")
print(f"   Deterministic: {deterministic}")
if deterministic:
    print(f"   Random Seed: {random_seed}")
print(f"\n✅ Output directory: {OUTPUT_DIR}")

In [None]:
# @title 📦 Install SoundLab
# @markdown Install the soundlab package and dependencies.

install_voice = False  # @param {type:"boolean"}
install_from = "pypi"  # @param ["pypi", "github_main", "github_dev"]
force_reinstall = False  # @param {type:"boolean"}

# === Execution ===
import subprocess
import sys


def install_package() -> None:
    """Install soundlab with appropriate extras."""
    extras = "[notebook"
    if install_voice:
        extras += ",voice"
    extras += "]"

    if install_from == "pypi":
        pkg = f"soundlab{extras}"
    elif install_from == "github_main":
        pkg = f"soundlab{extras} @ git+https://github.com/wyattowalsh/soundlab.git"
    else:  # github_dev
        pkg = f"soundlab{extras} @ git+https://github.com/wyattowalsh/soundlab.git@dev"

    cmd = [sys.executable, "-m", "pip", "install", "-q", pkg]
    if force_reinstall:
        cmd.append("--force-reinstall")

    print(f"📦 Installing: {pkg}")
    subprocess.check_call(cmd)
    print(f"✅ Installed: {pkg}")


install_package()

# Verify installation
import soundlab

print(f"\n✅ SoundLab version: {soundlab.__version__}")

# Configure logging
from soundlab.utils.logging import configure_logging

configure_logging(level=log_level)

# Show available features
print("\n📋 Available modules:")
print("   ✅ soundlab.separation")
print("   ✅ soundlab.transcription")
print("   ✅ soundlab.effects")
print("   ✅ soundlab.analysis")
print("   ✅ soundlab.pipeline")

try:
    from soundlab import voice

    print("   ✅ soundlab.voice (TTS + RVC)")
except ImportError:
    print("   ⚪ soundlab.voice (not installed)")

In [None]:
# @title 🎵 Upload Audio
# @markdown Upload your audio file for processing. Supported formats: WAV, MP3, FLAC, OGG, AIFF, M4A

# === Execution ===
from __future__ import annotations

import hashlib
from pathlib import Path
from typing import TYPE_CHECKING

import gradio as gr

from soundlab.io import load_audio

if TYPE_CHECKING:
    from soundlab.core.audio import AudioSegment

# Global state for current audio
CURRENT_AUDIO: AudioSegment | None = None
AUDIO_HASH: str | None = None
SOURCE_PATH: Path | None = None


def compute_audio_hash(path: Path) -> str:
    """Compute SHA-256 hash of audio file for caching/dedup."""
    h = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(8192), b""):
            h.update(chunk)
    return h.hexdigest()[:16]


def handle_upload(audio_path: str | None) -> tuple[str, str]:
    """Handle audio file upload and display metadata."""
    global CURRENT_AUDIO, AUDIO_HASH, SOURCE_PATH

    if audio_path is None:
        CURRENT_AUDIO = None
        AUDIO_HASH = None
        SOURCE_PATH = None
        return "No file uploaded", ""

    try:
        source = Path(audio_path)
        CURRENT_AUDIO = load_audio(source)
        AUDIO_HASH = compute_audio_hash(source)
        SOURCE_PATH = source

        meta = CURRENT_AUDIO.metadata
        if meta is None:
            info = f"""
**File:** {source.name}
**Duration:** {CURRENT_AUDIO.duration_seconds:.2f}s
**Sample Rate:** {CURRENT_AUDIO.sample_rate} Hz
**Channels:** {CURRENT_AUDIO.channels}
**Hash:** `{AUDIO_HASH}`
            """
        else:
            info = f"""
**File:** {source.name}
**Duration:** {meta.duration_str}
**Sample Rate:** {meta.sample_rate} Hz
**Channels:** {meta.channels} ({"Stereo" if meta.is_stereo else "Mono"})
**Bit Depth:** {meta.bit_depth.value if meta.bit_depth else "Unknown"}-bit
**Format:** {meta.format.value.upper() if meta.format else "Unknown"}
**Hash:** `{AUDIO_HASH}`
            """
        return info.strip(), "✅ Audio loaded successfully!"

    except Exception as e:
        CURRENT_AUDIO = None
        AUDIO_HASH = None
        SOURCE_PATH = None
        return f"**Error:** {e}", "❌ Failed to load audio"


# Build Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as upload_interface:
    gr.Markdown("## 🎵 Upload Audio File")
    gr.Markdown("Supported formats: WAV, MP3, FLAC, OGG, AIFF, M4A")

    with gr.Row():
        with gr.Column(scale=2):
            audio_input = gr.Audio(
                label="Input Audio",
                type="filepath",
                sources=["upload"],
            )
        with gr.Column(scale=1):
            info_output = gr.Markdown(label="Audio Info", value="No file uploaded")
            status_output = gr.Markdown()

    audio_input.change(
        fn=handle_upload,
        inputs=[audio_input],
        outputs=[info_output, status_output],
    )

upload_interface.launch(height=450, show_error=True)

In [None]:
# @title 🔊 Canonical Decode
# @markdown Decode audio to canonical format (44.1kHz stereo float32) for consistent processing.

target_sample_rate = 44100  # @param {type:"integer"}
target_channels = 2  # @param [1, 2] {type:"raw"}
normalize_audio = True  # @param {type:"boolean"}
normalization_level = -1.0  # @param {type:"number"}

# === Execution ===
from __future__ import annotations

import numpy as np
from scipy import signal


def to_canonical(
    audio: AudioSegment,
    sr: int = 44100,
    channels: int = 2,
    normalize: bool = True,
    level_db: float = -1.0,
) -> AudioSegment:
    """
    Convert audio to canonical format.

    Parameters
    ----------
    audio
        Source AudioSegment.
    sr
        Target sample rate (default 44100).
    channels
        Target channel count (1=mono, 2=stereo).
    normalize
        Whether to normalize audio level.
    level_db
        Target peak level in dB (default -1.0).

    Returns
    -------
    AudioSegment
        Resampled and normalized audio.
    """
    from soundlab.core.audio import AudioSegment

    samples = audio.samples.copy()
    current_sr = audio.sample_rate

    # Resample if needed
    if current_sr != sr:
        if samples.ndim == 1:
            samples = signal.resample(samples, int(len(samples) * sr / current_sr))
        else:
            # Resample each channel
            new_length = int(samples.shape[-1] * sr / current_sr)
            resampled = np.zeros((samples.shape[0], new_length), dtype=np.float32)
            for i in range(samples.shape[0]):
                resampled[i] = signal.resample(samples[i], new_length)
            samples = resampled
        print(f"   Resampled: {current_sr} Hz → {sr} Hz")

    # Convert channels
    current_channels = 1 if samples.ndim == 1 else samples.shape[0]

    if current_channels != channels:
        if channels == 1 and current_channels == 2:
            # Stereo to mono
            samples = np.mean(samples, axis=0)
            print("   Converted: Stereo → Mono")
        elif channels == 2 and current_channels == 1:
            # Mono to stereo (duplicate)
            if samples.ndim == 1:
                samples = np.stack([samples, samples])
            else:
                samples = np.stack([samples[0], samples[0]])
            print("   Converted: Mono → Stereo")

    # Normalize
    if normalize:
        peak = np.abs(samples).max()
        if peak > 0:
            target_peak = 10 ** (level_db / 20.0)
            samples = samples * (target_peak / peak)
            print(f"   Normalized to {level_db} dB peak")

    return AudioSegment(
        samples=samples.astype(np.float32),
        sample_rate=sr,
        source_path=audio.source_path,
        metadata=audio.metadata,
    )


# Process current audio
if CURRENT_AUDIO is None:
    print("⚠️ Please upload an audio file first (Cell 4)")
else:
    print(f"🔊 Processing: {SOURCE_PATH.name if SOURCE_PATH else 'audio'}")
    print(f"   Original: {CURRENT_AUDIO.sample_rate} Hz, {CURRENT_AUDIO.channels} channels")

    CANONICAL_AUDIO = to_canonical(
        CURRENT_AUDIO,
        sr=target_sample_rate,
        channels=target_channels,
        normalize=normalize_audio,
        level_db=normalization_level,
    )

    print(f"   Canonical: {CANONICAL_AUDIO.sample_rate} Hz, {CANONICAL_AUDIO.channels} channels")
    print(f"   Duration: {CANONICAL_AUDIO.duration_seconds:.2f}s")
    print(f"   Samples: {CANONICAL_AUDIO.samples.shape}")
    print("\n✅ Audio decoded to canonical format")

In [None]:
# @title ✂️ Excerpt Selection (Optional)
# @markdown Extract a shorter excerpt for faster processing and candidate comparison.
# @markdown Set `use_excerpt=False` to process the full track.

use_excerpt = True  # @param {type:"boolean"}
excerpt_start_seconds = 30.0  # @param {type:"number"}
excerpt_duration_seconds = 30.0  # @param {type:"number"}
auto_select_excerpt = False  # @param {type:"boolean"}

# === Execution ===
from __future__ import annotations


def extract_excerpt(
    audio: AudioSegment,
    start: float,
    duration: float,
) -> AudioSegment:
    """
    Extract an excerpt from audio.

    Parameters
    ----------
    audio
        Source AudioSegment.
    start
        Start time in seconds.
    duration
        Duration in seconds.

    Returns
    -------
    AudioSegment
        Extracted excerpt.
    """
    from soundlab.core.audio import AudioSegment

    sr = audio.sample_rate
    samples = audio.samples

    start_sample = int(start * sr)
    end_sample = int((start + duration) * sr)

    # Clamp to valid range
    total_samples = samples.shape[-1] if samples.ndim > 1 else len(samples)
    start_sample = max(0, min(start_sample, total_samples))
    end_sample = max(start_sample, min(end_sample, total_samples))

    if samples.ndim == 1:
        excerpt_samples = samples[start_sample:end_sample]
    else:
        excerpt_samples = samples[:, start_sample:end_sample]

    return AudioSegment(
        samples=excerpt_samples,
        sample_rate=sr,
        source_path=audio.source_path,
        metadata=audio.metadata,
    )


def find_energetic_region(audio: AudioSegment, duration: float) -> float:
    """
    Find the most energetic region of the audio for excerpt selection.

    Parameters
    ----------
    audio
        Source AudioSegment.
    duration
        Desired excerpt duration in seconds.

    Returns
    -------
    float
        Optimal start time in seconds.
    """
    sr = audio.sample_rate
    samples = audio.samples

    # Convert to mono for analysis
    mono = np.mean(samples, axis=0) if samples.ndim > 1 else samples

    # Compute RMS energy in windows
    window_samples = int(duration * sr)
    hop_samples = int(sr)  # 1-second hop

    if len(mono) <= window_samples:
        return 0.0

    best_energy = 0.0
    best_start = 0.0

    for start_idx in range(0, len(mono) - window_samples, hop_samples):
        window = mono[start_idx : start_idx + window_samples]
        energy = np.sqrt(np.mean(window**2))

        if energy > best_energy:
            best_energy = energy
            best_start = start_idx / sr

    return best_start


# Process excerpt
if "CANONICAL_AUDIO" not in dir() or CANONICAL_AUDIO is None:
    print("⚠️ Please run the Canonical Decode cell first (Cell 5)")
else:
    total_duration = CANONICAL_AUDIO.duration_seconds
    print(f"📊 Total duration: {total_duration:.2f}s")

    if not use_excerpt:
        EXCERPT_AUDIO = CANONICAL_AUDIO
        EXCERPT_START = 0.0
        EXCERPT_DURATION = total_duration
        print("ℹ️ Using full track (no excerpt)")
    else:
        # Determine start position
        if auto_select_excerpt:
            excerpt_start = find_energetic_region(CANONICAL_AUDIO, excerpt_duration_seconds)
            print(f"🎯 Auto-selected energetic region starting at {excerpt_start:.1f}s")
        else:
            excerpt_start = excerpt_start_seconds

        # Clamp to valid range
        max_start = max(0, total_duration - excerpt_duration_seconds)
        excerpt_start = min(excerpt_start, max_start)
        actual_duration = min(excerpt_duration_seconds, total_duration - excerpt_start)

        EXCERPT_AUDIO = extract_excerpt(CANONICAL_AUDIO, excerpt_start, actual_duration)
        EXCERPT_START = excerpt_start
        EXCERPT_DURATION = actual_duration

        print("\n✂️ Excerpt extracted:")
        print(f"   Start: {EXCERPT_START:.2f}s")
        print(f"   Duration: {EXCERPT_DURATION:.2f}s")
        print(f"   End: {EXCERPT_START + EXCERPT_DURATION:.2f}s")

    print("\n✅ Audio ready for processing")
    print(f"   Shape: {EXCERPT_AUDIO.samples.shape}")
    print(f"   Sample rate: {EXCERPT_AUDIO.sample_rate} Hz")

    # Initialize run ID if not resumed
    if RUN_ID is None:
        import uuid

        RUN_ID = f"{AUDIO_HASH}_{uuid.uuid4().hex[:8]}"
        print(f"\n🆔 Generated run ID: {RUN_ID}")

In [None]:
# @title 🎚️ Stem Separation
# @markdown Configure and run Demucs stem separation.

# --- Separation Model ---
separation_model = (
    "htdemucs_ft"  # @param ["htdemucs", "htdemucs_ft", "htdemucs_6s", "mdx_extra", "mdx_extra_q"]
)
two_stems = "none"  # @param ["none", "vocals", "drums", "bass", "other"]

# --- Processing Options ---
segment_length = 7.8  # @param {type:"slider", min:5.0, max:30.0, step:0.1}
overlap = 0.25  # @param {type:"slider", min:0.1, max:0.5, step:0.05}
shifts = 1  # @param {type:"slider", min:0, max:5, step:1}

# --- Output Format ---
output_format = "wav"  # @param ["wav", "mp3", "flac"]
mp3_bitrate = 320  # @param [128, 192, 256, 320] {type:"raw"}
float32_output = True  # @param {type:"boolean"}

# === Execution ===
from pathlib import Path

from soundlab.separation import StemSeparator
from soundlab.separation.models import DemucsModel, SeparationConfig

# Validate input exists
if "AUDIO_PATH" not in dir() or AUDIO_PATH is None:
    raise RuntimeError("❌ No audio file loaded. Please run the Upload cell first.")

print(f"🎵 Input: {AUDIO_PATH}")
print(f"🎚️ Model: {separation_model}")

# Build configuration
model_enum = DemucsModel(separation_model)
two_stems_val = None if two_stems == "none" else two_stems

sep_config = SeparationConfig(
    model=model_enum,
    segment_length=segment_length,
    overlap=overlap,
    shifts=shifts,
    two_stems=two_stems_val,
    float32=float32_output,
    mp3_bitrate=mp3_bitrate if output_format == "mp3" else None,
)

print("\n⚙️ Configuration:")
print(f"   Segment length: {segment_length}s")
print(f"   Overlap: {overlap}")
print(f"   Shifts: {shifts}")
print(f"   Two stems: {two_stems_val or 'disabled'}")

# Initialize separator
separator = StemSeparator(config=sep_config)

# Run separation
print("\n🔄 Separating stems...")
stem_result = separator.separate(
    audio_path=AUDIO_PATH,
    output_dir=STEMS_DIR,
)

# Store results for downstream cells
STEM_RESULT = stem_result
STEM_PATHS = dict(stem_result.stems.items())

# Display results
print("\n✅ Separation complete!")
print(f"📁 Output directory: {STEMS_DIR}")
print("\n🎼 Extracted stems:")
for name, path in STEM_PATHS.items():
    size_mb = Path(path).stat().st_size / (1024 * 1024)
    print(f"   {name}: {path.name} ({size_mb:.1f} MB)")

if stem_result.vocals:
    print(f"\n🎤 Vocals stem: {stem_result.vocals}")
if stem_result.instrumental:
    print(f"🎸 Instrumental stem: {stem_result.instrumental}")

In [None]:
# @title 🎯 Candidate Selection (Multi-Strategy Separation)
# @markdown Run multiple separation strategies on an excerpt and select the best one.

# --- Excerpt Settings ---
enable_candidate_selection = True  # @param {type:"boolean"}
excerpt_start = 30.0  # @param {type:"number"}
excerpt_duration = 30.0  # @param {type:"slider", min:10.0, max:60.0, step:5.0}
max_candidates = 3  # @param {type:"slider", min:1, max:5, step:1}

# --- QA Thresholds ---
min_overall_score = 0.7  # @param {type:"slider", min:0.5, max:0.95, step:0.05}
max_reconstruction_error = 0.15  # @param {type:"slider", min:0.05, max:0.30, step:0.05}
max_clipping_ratio = 0.01  # @param {type:"slider", min:0.001, max:0.05, step:0.005}

# --- Full Run Options ---
rerun_full_on_best = True  # @param {type:"boolean"}

# === Execution ===
from pathlib import Path

from soundlab.pipeline import (
    CandidateScore,
    PipelineConfig,
    QAConfig,
    build_candidate_plans,
    choose_best_candidate,
    init_run,
    score_separation,
)
from soundlab.separation import StemSeparator
from soundlab.separation.models import SeparationConfig

if not enable_candidate_selection:
    print("ℹ️ Candidate selection disabled. Using single-run separation from previous cell.")
    if "STEM_RESULT" in dir():
        print(f"✅ Using existing stem result: {len(STEM_RESULT.stems)} stems")
    else:
        print("⚠️ No stem result found. Run the Separation cell first.")
else:
    # Validate input
    if "AUDIO_PATH" not in dir() or AUDIO_PATH is None:
        raise RuntimeError("❌ No audio file loaded. Please run the Upload cell first.")

    print(f"🎵 Input: {AUDIO_PATH}")
    print(f"⏱️ Excerpt: {excerpt_start}s - {excerpt_start + excerpt_duration}s")
    print(f"🎯 Max candidates: {max_candidates}")

    # Build QA config
    qa_config = QAConfig(
        min_overall_score=min_overall_score,
        max_reconstruction_error=max_reconstruction_error,
        max_clipping_ratio=max_clipping_ratio,
    )

    # Build pipeline config
    pipeline_config = PipelineConfig(
        excerpt_start=excerpt_start,
        excerpt_duration=excerpt_duration,
        max_candidates=max_candidates,
        qa=qa_config,
    )

    # Generate candidate plans
    plans = build_candidate_plans(pipeline_config)
    print(f"\n📋 Candidate plans ({len(plans)}):")
    for plan in plans:
        print(f"   • {plan.name}: {plan.notes or 'default settings'}")

    # Initialize run
    artifacts = init_run(pipeline_config, Path(AUDIO_PATH), root=RUNS_DIR)
    print(f"\n🆔 Run ID: {artifacts.run_id}")

    # Run candidates on excerpt
    print(f"\n🔄 Running {len(plans)} candidates on excerpt...")
    candidate_scores: list[CandidateScore] = []

    for i, plan in enumerate(plans, 1):
        print(f"\n--- Candidate {i}/{len(plans)}: {plan.name} ---")

        # Create separator for this plan
        separator = StemSeparator(config=plan.separation)

        try:
            # Separate excerpt
            excerpt_result = separator.separate(
                audio_path=AUDIO_PATH,
                output_dir=artifacts.cache_dir / f"excerpt_{plan.name}",
                start_time=excerpt_start,
                duration=excerpt_duration,
            )

            # Score separation quality
            qa_result = score_separation(
                stems=excerpt_result.stems,
                original_path=Path(AUDIO_PATH),
                config=qa_config,
            )

            score = CandidateScore(
                name=plan.name,
                score=qa_result.score,
                metrics=qa_result.metrics,
                passed=qa_result.passed,
            )
            candidate_scores.append(score)

            status = "✅ PASS" if qa_result.passed else "⚠️ WARN"
            print(f"   {status} Score: {qa_result.score:.3f}")
            for metric, value in qa_result.metrics.items():
                print(f"      {metric}: {value:.4f}")

        except Exception as e:
            print(f"   ❌ Failed: {e}")
            candidate_scores.append(CandidateScore(name=plan.name, score=0.0, passed=False))

    # Select best candidate
    best = choose_best_candidate(candidate_scores, qa=qa_config)

    if best:
        print(f"\n🏆 Best candidate: {best.name} (score: {best.score:.3f})")

        # Store for downstream
        BEST_CANDIDATE = best
        BEST_PLAN = next(p for p in plans if p.name == best.name)

        if rerun_full_on_best:
            print("\n🔄 Re-running full separation with best candidate...")

            separator = StemSeparator(config=BEST_PLAN.separation)
            stem_result = separator.separate(
                audio_path=AUDIO_PATH,
                output_dir=STEMS_DIR,
            )

            # Update globals
            STEM_RESULT = stem_result
            STEM_PATHS = dict(stem_result.stems.items())

            print("✅ Full separation complete!")
            print(f"📁 Output: {STEMS_DIR}")
            for name, path in STEM_PATHS.items():
                print(f"   {name}: {path.name}")
        else:
            print("ℹ️ Full re-run disabled. Using excerpt result.")
    else:
        print("❌ No candidates passed QA thresholds. Consider adjusting settings.")

    # Summary table
    print("\n📊 Candidate Summary:")
    print(f"{'Name':<15} {'Score':<10} {'Status':<10}")
    print("-" * 35)
    for score in sorted(candidate_scores, key=lambda x: x.score, reverse=True):
        status = "PASS" if score.passed else "FAIL"
        marker = "→" if best and score.name == best.name else " "
        print(f"{marker}{score.name:<14} {score.score:<10.3f} {status:<10}")

In [None]:
# @title 🧹 Stem Post-Processing
# @markdown Clean and prepare stems for transcription. Applies alignment-safe filtering,
# @markdown mono conversion for AMT, and clipping detection.

# --- Silence Filtering ---
silence_threshold = 1e-4  # @param {type:"number"}
enable_silence_filtering = True  # @param {type:"boolean"}

# --- Clipping Detection ---
clipping_threshold = 0.99  # @param {type:"number"}
max_clipping_ratio = 0.01  # @param {type:"number"}

# --- AMT Preparation ---
prepare_mono_for_amt = True  # @param {type:"boolean"}
save_mono_stems = True  # @param {type:"boolean"}

# === Execution ===
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np

from soundlab.io import load_audio, save_audio
from soundlab.pipeline.postprocess import clean_stems, mono_amt_exports

if TYPE_CHECKING:
    from soundlab.core.audio import AudioSegment


def detect_clipping(
    samples: np.ndarray,
    threshold: float = 0.99,
) -> dict[str, float]:
    """
    Detect clipping in audio samples.

    Parameters
    ----------
    samples
        Audio samples (mono or stereo).
    threshold
        Amplitude threshold for clipping detection.

    Returns
    -------
    dict
        Clipping metrics: ratio, peak, and count.
    """
    abs_samples = np.abs(samples)
    peak = float(np.max(abs_samples))
    clipped_count = int(np.sum(abs_samples >= threshold))
    total_samples = samples.size
    clipping_ratio = clipped_count / total_samples if total_samples > 0 else 0.0

    return {
        "peak": peak,
        "clipped_count": clipped_count,
        "clipping_ratio": clipping_ratio,
    }


def load_stem_arrays(stem_paths: dict[str, Path]) -> dict[str, np.ndarray]:
    """Load stem audio files into numpy arrays."""
    stems: dict[str, np.ndarray] = {}
    for name, path in stem_paths.items():
        audio = load_audio(path)
        stems[name] = audio.samples
    return stems


# Validate prerequisites
if "STEM_PATHS" not in dir() or not STEM_PATHS:
    raise RuntimeError("❌ No stem paths found. Please run the Separation cell first.")

print(f"🧹 Post-processing {len(STEM_PATHS)} stems...")
print(f"   Silence threshold: {silence_threshold}")
print(f"   Clipping threshold: {clipping_threshold}")

# Load stem arrays
print("\n📂 Loading stems...")
STEM_ARRAYS = load_stem_arrays(STEM_PATHS)

# Check for clipping
print("\n📊 Clipping analysis:")
CLIPPING_RESULTS: dict[str, dict[str, float]] = {}
any_clipping_issues = False

for name, samples in STEM_ARRAYS.items():
    clip_info = detect_clipping(samples, threshold=clipping_threshold)
    CLIPPING_RESULTS[name] = clip_info

    status = "✅" if clip_info["clipping_ratio"] <= max_clipping_ratio else "⚠️"
    if clip_info["clipping_ratio"] > max_clipping_ratio:
        any_clipping_issues = True

    print(
        f"   {status} {name}: peak={clip_info['peak']:.4f}, "
        f"clipping={clip_info['clipping_ratio'] * 100:.3f}%"
    )

if any_clipping_issues:
    print(
        "\n⚠️ Some stems exceed clipping threshold. Consider re-running separation "
        "with different settings or applying limiting during export."
    )

# Apply silence filtering
if enable_silence_filtering:
    print(f"\n🔇 Applying silence filtering (threshold={silence_threshold})...")
    STEM_ARRAYS_CLEANED = clean_stems(STEM_ARRAYS, silence_threshold=silence_threshold)
    print("   ✅ Silence filtering complete")
else:
    STEM_ARRAYS_CLEANED = STEM_ARRAYS
    print("\nℹ️ Silence filtering disabled")

# Prepare mono stems for AMT
if prepare_mono_for_amt:
    print("\n🎹 Preparing mono stems for AMT...")
    MONO_STEMS = mono_amt_exports(STEM_ARRAYS_CLEANED)

    # Display mono stem info
    for name, samples in MONO_STEMS.items():
        print(f"   {name}: {samples.shape} ({samples.dtype})")

    # Save mono stems if requested
    if save_mono_stems:
        MONO_STEMS_DIR = STEMS_DIR / "mono_amt"
        MONO_STEMS_DIR.mkdir(exist_ok=True)

        MONO_STEM_PATHS: dict[str, Path] = {}
        for name, samples in MONO_STEMS.items():
            # Infer sample rate from original stem
            sr = load_audio(STEM_PATHS[name]).sample_rate
            mono_path = MONO_STEMS_DIR / f"{name}_mono.wav"
            save_audio(samples, mono_path, sample_rate=sr)
            MONO_STEM_PATHS[name] = mono_path

        print(f"\n💾 Saved mono stems to: {MONO_STEMS_DIR}")
        for name, path in MONO_STEM_PATHS.items():
            size_kb = path.stat().st_size / 1024
            print(f"   {name}: {path.name} ({size_kb:.1f} KB)")
else:
    MONO_STEMS = {}
    MONO_STEM_PATHS = {}
    print("\nℹ️ Mono AMT preparation disabled")

# Summary
print("\n" + "=" * 50)
print("✅ Post-processing complete!")
print(f"   Cleaned stems: {len(STEM_ARRAYS_CLEANED)}")
print(f"   Mono AMT stems: {len(MONO_STEMS)}")
print(f"   Clipping issues: {'Yes' if any_clipping_issues else 'None'}")

In [None]:
# @title 🎹 Transcription Routing
# @markdown Configure per-stem transcription backends and run audio-to-MIDI conversion.
# @markdown Each stem type can use different thresholds optimized for that instrument.

# --- Stem Selection ---
transcribe_vocals = False  # @param {type:"boolean"}
transcribe_drums = False  # @param {type:"boolean"}
transcribe_bass = True  # @param {type:"boolean"}
transcribe_other = True  # @param {type:"boolean"}
transcribe_piano = True  # @param {type:"boolean"}
transcribe_guitar = True  # @param {type:"boolean"}

# --- Global Thresholds ---
onset_threshold = 0.5  # @param {type:"slider", min:0.1, max:0.9, step:0.05}
frame_threshold = 0.3  # @param {type:"slider", min:0.1, max:0.9, step:0.05}
min_note_length = 0.058  # @param {type:"number"}

# --- Frequency Ranges ---
bass_min_freq = 32.7  # @param {type:"number"}
bass_max_freq = 500.0  # @param {type:"number"}
default_min_freq = 32.7  # @param {type:"number"}
default_max_freq = 2093.0  # @param {type:"number"}

# --- Fallback Behavior ---
enable_fallback = True  # @param {type:"boolean"}
fallback_lower_threshold = 0.1  # @param {type:"number"}

# === Execution ===
from __future__ import annotations

from dataclasses import dataclass, field
from pathlib import Path

from soundlab.transcription import MIDIResult, MIDITranscriber, TranscriptionConfig

# Stem routing configuration
STEM_ROUTING: dict[str, dict[str, bool | float]] = {
    "vocals": {
        "enabled": transcribe_vocals,
        "min_freq": default_min_freq,
        "max_freq": default_max_freq,
    },
    "drums": {
        "enabled": transcribe_drums,
        "min_freq": default_min_freq,
        "max_freq": default_max_freq,
    },
    "bass": {"enabled": transcribe_bass, "min_freq": bass_min_freq, "max_freq": bass_max_freq},
    "other": {
        "enabled": transcribe_other,
        "min_freq": default_min_freq,
        "max_freq": default_max_freq,
    },
    "piano": {
        "enabled": transcribe_piano,
        "min_freq": default_min_freq,
        "max_freq": default_max_freq,
    },
    "guitar": {
        "enabled": transcribe_guitar,
        "min_freq": default_min_freq,
        "max_freq": default_max_freq,
    },
}


@dataclass
class TranscriptionResult:
    """Result from transcription routing for a single stem."""

    stem_name: str
    midi_result: MIDIResult | None = None
    confidence: float = 0.0
    fallback_used: bool = False
    error: str | None = None


@dataclass
class TranscriptionBatch:
    """Batch transcription results with confidence matrix."""

    results: dict[str, TranscriptionResult] = field(default_factory=dict)
    total_notes: int = 0
    successful_stems: int = 0
    failed_stems: int = 0

    @property
    def confidence_matrix(self) -> dict[str, float]:
        """Return confidence scores for all stems."""
        return {name: r.confidence for name, r in self.results.items()}


def build_stem_config(
    stem_name: str,
    routing: dict[str, dict[str, bool | float]],
    onset_thresh: float,
    frame_thresh: float,
    min_note_len: float,
) -> TranscriptionConfig | None:
    """Build transcription config for a stem based on routing rules."""
    stem_config = routing.get(stem_name)
    if not stem_config or not stem_config.get("enabled", False):
        return None

    return TranscriptionConfig(
        onset_thresh=onset_thresh,
        frame_thresh=frame_thresh,
        min_note_length=min_note_len,
        min_freq=float(stem_config.get("min_freq", default_min_freq)),
        max_freq=float(stem_config.get("max_freq", default_max_freq)),
    )


def compute_confidence(result: MIDIResult) -> float:
    """
    Compute transcription confidence based on note density and consistency.

    Higher confidence for:
    - More notes (up to a point)
    - Consistent note velocities
    - Reasonable note durations
    """
    notes = result.notes
    if not notes:
        return 0.0

    note_count = len(notes)

    # Note count factor (saturates around 50-100 notes)
    count_factor = min(1.0, note_count / 50.0)

    # Velocity variance factor (lower variance = higher confidence)
    velocities = [n.velocity for n in notes]
    vel_mean = sum(velocities) / len(velocities)
    vel_var = sum((v - vel_mean) ** 2 for v in velocities) / len(velocities)
    vel_factor = max(0.0, 1.0 - (vel_var / 2000.0))  # Normalize variance

    # Duration consistency factor
    durations = [n.end - n.start for n in notes]
    dur_mean = sum(durations) / len(durations)
    dur_var = sum((d - dur_mean) ** 2 for d in durations) / len(durations)
    dur_factor = max(0.0, 1.0 - (dur_var / 1.0))  # Normalize

    # Combined confidence
    confidence = count_factor * 0.4 + vel_factor * 0.3 + dur_factor * 0.3
    return min(1.0, max(0.0, confidence))


def transcribe_stem_with_fallback(
    stem_path: Path,
    stem_name: str,
    config: TranscriptionConfig,
    output_dir: Path,
    enable_fallback: bool = True,
    fallback_delta: float = 0.1,
) -> TranscriptionResult:
    """
    Transcribe a single stem with optional fallback on low confidence.

    Parameters
    ----------
    stem_path
        Path to the stem audio file.
    stem_name
        Name of the stem.
    config
        Transcription configuration.
    output_dir
        Directory to save MIDI output.
    enable_fallback
        Whether to retry with lower thresholds on failure.
    fallback_delta
        Amount to lower thresholds on fallback.

    Returns
    -------
    TranscriptionResult
        Transcription result with confidence and fallback status.
    """
    transcriber = MIDITranscriber(config=config)
    stem_output = output_dir / stem_name
    stem_output.mkdir(parents=True, exist_ok=True)

    try:
        midi_result = transcriber.transcribe(stem_path, stem_output)
        confidence = compute_confidence(midi_result)

        # Check if fallback needed (low note count or confidence)
        if enable_fallback and len(midi_result.notes) < 5 and confidence < 0.3:
            # Retry with lower thresholds
            fallback_config = TranscriptionConfig(
                onset_thresh=max(0.1, config.onset_thresh - fallback_delta),
                frame_thresh=max(0.1, config.frame_thresh - fallback_delta),
                min_note_length=config.min_note_length,
                min_freq=config.min_freq,
                max_freq=config.max_freq,
            )

            fallback_transcriber = MIDITranscriber(config=fallback_config)
            fallback_result = fallback_transcriber.transcribe(stem_path, stem_output)
            fallback_confidence = compute_confidence(fallback_result)

            if len(fallback_result.notes) > len(midi_result.notes):
                return TranscriptionResult(
                    stem_name=stem_name,
                    midi_result=fallback_result,
                    confidence=fallback_confidence,
                    fallback_used=True,
                )

        return TranscriptionResult(
            stem_name=stem_name,
            midi_result=midi_result,
            confidence=confidence,
            fallback_used=False,
        )

    except Exception as e:
        return TranscriptionResult(
            stem_name=stem_name,
            midi_result=None,
            confidence=0.0,
            error=str(e),
        )


# Validate prerequisites
if "MONO_STEM_PATHS" not in dir() or not MONO_STEM_PATHS:
    # Fall back to original stems if mono not available
    if "STEM_PATHS" not in dir() or not STEM_PATHS:
        raise RuntimeError("❌ No stems found. Please run the Separation cell first.")
    TRANSCRIPTION_INPUTS = STEM_PATHS
    print("ℹ️ Using original stereo stems (mono not available)")
else:
    TRANSCRIPTION_INPUTS = MONO_STEM_PATHS
    print("✅ Using mono stems for transcription")

# Build enabled stems list
enabled_stems = [name for name, cfg in STEM_ROUTING.items() if cfg.get("enabled", False)]
available_stems = [name for name in enabled_stems if name in TRANSCRIPTION_INPUTS]

print(f"\n🎹 Transcription routing:")
print(f"   Enabled stems: {enabled_stems}")
print(f"   Available stems: {available_stems}")
print(f"   Global onset threshold: {onset_threshold}")
print(f"   Global frame threshold: {frame_threshold}")

if not available_stems:
    print("\n⚠️ No stems selected for transcription. Enable at least one stem above.")
else:
    # Run transcription
    print(f"\n🔄 Transcribing {len(available_stems)} stems...")

    TRANSCRIPTION_BATCH = TranscriptionBatch()

    for stem_name in available_stems:
        stem_path = TRANSCRIPTION_INPUTS[stem_name]
        config = build_stem_config(
            stem_name,
            STEM_ROUTING,
            onset_threshold,
            frame_threshold,
            min_note_length,
        )

        if config is None:
            continue

        print(f"\n--- {stem_name} ---")
        print(f"   Path: {stem_path}")
        print(f"   Freq range: {config.min_freq:.1f} - {config.max_freq:.1f} Hz")

        result = transcribe_stem_with_fallback(
            stem_path=stem_path,
            stem_name=stem_name,
            config=config,
            output_dir=MIDI_DIR,
            enable_fallback=enable_fallback,
            fallback_delta=fallback_lower_threshold,
        )

        TRANSCRIPTION_BATCH.results[stem_name] = result

        if result.error:
            TRANSCRIPTION_BATCH.failed_stems += 1
            print(f"   ❌ Error: {result.error}")
        elif result.midi_result:
            TRANSCRIPTION_BATCH.successful_stems += 1
            TRANSCRIPTION_BATCH.total_notes += len(result.midi_result.notes)
            status = "✅"
            if result.fallback_used:
                status = "⚠️ (fallback)"
            print(
                f"   {status} Notes: {len(result.midi_result.notes)}, "
                f"Confidence: {result.confidence:.2f}"
            )
            print(f"   MIDI: {result.midi_result.path}")

    # Summary
    print("\n" + "=" * 50)
    print("📊 Transcription Summary:")
    print(f"   Successful: {TRANSCRIPTION_BATCH.successful_stems}")
    print(f"   Failed: {TRANSCRIPTION_BATCH.failed_stems}")
    print(f"   Total notes: {TRANSCRIPTION_BATCH.total_notes}")

    # Confidence matrix
    print("\n🎯 Confidence Matrix:")
    for stem, conf in TRANSCRIPTION_BATCH.confidence_matrix.items():
        bar = "█" * int(conf * 20) + "░" * (20 - int(conf * 20))
        print(f"   {stem:<10} [{bar}] {conf:.2f}")

In [None]:
# @title 🎼 MIDI Cleanup
# @markdown Clean up transcribed MIDI: filter short notes, detect tempo, apply quantization,
# @markdown and assign General MIDI program numbers to stems.

# --- Note Filtering ---
min_note_duration = 0.02  # @param {type:"number"}
min_velocity = 1  # @param {type:"slider", min:1, max:127, step:1}
filter_short_notes = True  # @param {type:"boolean"}

# --- Tempo Detection ---
detect_tempo_from_audio = True  # @param {type:"boolean"}
fallback_bpm = 120.0  # @param {type:"number"}

# --- Quantization ---
enable_quantization = True  # @param {type:"boolean"}
quantize_strength = 0.5  # @param {type:"slider", min:0.0, max:1.0, step:0.1}
quantize_grid = "1/16"  # @param ["1/4", "1/8", "1/16", "1/32"]

# --- Program Mapping ---
enable_program_mapping = True  # @param {type:"boolean"}

# === Execution ===
from __future__ import annotations

from dataclasses import dataclass, field
from pathlib import Path

import numpy as np

from soundlab.analysis.tempo import detect_tempo
from soundlab.io import load_audio
from soundlab.io.midi_io import MIDIData, MIDINote, TimeSignature, save_midi
from soundlab.pipeline.postprocess import cleanup_midi_notes
from soundlab.transcription.models import NoteEvent

# General MIDI program mapping for stems
GM_PROGRAMS: dict[str, int] = {
    "vocals": 54,  # Voice Oohs
    "drums": 0,  # Drums on channel 10 (program ignored)
    "bass": 33,  # Finger Bass
    "other": 0,  # Acoustic Grand Piano
    "piano": 0,  # Acoustic Grand Piano
    "guitar": 25,  # Acoustic Guitar (steel)
}

# Grid values in beats
GRID_VALUES: dict[str, float] = {
    "1/4": 1.0,
    "1/8": 0.5,
    "1/16": 0.25,
    "1/32": 0.125,
}


@dataclass
class CleanedMIDI:
    """Container for cleaned MIDI data."""

    stem_name: str
    original_notes: int
    cleaned_notes: int
    midi_data: MIDIData
    output_path: Path | None = None
    tempo_bpm: float = 120.0
    quantized: bool = False


@dataclass
class MIDICleanupBatch:
    """Batch of cleaned MIDI results."""

    results: dict[str, CleanedMIDI] = field(default_factory=dict)
    detected_tempo: float = 120.0
    tempo_confidence: float = 0.0
    total_original_notes: int = 0
    total_cleaned_notes: int = 0

    @property
    def notes_removed(self) -> int:
        return self.total_original_notes - self.total_cleaned_notes


def quantize_notes(
    notes: list[NoteEvent],
    bpm: float,
    grid: float,
    strength: float,
) -> list[NoteEvent]:
    """
    Apply soft quantization to notes.

    Parameters
    ----------
    notes
        List of note events to quantize.
    bpm
        Tempo in beats per minute.
    grid
        Grid size in beats.
    strength
        Quantization strength (0.0 = no change, 1.0 = full snap).

    Returns
    -------
    list[NoteEvent]
        Quantized notes.
    """
    if strength <= 0 or not notes:
        return notes

    # Convert grid to seconds
    beat_duration = 60.0 / bpm
    grid_seconds = grid * beat_duration

    quantized: list[NoteEvent] = []
    for note in notes:
        # Find nearest grid point for start
        grid_index = round(note.start / grid_seconds)
        snapped_start = grid_index * grid_seconds

        # Apply strength (blend between original and snapped)
        new_start = note.start + (snapped_start - note.start) * strength

        # Preserve duration
        duration = note.end - note.start
        new_end = new_start + duration

        quantized.append(
            NoteEvent(
                start=max(0.0, new_start),
                end=max(new_start + 0.001, new_end),
                pitch=note.pitch,
                velocity=note.velocity,
            )
        )

    return quantized


def notes_to_midi_data(
    notes: list[NoteEvent],
    tempo: float,
    time_signature: TimeSignature | None = None,
) -> MIDIData:
    """
    Convert NoteEvents to MIDIData.

    Parameters
    ----------
    notes
        List of note events.
    tempo
        Tempo in BPM.
    time_signature
        Optional time signature.

    Returns
    -------
    MIDIData
        MIDI data structure.
    """
    midi_notes = [
        MIDINote(
            pitch=note.pitch,
            start_seconds=note.start,
            end_seconds=note.end,
            velocity=note.velocity,
        )
        for note in notes
    ]

    return MIDIData(
        notes=midi_notes,
        tempo=tempo,
        time_signature=time_signature or TimeSignature(numerator=4, denominator=4),
    )


def cleanup_stem_midi(
    stem_name: str,
    midi_result: "MIDIResult",
    tempo: float,
    min_duration: float,
    min_vel: int,
    enable_filter: bool,
    enable_quant: bool,
    quant_strength: float,
    quant_grid: float,
    output_dir: Path,
) -> CleanedMIDI:
    """
    Apply cleanup pipeline to a stem's MIDI result.

    Parameters
    ----------
    stem_name
        Name of the stem.
    midi_result
        Transcription result.
    tempo
        Detected or fallback tempo.
    min_duration
        Minimum note duration in seconds.
    min_vel
        Minimum velocity threshold.
    enable_filter
        Whether to filter short/quiet notes.
    enable_quant
        Whether to apply quantization.
    quant_strength
        Quantization strength.
    quant_grid
        Quantization grid in beats.
    output_dir
        Directory to save cleaned MIDI.

    Returns
    -------
    CleanedMIDI
        Cleaned MIDI result.
    """
    original_notes = midi_result.notes
    original_count = len(original_notes)

    # Step 1: Filter short/quiet notes
    if enable_filter:
        filtered_notes = cleanup_midi_notes(
            original_notes,
            min_duration=min_duration,
            min_velocity=min_vel,
        )
    else:
        filtered_notes = list(original_notes)

    # Step 2: Quantize
    if enable_quant and quant_strength > 0:
        quantized_notes = quantize_notes(filtered_notes, tempo, quant_grid, quant_strength)
        was_quantized = True
    else:
        quantized_notes = filtered_notes
        was_quantized = False

    # Step 3: Convert to MIDIData
    midi_data = notes_to_midi_data(quantized_notes, tempo)

    # Step 4: Save cleaned MIDI
    output_dir.mkdir(parents=True, exist_ok=True)
    output_path = output_dir / f"{stem_name}_cleaned.mid"
    save_midi(midi_data, output_path)

    return CleanedMIDI(
        stem_name=stem_name,
        original_notes=original_count,
        cleaned_notes=len(quantized_notes),
        midi_data=midi_data,
        output_path=output_path,
        tempo_bpm=tempo,
        quantized=was_quantized,
    )


# Validate prerequisites
if "TRANSCRIPTION_BATCH" not in dir() or not TRANSCRIPTION_BATCH.results:
    raise RuntimeError(
        "❌ No transcription results found. Please run the Transcription cell first."
    )

print(f"🎼 MIDI Cleanup Pipeline")
print(f"   Min note duration: {min_note_duration}s")
print(f"   Min velocity: {min_velocity}")
print(f"   Quantization: {'enabled' if enable_quantization else 'disabled'}")
if enable_quantization:
    print(f"   Quantize grid: {quantize_grid}, strength: {quantize_strength}")

# Step 1: Detect tempo from audio
CLEANUP_BATCH = MIDICleanupBatch()

if detect_tempo_from_audio:
    print("\n⏱️ Detecting tempo from audio...")

    # Use canonical audio if available, otherwise try to load from source
    if "CANONICAL_AUDIO" in dir() and CANONICAL_AUDIO is not None:
        tempo_audio = CANONICAL_AUDIO
    elif "AUDIO_PATH" in dir() and AUDIO_PATH is not None:
        tempo_audio = load_audio(AUDIO_PATH)
    else:
        tempo_audio = None

    if tempo_audio is not None:
        try:
            tempo_result = detect_tempo(tempo_audio.samples, tempo_audio.sample_rate)
            CLEANUP_BATCH.detected_tempo = tempo_result.bpm
            CLEANUP_BATCH.tempo_confidence = tempo_result.confidence
            print(
                f"   Detected BPM: {tempo_result.bpm:.1f} (confidence: {tempo_result.confidence:.2f})"
            )
        except Exception as e:
            CLEANUP_BATCH.detected_tempo = fallback_bpm
            print(f"   ⚠️ Tempo detection failed: {e}")
            print(f"   Using fallback: {fallback_bpm} BPM")
    else:
        CLEANUP_BATCH.detected_tempo = fallback_bpm
        print(f"   ⚠️ No audio available for tempo detection")
        print(f"   Using fallback: {fallback_bpm} BPM")
else:
    CLEANUP_BATCH.detected_tempo = fallback_bpm
    print(f"\nℹ️ Using manual BPM: {fallback_bpm}")

# Step 2: Process each transcribed stem
grid_beats = GRID_VALUES.get(quantize_grid, 0.25)
CLEANED_MIDI_DIR = MIDI_DIR / "cleaned"

print(f"\n🔄 Cleaning {len(TRANSCRIPTION_BATCH.results)} MIDI files...")

for stem_name, result in TRANSCRIPTION_BATCH.results.items():
    if result.midi_result is None:
        print(f"\n--- {stem_name} ---")
        print(f"   ⚠️ Skipped (no MIDI result)")
        continue

    print(f"\n--- {stem_name} ---")

    cleaned = cleanup_stem_midi(
        stem_name=stem_name,
        midi_result=result.midi_result,
        tempo=CLEANUP_BATCH.detected_tempo,
        min_duration=min_note_duration,
        min_vel=min_velocity,
        enable_filter=filter_short_notes,
        enable_quant=enable_quantization,
        quant_strength=quantize_strength,
        quant_grid=grid_beats,
        output_dir=CLEANED_MIDI_DIR,
    )

    CLEANUP_BATCH.results[stem_name] = cleaned
    CLEANUP_BATCH.total_original_notes += cleaned.original_notes
    CLEANUP_BATCH.total_cleaned_notes += cleaned.cleaned_notes

    removed = cleaned.original_notes - cleaned.cleaned_notes
    status = "✅" if cleaned.cleaned_notes > 0 else "⚠️"
    quant_label = " (quantized)" if cleaned.quantized else ""

    print(
        f"   {status} Notes: {cleaned.original_notes} → {cleaned.cleaned_notes} (-{removed}){quant_label}"
    )
    print(f"   Output: {cleaned.output_path}")

# Step 3: Display program mapping info
if enable_program_mapping:
    print("\n🎹 General MIDI Program Mapping:")
    for stem_name in CLEANUP_BATCH.results:
        program = GM_PROGRAMS.get(stem_name, 0)
        channel = 10 if stem_name == "drums" else 1
        print(f"   {stem_name}: Program {program}, Channel {channel}")

# Summary
print("\n" + "=" * 50)
print("✅ MIDI Cleanup Complete!")
print(f"   Tempo: {CLEANUP_BATCH.detected_tempo:.1f} BPM")
print(f"   Total notes: {CLEANUP_BATCH.total_original_notes} → {CLEANUP_BATCH.total_cleaned_notes}")
print(f"   Notes removed: {CLEANUP_BATCH.notes_removed}")
print(f"   Output directory: {CLEANED_MIDI_DIR}")

In [None]:
# @title 📊 QA Dashboard + Preview
# @markdown Interactive dashboard for reviewing separation and transcription results.
# @markdown Includes audio previews, QA metrics, and rerun controls.

# === Execution ===
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING

import gradio as gr
import numpy as np

if TYPE_CHECKING:
    from soundlab.core.audio import AudioSegment

# --- Helper Functions ---


def load_audio_for_preview(path: Path | str) -> tuple[int, np.ndarray] | None:
    """Load audio file for Gradio preview."""
    try:
        import soundfile as sf

        data, sr = sf.read(str(path), dtype="float32")
        # Gradio expects (sr, data) with data as numpy array
        return (sr, data)
    except Exception:
        return None


def compute_residual(
    mix_path: Path | str,
    stem_paths: dict[str, Path],
) -> tuple[int, np.ndarray] | None:
    """Compute mix - sum(stems) residual for QA visualization."""
    try:
        import soundfile as sf

        mix, sr = sf.read(str(mix_path), dtype="float32")

        total = np.zeros_like(mix)
        for stem_path in stem_paths.values():
            stem_audio, _ = sf.read(str(stem_path), dtype="float32")
            if stem_audio.shape == mix.shape:
                total = total + stem_audio
            elif stem_audio.shape[0] == mix.shape[0]:
                # Match mono/stereo
                if mix.ndim == 2 and stem_audio.ndim == 1:
                    stem_audio = np.column_stack([stem_audio, stem_audio])
                elif mix.ndim == 1 and stem_audio.ndim == 2:
                    stem_audio = np.mean(stem_audio, axis=1)
                total = total + stem_audio

        residual = mix - total
        return (sr, residual)
    except Exception as e:
        print(f"⚠️ Residual computation failed: {e}")
        return None


def format_qa_metrics(metrics: dict[str, float]) -> str:
    """Format QA metrics as a markdown table."""
    if not metrics:
        return "No metrics available"

    rows = ["| Metric | Value | Status |", "|--------|-------|--------|"]
    thresholds = {
        "reconstruction_error": (0.15, "lower is better"),
        "spectral_flatness": (0.1, "higher is better"),
        "clipping_ratio": (0.01, "lower is better"),
        "stereo_coherence": (0.2, "higher is better"),
        "leakage_ratio": (0.2, "lower is better"),
    }

    for name, value in metrics.items():
        if name in thresholds:
            thresh, direction = thresholds[name]
            if "lower" in direction:
                status = "✅" if value <= thresh else "⚠️"
            else:
                status = "✅" if value >= thresh else "⚠️"
        else:
            status = "ℹ️"
        rows.append(f"| {name.replace('_', ' ').title()} | {value:.4f} | {status} |")

    return "\n".join(rows)


def format_candidate_table(candidates: list[dict]) -> str:
    """Format candidate scores as a comparison table."""
    if not candidates:
        return "No candidates to display"

    rows = ["| Candidate | Score | Status |", "|-----------|-------|--------|"]
    for c in sorted(candidates, key=lambda x: x.get("score", 0), reverse=True):
        status = "✅ PASS" if c.get("passed", False) else "❌ FAIL"
        rows.append(f"| {c.get('name', 'Unknown')} | {c.get('score', 0):.3f} | {status} |")

    return "\n".join(rows)


# --- Dashboard Builder ---


def build_qa_dashboard():
    """Build the QA dashboard interface."""
    with gr.Blocks(theme=gr.themes.Soft()) as dashboard:
        gr.Markdown("## 📊 QA Dashboard + Preview")
        gr.Markdown("Review separation quality, preview stems/MIDI, and re-run if needed.")

        # --- Status Section ---
        with gr.Row():
            with gr.Column(scale=1):
                run_id_display = gr.Textbox(
                    label="Run ID",
                    value=RUN_ID if "RUN_ID" in dir() else "N/A",
                    interactive=False,
                )
            with gr.Column(scale=1):
                status_display = gr.Textbox(
                    label="Status",
                    value="Ready" if "STEM_RESULT" in dir() else "No stems available",
                    interactive=False,
                )

        # --- QA Metrics Section ---
        gr.Markdown("### 📈 Quality Metrics")

        with gr.Row():
            with gr.Column(scale=1):
                stem_qa_markdown = gr.Markdown(
                    value="Run separation first to see QA metrics.",
                    label="Stem QA",
                )
            with gr.Column(scale=1):
                midi_qa_markdown = gr.Markdown(
                    value="Run transcription to see MIDI QA metrics.",
                    label="MIDI QA",
                )

        # --- Audio Preview Section ---
        gr.Markdown("### 🔊 Audio Previews")

        with gr.Tabs() as audio_tabs:
            with gr.TabItem("🎵 Original Mix"):
                mix_audio = gr.Audio(
                    label="Original Mix",
                    interactive=False,
                )

            with gr.TabItem("🎤 Vocals"):
                vocals_audio = gr.Audio(label="Vocals Stem", interactive=False)

            with gr.TabItem("🥁 Drums"):
                drums_audio = gr.Audio(label="Drums Stem", interactive=False)

            with gr.TabItem("🎸 Bass"):
                bass_audio = gr.Audio(label="Bass Stem", interactive=False)

            with gr.TabItem("🎹 Other"):
                other_audio = gr.Audio(label="Other Stem", interactive=False)

            with gr.TabItem("📉 Residual"):
                residual_audio = gr.Audio(
                    label="Residual (Mix - Stems)",
                    interactive=False,
                )
                gr.Markdown(
                    "*Residual should be near-silent for good separation. "
                    "Audible content indicates reconstruction error.*"
                )

            with gr.TabItem("🔄 Reconstruction"):
                reconstruction_audio = gr.Audio(
                    label="Reconstructed (Sum of Stems)",
                    interactive=False,
                )

        # --- Candidate Comparison (if available) ---
        gr.Markdown("### 🎯 Candidate Comparison")
        candidate_table = gr.Markdown(
            value="Run candidate selection to see comparison.",
        )

        # --- Rerun Controls ---
        gr.Markdown("### 🔄 Rerun Controls")
        with gr.Row():
            rerun_separation_btn = gr.Button(
                "🎚️ Re-run Separation",
                variant="secondary",
            )
            rerun_transcription_btn = gr.Button(
                "🎹 Re-run Transcription",
                variant="secondary",
            )
            refresh_btn = gr.Button(
                "🔄 Refresh Dashboard",
                variant="primary",
            )

        rerun_output = gr.Markdown(value="")

        # --- Load Data Function ---
        def refresh_dashboard():
            """Refresh dashboard with current data."""
            outputs = {}

            # Run ID
            outputs["run_id"] = RUN_ID if "RUN_ID" in dir() else "N/A"
            outputs["status"] = "Ready" if "STEM_RESULT" in dir() else "No stems available"

            # Stem QA metrics
            if "STEM_QA_RESULT" in dir() and STEM_QA_RESULT is not None:
                outputs["stem_qa"] = format_qa_metrics(STEM_QA_RESULT.metrics)
            elif "STEM_RESULT" in dir() and STEM_RESULT is not None:
                # Compute QA on the fly
                try:
                    from soundlab.pipeline import QAConfig, score_separation
                    import soundfile as sf

                    mix_path = SOURCE_PATH if "SOURCE_PATH" in dir() else None
                    if mix_path and STEM_RESULT.stems:
                        mix, sr = sf.read(str(mix_path), dtype="float32")
                        stems_data = {}
                        for name, path in STEM_RESULT.stems.items():
                            stem_audio, _ = sf.read(str(path), dtype="float32")
                            stems_data[name] = stem_audio

                        qa_result = score_separation(mix, stems_data, sr)
                        outputs["stem_qa"] = format_qa_metrics(qa_result.metrics)
                        outputs["status"] = "✅ Pass" if qa_result.passed else "⚠️ Check QA"
                    else:
                        outputs["stem_qa"] = "Mix path not available for QA."
                except Exception as e:
                    outputs["stem_qa"] = f"QA computation error: {e}"
            else:
                outputs["stem_qa"] = "Run separation first to see QA metrics."

            # MIDI QA metrics
            if "MIDI_QA_RESULT" in dir() and MIDI_QA_RESULT is not None:
                outputs["midi_qa"] = format_qa_metrics(MIDI_QA_RESULT.metrics)
            else:
                outputs["midi_qa"] = "Run transcription to see MIDI QA metrics."

            # Audio previews
            if "SOURCE_PATH" in dir() and SOURCE_PATH:
                outputs["mix"] = load_audio_for_preview(SOURCE_PATH)
            else:
                outputs["mix"] = None

            if "STEM_PATHS" in dir() and STEM_PATHS:
                outputs["vocals"] = load_audio_for_preview(STEM_PATHS.get("vocals"))
                outputs["drums"] = load_audio_for_preview(STEM_PATHS.get("drums"))
                outputs["bass"] = load_audio_for_preview(STEM_PATHS.get("bass"))
                outputs["other"] = load_audio_for_preview(STEM_PATHS.get("other"))

                # Compute residual
                if "SOURCE_PATH" in dir() and SOURCE_PATH:
                    outputs["residual"] = compute_residual(SOURCE_PATH, STEM_PATHS)

                    # Compute reconstruction
                    try:
                        import soundfile as sf

                        total = None
                        sr = None
                        for path in STEM_PATHS.values():
                            audio, sr = sf.read(str(path), dtype="float32")
                            if total is None:
                                total = audio
                            else:
                                total = total + audio
                        outputs["reconstruction"] = (sr, total) if total is not None else None
                    except Exception:
                        outputs["reconstruction"] = None
                else:
                    outputs["residual"] = None
                    outputs["reconstruction"] = None
            else:
                outputs["vocals"] = None
                outputs["drums"] = None
                outputs["bass"] = None
                outputs["other"] = None
                outputs["residual"] = None
                outputs["reconstruction"] = None

            # Candidate table
            if "candidate_scores" in dir() and candidate_scores:
                candidates = [
                    {"name": c.name, "score": c.score, "passed": c.passed} for c in candidate_scores
                ]
                outputs["candidates"] = format_candidate_table(candidates)
            else:
                outputs["candidates"] = "Run candidate selection to see comparison."

            return (
                outputs["run_id"],
                outputs["status"],
                outputs["stem_qa"],
                outputs["midi_qa"],
                outputs["mix"],
                outputs["vocals"],
                outputs["drums"],
                outputs["bass"],
                outputs["other"],
                outputs["residual"],
                outputs["reconstruction"],
                outputs["candidates"],
            )

        def handle_rerun_separation():
            return "ℹ️ To re-run separation, modify settings in the Stem Separation cell and re-execute."

        def handle_rerun_transcription():
            return "ℹ️ To re-run transcription, modify settings in the Transcription cell and re-execute."

        # Wire up refresh button
        refresh_btn.click(
            fn=refresh_dashboard,
            outputs=[
                run_id_display,
                status_display,
                stem_qa_markdown,
                midi_qa_markdown,
                mix_audio,
                vocals_audio,
                drums_audio,
                bass_audio,
                other_audio,
                residual_audio,
                reconstruction_audio,
                candidate_table,
            ],
        )

        rerun_separation_btn.click(fn=handle_rerun_separation, outputs=[rerun_output])
        rerun_transcription_btn.click(fn=handle_rerun_transcription, outputs=[rerun_output])

        # Auto-refresh on load
        dashboard.load(
            fn=refresh_dashboard,
            outputs=[
                run_id_display,
                status_display,
                stem_qa_markdown,
                midi_qa_markdown,
                mix_audio,
                vocals_audio,
                drums_audio,
                bass_audio,
                other_audio,
                residual_audio,
                reconstruction_audio,
                candidate_table,
            ],
        )

    return dashboard


# Build and launch dashboard
print("📊 Building QA Dashboard...")
qa_dashboard = build_qa_dashboard()
qa_dashboard.launch(height=800, show_error=True)

In [None]:
# @title 🗣️ Voice Generation (Optional)
# @markdown Generate speech or convert singing voice using extracted stems.
# @markdown **Requires:** `install_voice=True` in the installation cell.

# --- TTS Settings ---
enable_tts = False  # @param {type:"boolean"}
tts_text = "Hello, this is a voice generation test."  # @param {type:"string"}
tts_language = "en"  # @param ["en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn", "ja", "hu", "ko", "hi"]
tts_speaker_wav = ""  # @param {type:"string"}

# --- SVC Settings ---
enable_svc = False  # @param {type:"boolean"}
svc_source_stem = "vocals"  # @param ["vocals", "other"]
svc_pitch_shift = 0  # @param {type:"slider", min:-12, max:12, step:1}
svc_model_path = ""  # @param {type:"string"}
svc_index_path = ""  # @param {type:"string"}

# === Execution ===
from __future__ import annotations

from pathlib import Path

# Check if voice module is available
VOICE_AVAILABLE = False
try:
    from soundlab.voice import TTSGenerator, VoiceConverter, TTSConfig, SVCConfig

    VOICE_AVAILABLE = True
    print("✅ Voice module available")
except ImportError:
    print("⚠️ Voice module not installed.")
    print("   To enable voice features, set install_voice=True in the installation cell")
    print("   and re-run the installation.")

if VOICE_AVAILABLE:
    # --- TTS Generation ---
    if enable_tts:
        print("\n🗣️ Text-to-Speech Generation")
        print(f"   Text: {tts_text[:50]}...")
        print(f"   Language: {tts_language}")

        # Determine speaker reference
        speaker_ref = None
        if tts_speaker_wav:
            speaker_ref = Path(tts_speaker_wav)
            if not speaker_ref.exists():
                print(f"   ⚠️ Speaker reference not found: {speaker_ref}")
                speaker_ref = None
            else:
                print(f"   Speaker ref: {speaker_ref.name}")
        elif "STEM_PATHS" in dir() and "vocals" in STEM_PATHS:
            # Use extracted vocals as speaker reference
            speaker_ref = STEM_PATHS["vocals"]
            print(f"   Speaker ref: vocals stem (auto-selected)")

        try:
            tts_config = TTSConfig(
                language=tts_language,
                speaker_wav=speaker_ref,
            )

            generator = TTSGenerator(config=tts_config)

            output_path = VOICE_DIR / "tts_output.wav"
            print(f"\n🔄 Generating speech...")

            result = generator.generate(
                text=tts_text,
                output_path=output_path,
            )

            TTS_OUTPUT_PATH = result.output_path
            print(f"✅ TTS output saved: {TTS_OUTPUT_PATH}")

        except Exception as e:
            print(f"❌ TTS generation failed: {e}")
            TTS_OUTPUT_PATH = None
    else:
        print("\nℹ️ TTS generation disabled (set enable_tts=True to enable)")
        TTS_OUTPUT_PATH = None

    # --- Singing Voice Conversion ---
    if enable_svc:
        print("\n🎤 Singing Voice Conversion")
        print(f"   Source stem: {svc_source_stem}")
        print(f"   Pitch shift: {svc_pitch_shift} semitones")

        # Validate source stem
        source_path = None
        if "STEM_PATHS" in dir() and svc_source_stem in STEM_PATHS:
            source_path = STEM_PATHS[svc_source_stem]
            print(f"   Source: {source_path.name}")
        else:
            print(f"   ⚠️ Source stem '{svc_source_stem}' not found")
            print("      Run stem separation first")

        # Validate model path
        model_path = Path(svc_model_path) if svc_model_path else None
        if model_path and not model_path.exists():
            print(f"   ⚠️ Model not found: {model_path}")
            print("      RVC models require manual setup. See documentation.")
            model_path = None

        if source_path and model_path:
            try:
                svc_config = SVCConfig(
                    model_path=model_path,
                    index_path=Path(svc_index_path) if svc_index_path else None,
                    pitch_shift_semitones=svc_pitch_shift,
                )

                converter = VoiceConverter(config=svc_config)

                output_path = VOICE_DIR / f"svc_{svc_source_stem}.wav"
                print(f"\n🔄 Converting voice...")

                result = converter.convert(
                    audio_path=source_path,
                    output_path=output_path,
                )

                SVC_OUTPUT_PATH = result.output_path
                print(f"✅ SVC output saved: {SVC_OUTPUT_PATH}")

            except Exception as e:
                print(f"❌ Voice conversion failed: {e}")
                SVC_OUTPUT_PATH = None
        else:
            if not model_path:
                print("\n💡 RVC Setup Instructions:")
                print("   1. Download a compatible RVC model (.pth file)")
                print("   2. Optionally download the index file (.index)")
                print("   3. Upload to Colab or mount from Drive")
                print("   4. Set svc_model_path and svc_index_path above")
            SVC_OUTPUT_PATH = None
    else:
        print("\nℹ️ SVC disabled (set enable_svc=True to enable)")
        SVC_OUTPUT_PATH = None

    # --- Summary ---
    print("\n" + "=" * 50)
    print("🗣️ Voice Generation Summary")
    print("=" * 50)
    if TTS_OUTPUT_PATH:
        print(f"   TTS Output: {TTS_OUTPUT_PATH}")
    if SVC_OUTPUT_PATH:
        print(f"   SVC Output: {SVC_OUTPUT_PATH}")
    if not TTS_OUTPUT_PATH and not SVC_OUTPUT_PATH:
        print("   No voice outputs generated")

else:
    print("\n💡 To enable voice features:")
    print("   1. Go to the 'Install SoundLab' cell")
    print("   2. Set install_voice = True")
    print("   3. Re-run the installation cell")
    print("   4. Return here and configure voice settings")

In [None]:
# @title 💾 Export & Download
# @markdown Export all processed files and download as a ZIP archive.

# --- Export Settings ---
export_stems = True  # @param {type:"boolean"}
export_midi = True  # @param {type:"boolean"}
export_voice = True  # @param {type:"boolean"}
export_analysis = True  # @param {type:"boolean"}
export_config = True  # @param {type:"boolean"}
export_qa_report = True  # @param {type:"boolean"}

# --- Format Options ---
stem_format = "wav"  # @param ["wav", "mp3", "flac"]
normalize_stems = True  # @param {type:"boolean"}
normalization_lufs = -14.0  # @param {type:"number"}

# --- Archive Options ---
create_zip_archive = True  # @param {type:"boolean"}
zip_filename = "soundlab_export"  # @param {type:"string"}
auto_download = True  # @param {type:"boolean"}

# === Execution ===
from __future__ import annotations

import csv
import json
import zipfile
from datetime import datetime
from pathlib import Path

# Collect files to export
export_files: list[Path] = []
export_manifest: dict[str, list[str]] = {
    "stems": [],
    "midi": [],
    "voice": [],
    "analysis": [],
    "reports": [],
}

print("📦 Collecting files for export...")
print("=" * 50)

# --- Stems ---
if export_stems and "STEM_PATHS" in dir() and STEM_PATHS:
    print("\n🎼 Stems:")

    if normalize_stems and normalization_lufs:
        try:
            from soundlab.io import load_audio
            from soundlab.io.export import export_audio

            # Re-export stems with normalization
            normalized_dir = EXPORTS_DIR / "stems"
            normalized_dir.mkdir(exist_ok=True)

            for name, path in STEM_PATHS.items():
                audio = load_audio(path)
                export_path = normalized_dir / f"{name}.{stem_format}"
                export_audio(
                    audio,
                    export_path,
                    format=stem_format,
                    normalize_lufs=normalization_lufs,
                )
                export_files.append(export_path)
                export_manifest["stems"].append(str(export_path.name))
                print(f"   ✅ {name}: {export_path.name} (normalized to {normalization_lufs} LUFS)")
        except ImportError as e:
            print(f"   ⚠️ Normalization requires pyloudnorm: {e}")
            # Fall back to raw stems
            for name, path in STEM_PATHS.items():
                export_files.append(Path(path))
                export_manifest["stems"].append(Path(path).name)
                print(f"   ✅ {name}: {Path(path).name}")
    else:
        for name, path in STEM_PATHS.items():
            export_files.append(Path(path))
            export_manifest["stems"].append(Path(path).name)
            print(f"   ✅ {name}: {Path(path).name}")
else:
    print("\n🎼 Stems: (none available)")

# --- MIDI ---
if export_midi and "MIDI_PATHS" in dir() and MIDI_PATHS:
    print("\n🎹 MIDI:")
    for name, path in MIDI_PATHS.items():
        export_files.append(Path(path))
        export_manifest["midi"].append(Path(path).name)
        print(f"   ✅ {name}: {Path(path).name}")
elif export_midi:
    # Check MIDI directory for any files
    if MIDI_DIR.exists():
        midi_files = list(MIDI_DIR.glob("*.mid")) + list(MIDI_DIR.glob("*.midi"))
        if midi_files:
            print("\n🎹 MIDI:")
            for path in midi_files:
                export_files.append(path)
                export_manifest["midi"].append(path.name)
                print(f"   ✅ {path.name}")
        else:
            print("\n🎹 MIDI: (none available)")
    else:
        print("\n🎹 MIDI: (none available)")

# --- Voice ---
if export_voice:
    voice_files = []
    if "TTS_OUTPUT_PATH" in dir() and TTS_OUTPUT_PATH:
        voice_files.append(Path(TTS_OUTPUT_PATH))
    if "SVC_OUTPUT_PATH" in dir() and SVC_OUTPUT_PATH:
        voice_files.append(Path(SVC_OUTPUT_PATH))
    if VOICE_DIR.exists():
        voice_files.extend(VOICE_DIR.glob("*.wav"))

    if voice_files:
        print("\n🗣️ Voice:")
        for path in voice_files:
            if path.exists():
                export_files.append(path)
                export_manifest["voice"].append(path.name)
                print(f"   ✅ {path.name}")
    else:
        print("\n🗣️ Voice: (none available)")

# --- Analysis ---
if export_analysis:
    analysis_files = []
    if ANALYSIS_DIR.exists():
        analysis_files.extend(ANALYSIS_DIR.glob("*"))

    if analysis_files:
        print("\n📊 Analysis:")
        for path in analysis_files:
            if path.is_file():
                export_files.append(path)
                export_manifest["analysis"].append(path.name)
                print(f"   ✅ {path.name}")
    else:
        print("\n📊 Analysis: (none available)")

# --- Config Export ---
if export_config:
    print("\n⚙️ Configuration:")
    config_path = EXPORTS_DIR / "config.json"

    config_data = {
        "run_id": RUN_ID if "RUN_ID" in dir() else None,
        "audio_hash": AUDIO_HASH if "AUDIO_HASH" in dir() else None,
        "source_file": str(SOURCE_PATH.name) if "SOURCE_PATH" in dir() and SOURCE_PATH else None,
        "timestamp": datetime.now().isoformat(),
        "separation": {},
        "transcription": {},
        "voice": {},
        "environment": {
            "gpu_available": GPU_AVAILABLE if "GPU_AVAILABLE" in dir() else False,
            "gpu_name": GPU_NAME if "GPU_NAME" in dir() else None,
        },
    }

    # Add separation config if available
    if "STEM_RESULT" in dir() and STEM_RESULT:
        config_data["separation"] = {
            "model": str(STEM_RESULT.config.model.value),
            "segment_length": STEM_RESULT.config.segment_length,
            "overlap": STEM_RESULT.config.overlap,
            "shifts": STEM_RESULT.config.shifts,
            "processing_time_seconds": STEM_RESULT.processing_time_seconds,
        }

    with open(config_path, "w") as f:
        json.dump(config_data, f, indent=2, default=str)

    export_files.append(config_path)
    export_manifest["reports"].append(config_path.name)
    print(f"   ✅ {config_path.name}")

# --- QA Report ---
if export_qa_report:
    print("\n📋 QA Report:")
    qa_report_path = EXPORTS_DIR / "qa_report.csv"

    qa_rows = [["Metric", "Value", "Threshold", "Status"]]

    # Add stem QA metrics
    if "STEM_QA_RESULT" in dir() and STEM_QA_RESULT:
        for metric, value in STEM_QA_RESULT.metrics.items():
            qa_rows.append([metric, f"{value:.4f}", "", ""])
        qa_rows.append(
            [
                "Overall Score",
                f"{STEM_QA_RESULT.score:.4f}",
                "0.70",
                "PASS" if STEM_QA_RESULT.passed else "FAIL",
            ]
        )

    # Add MIDI QA metrics
    if "MIDI_QA_RESULT" in dir() and MIDI_QA_RESULT:
        qa_rows.append(["", "", "", ""])  # Separator
        qa_rows.append(["MIDI QA", "", "", ""])
        for metric, value in MIDI_QA_RESULT.metrics.items():
            qa_rows.append([metric, f"{value:.4f}", "", ""])
        qa_rows.append(
            [
                "MIDI Score",
                f"{MIDI_QA_RESULT.score:.4f}",
                "0.60",
                "PASS" if MIDI_QA_RESULT.passed else "FAIL",
            ]
        )

    if len(qa_rows) > 1:
        with open(qa_report_path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerows(qa_rows)

        export_files.append(qa_report_path)
        export_manifest["reports"].append(qa_report_path.name)
        print(f"   ✅ {qa_report_path.name}")
    else:
        print("   ℹ️ No QA data available")

# --- Create ZIP Archive ---
if create_zip_archive and export_files:
    print("\n📦 Creating ZIP archive...")

    # Generate timestamped filename
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    run_suffix = f"_{RUN_ID[:8]}" if "RUN_ID" in dir() and RUN_ID else ""
    zip_path = EXPORTS_DIR / f"{zip_filename}{run_suffix}_{timestamp}.zip"

    with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
        for file_path in export_files:
            if file_path.exists():
                # Determine subdirectory based on file type
                if (
                    file_path.suffix in [".wav", ".mp3", ".flac"]
                    and "stem" in str(file_path.parent).lower()
                ):
                    arcname = f"stems/{file_path.name}"
                elif file_path.suffix in [".mid", ".midi"]:
                    arcname = f"midi/{file_path.name}"
                elif (
                    file_path.parent.name == "voice"
                    or "tts" in file_path.stem
                    or "svc" in file_path.stem
                ):
                    arcname = f"voice/{file_path.name}"
                elif file_path.suffix == ".json" or file_path.suffix == ".csv":
                    arcname = f"reports/{file_path.name}"
                else:
                    arcname = file_path.name

                zf.write(file_path, arcname=arcname)

        # Add manifest
        manifest_json = json.dumps(export_manifest, indent=2)
        zf.writestr("manifest.json", manifest_json)

    EXPORT_ZIP_PATH = zip_path
    zip_size_mb = zip_path.stat().st_size / (1024 * 1024)
    print(f"   ✅ {zip_path.name} ({zip_size_mb:.1f} MB)")

    # Auto-download in Colab
    if auto_download:
        try:
            from google.colab import files as colab_files

            print("\n📥 Initiating download...")
            colab_files.download(str(zip_path))
            print("✅ Download started!")
        except ImportError:
            print("\nℹ️ Not running in Colab. Download manually from:")
            print(f"   {zip_path}")
        except Exception as e:
            print(f"⚠️ Auto-download failed: {e}")
            print(f"   Manual download path: {zip_path}")

# --- Summary ---
print("\n" + "=" * 50)
print("📦 Export Summary")
print("=" * 50)
print(f"   Total files: {len(export_files)}")
print(f"   Stems: {len(export_manifest['stems'])}")
print(f"   MIDI: {len(export_manifest['midi'])}")
print(f"   Voice: {len(export_manifest['voice'])}")
print(f"   Analysis: {len(export_manifest['analysis'])}")
print(f"   Reports: {len(export_manifest['reports'])}")
if "EXPORT_ZIP_PATH" in dir():
    print(f"\n   📁 Archive: {EXPORT_ZIP_PATH}")

In [None]:
# @title 🧹 Cleanup
# @markdown Clean up temporary files and free GPU memory.

# --- Cleanup Options ---
cleanup_stems = False  # @param {type:"boolean"}
cleanup_midi = False  # @param {type:"boolean"}
cleanup_voice = False  # @param {type:"boolean"}
cleanup_analysis = False  # @param {type:"boolean"}
cleanup_exports = False  # @param {type:"boolean"}
cleanup_cache = False  # @param {type:"boolean"}
clear_gpu_cache = True  # @param {type:"boolean"}
reset_global_state = False  # @param {type:"boolean"}

# === Execution ===
from __future__ import annotations

import gc
import shutil
from pathlib import Path


def safe_rmtree(path: Path, name: str) -> int:
    """Safely remove a directory and return count of removed files."""
    if not path.exists():
        return 0

    count = sum(1 for _ in path.rglob("*") if _.is_file())
    try:
        shutil.rmtree(path)
        path.mkdir(exist_ok=True)  # Recreate empty directory
        print(f"   ✅ {name}: Removed {count} files")
        return count
    except Exception as e:
        print(f"   ⚠️ {name}: Failed to clean - {e}")
        return 0


print("🧹 Cleanup")
print("=" * 50)

total_removed = 0

# --- Directory Cleanup ---
if cleanup_stems and "STEMS_DIR" in dir():
    total_removed += safe_rmtree(STEMS_DIR, "Stems")

if cleanup_midi and "MIDI_DIR" in dir():
    total_removed += safe_rmtree(MIDI_DIR, "MIDI")

if cleanup_voice and "VOICE_DIR" in dir():
    total_removed += safe_rmtree(VOICE_DIR, "Voice")

if cleanup_analysis and "ANALYSIS_DIR" in dir():
    total_removed += safe_rmtree(ANALYSIS_DIR, "Analysis")

if cleanup_exports and "EXPORTS_DIR" in dir():
    total_removed += safe_rmtree(EXPORTS_DIR, "Exports")

if cleanup_cache and "CACHE_ROOT" in dir():
    # Clean checkpoints and cache, preserve models
    if CHECKPOINTS_DIR.exists():
        total_removed += safe_rmtree(CHECKPOINTS_DIR, "Checkpoints")

    cache_dir = CACHE_ROOT / "cache"
    if cache_dir.exists():
        total_removed += safe_rmtree(cache_dir, "Cache")

# --- GPU Memory ---
if clear_gpu_cache:
    print("\n🎮 GPU Memory:")
    try:
        import torch

        if torch.cuda.is_available():
            before = torch.cuda.memory_allocated() / 1e9
            torch.cuda.empty_cache()
            gc.collect()
            after = torch.cuda.memory_allocated() / 1e9
            freed = before - after
            print(f"   ✅ Cleared GPU cache: {freed:.2f} GB freed")
            print(f"   Current allocation: {after:.2f} GB")
        else:
            print("   ℹ️ No GPU available")
    except ImportError:
        print("   ℹ️ PyTorch not available")

# --- Python Garbage Collection ---
print("\n🐍 Python Memory:")
gc.collect()
print("   ✅ Garbage collection complete")

# --- Reset Global State ---
if reset_global_state:
    print("\n🔄 Resetting Global State:")

    # List of global variables to reset
    globals_to_reset = [
        "CURRENT_AUDIO",
        "CANONICAL_AUDIO",
        "EXCERPT_AUDIO",
        "AUDIO_HASH",
        "SOURCE_PATH",
        "RUN_ID",
        "STEM_RESULT",
        "STEM_PATHS",
        "STEM_QA_RESULT",
        "MIDI_PATHS",
        "MIDI_QA_RESULT",
        "TTS_OUTPUT_PATH",
        "SVC_OUTPUT_PATH",
        "EXPORT_ZIP_PATH",
        "BEST_CANDIDATE",
        "BEST_PLAN",
        "candidate_scores",
    ]

    reset_count = 0
    for var_name in globals_to_reset:
        if var_name in globals():
            globals()[var_name] = None
            reset_count += 1

    print(f"   ✅ Reset {reset_count} global variables")
    print("   ℹ️ Re-run the Upload cell to start a new session")

# --- Summary ---
print("\n" + "=" * 50)
print("🧹 Cleanup Summary")
print("=" * 50)
print(f"   Files removed: {total_removed}")
print(f"   GPU cache cleared: {'Yes' if clear_gpu_cache else 'No'}")
print(f"   Global state reset: {'Yes' if reset_global_state else 'No'}")

if not any(
    [
        cleanup_stems,
        cleanup_midi,
        cleanup_voice,
        cleanup_analysis,
        cleanup_exports,
        cleanup_cache,
    ]
):
    print("\n💡 Tip: Enable cleanup options above to remove temporary files")
    print("   before downloading to free up disk space.")