# Instrument Detection with Timestamps - 3-Step CoT

Chain of thought:
1. Layer Analysis (background/middleground/foreground)
2. Timestamp Analysis (when instruments enter/exit)
3. Structured JSON Output

In [None]:
import torch
import json
from pathlib import Path
from loguru import logger

from src.data.qwen_omni import QwenOmniCoTDataset
from src.models import load_model_and_processor

In [None]:
# Config
MODEL_NAME = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
AUDIO_DIR = "../audio_files"
DTYPE = torch.bfloat16
DEVICE = "cuda"
MAX_FILES = 2

In [None]:
# Load model
model, processor = load_model_and_processor(MODEL_NAME, DTYPE, DEVICE)
model.eval()
print("Model loaded")

In [None]:
# Load dataset (reuse existing class)
dataset = QwenOmniCoTDataset(AUDIO_DIR, processor)
print(f"Found {len(dataset)} files")
for f in dataset.files[:MAX_FILES]:
    print(f"  - {Path(f).name}")

## New Prompts for Timestamp Analysis

In [None]:
# Step 2: Timestamp analysis prompt
STEP_2_TIMESTAMP_PROMPT = """Based on your layer analysis:

{step_1_response}

Now analyze WHEN each instrument appears. Listen and identify:
- When does each instrument first enter?
- Are there sections where instruments drop out?
- When do instruments return?

Format:
**Instrument Timeline:**
[instrument]: enters at [time], exits at [time], re-enters at [time]
...

Use timestamps like "0:00", "0:15", "1:30"."""

# Step 3: Final JSON output prompt
STEP_3_JSON_PROMPT = """Based on your analysis:

Layer Analysis:
{step_1_response}

Timeline:
{step_2_response}

Output JSON with instruments and timestamps:

{{
  "instruments": [
    {{"name": "instrument", "layer": "background|middle_ground|foreground", "timestamps": [{{"start": "0:00", "end": "1:30"}}]}}
  ]
}}

Return ONLY the JSON:"""

In [None]:
def generate(conversation, max_new_tokens=512):
    """Generate response from a conversation."""
    inputs = processor.apply_chat_template(
        conversation,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
        return_dict=True,
    )

    # Move to device
    inputs = {
        k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()
    }

    with torch.no_grad():
        text_ids, _ = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            return_audio=False,
        )

    generated_ids = text_ids[:, inputs["input_ids"].shape[1] :]
    response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response

In [None]:
def detect_with_timestamps(waveform, verbose=True):
    """
    3-step CoT detection:
    1. Layer analysis (uses existing CoT prompt)
    2. Timestamp analysis
    3. Structured JSON output
    """
    # Step 1: Layer analysis (reuse existing conversation builder)
    if verbose:
        print("Step 1: Layer analysis...")

    step_1_conv = QwenOmniCoTDataset.get_conversation(waveform)
    step_1_response = generate(step_1_conv, max_new_tokens=512)

    if verbose:
        print(f"  Done ({len(step_1_response)} chars)")

    # Step 2: Timestamp analysis
    if verbose:
        print("Step 2: Timestamp analysis...")

    step_2_text = STEP_2_TIMESTAMP_PROMPT.format(step_1_response=step_1_response)
    step_2_conv = [
        QwenOmniCoTDataset.get_system_message(QwenOmniCoTDataset.SYSTEM_PROMPT),
        QwenOmniCoTDataset.get_user_message(waveform, step_2_text),
    ]
    step_2_response = generate(step_2_conv, max_new_tokens=512)

    if verbose:
        print(f"  Done ({len(step_2_response)} chars)")

    # Step 3: Structured JSON
    if verbose:
        print("Step 3: JSON output...")

    step_3_text = STEP_3_JSON_PROMPT.format(
        step_1_response=step_1_response,
        step_2_response=step_2_response,
    )
    step_3_conv = [
        QwenOmniCoTDataset.get_system_message(QwenOmniCoTDataset.SYSTEM_PROMPT),
        QwenOmniCoTDataset.get_user_message(waveform, step_3_text),
    ]
    step_3_response = generate(step_3_conv, max_new_tokens=1024)

    if verbose:
        print(f"  Done ({len(step_3_response)} chars)")

    # Parse JSON
    parsed = None
    try:
        start = step_3_response.find("{")
        end = step_3_response.rfind("}") + 1
        if start != -1 and end > start:
            parsed = json.loads(step_3_response[start:end])
    except json.JSONDecodeError:
        pass

    return {
        "step_1": step_1_response,
        "step_2": step_2_response,
        "step_3": step_3_response,
        "parsed": parsed,
    }

## Run Detection

In [None]:
# Get first sample
filenames, waveforms, _ = dataset[0]
waveform = waveforms[0]
filename = Path(filenames[0]).name

print(f"Processing: {filename}")
print(f"Duration: {len(waveform) / 16000:.1f}s")

In [None]:
# Run 3-step detection
result = detect_with_timestamps(waveform)

In [None]:
# Step 1 output
print("=" * 60)
print("STEP 1: LAYER ANALYSIS")
print("=" * 60)
print(result["step_1"])

In [None]:
# Step 2 output
print("=" * 60)
print("STEP 2: TIMESTAMP ANALYSIS")
print("=" * 60)
print(result["step_2"])

In [None]:
# Step 3 output
print("=" * 60)
print("STEP 3: JSON OUTPUT")
print("=" * 60)
print(result["step_3"])

In [None]:
# Parsed output
print("=" * 60)
print("PARSED")
print("=" * 60)
if result["parsed"]:
    print(json.dumps(result["parsed"], indent=2))
else:
    print("Failed to parse JSON")

## Process Multiple Files

In [None]:
# Process first few files
all_results = []

for i in range(min(MAX_FILES, len(dataset))):
    filenames, waveforms, _ = dataset[i]
    filename = Path(filenames[0]).name
    waveform = waveforms[0]

    print(f"\n[{i+1}/{MAX_FILES}] {filename}")
    result = detect_with_timestamps(waveform, verbose=True)
    result["filename"] = filename
    all_results.append(result)

    if result["parsed"]:
        n = len(result["parsed"].get("instruments", []))
        print(f"  -> {n} instruments detected")
    else:
        print("  -> Parse failed")

In [None]:
# Summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)

for r in all_results:
    print(f"\n{r['filename']}:")
    if r["parsed"] and "instruments" in r["parsed"]:
        for inst in r["parsed"]["instruments"]:
            name = inst.get("name", "?")
            layer = inst.get("layer", "?")
            ts = inst.get("timestamps", [])
            ts_str = ", ".join(f"{t.get('start','?')}-{t.get('end','end')}" for t in ts)
            print(f"  {name} [{layer}]: {ts_str}")
    else:
        print("  (no parsed output)")