In [None]:
import os
import time
import json
from pathlib import Path

import torch
import torchaudio
from tqdm import tqdm
from transformers import (
    Qwen2_5OmniProcessor,
    Qwen2_5OmniThinkerForConditionalGeneration,  # OK to keep; we’ll also handle fallback
)

# ==== CONFIG ====
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

MODEL_PATH = "/raid/vladimir_albrekht/projects/conaiki/qwen_omni/conaiki/qwen_omni_finetuned_common_voice_for_qwen_train_less_than_3_sec/final_model_1_epoch"
JSONL_PATH = "/raid/vladimir_albrekht/projects/conaiki/qwen_omni/conaiki/data/common_voice_for_qwen/less_than_3_sec/processed/streaming_chunks_padded.jsonl"
TRANSLATE_THRESHOLD = 0.1
USE_AUTOCast = False  # set True if you want cuda autocast(bfloat16) during forward
# =================


def load_and_prep_audio(audio_path: str, target_sr: int) -> torch.Tensor:
    wav, sr = torchaudio.load(audio_path)
    if wav.shape[0] > 1:
        wav = wav.mean(0, keepdim=True)
    if sr != target_sr:
        wav = torchaudio.functional.resample(wav, sr, target_sr)
    return wav.squeeze(0)  # [S]


def _pool_last_audio_token(hidden_states: torch.Tensor,
                           input_ids: torch.LongTensor,
                           attention_mask: torch.LongTensor,
                           audio_token_index: int) -> torch.Tensor:
    """
    hidden_states: [B, T, D] (last hidden layer from the text model)
    input_ids:     [B, T]
    attention_mask:[B, T]
    Returns:       [B, D] pooled at last audio token; fallback = last attended token.
    """
    B, T, D = hidden_states.shape
    device = hidden_states.device

    audio_mask = (input_ids == audio_token_index)  # [B, T]
    has_audio = audio_mask.any(dim=1)              # [B]

    # default: last attended token
    # (find last index where attention_mask==1)
    attn = attention_mask.to(torch.int8)
    rev = torch.flip(attn, dims=[1])                        # [B, T]
    last_from_end = torch.argmax(rev, dim=1)                # [B]
    last_attn_idx = (T - 1) - last_from_end                 # [B]
    last_idx = last_attn_idx.clone()

    # where audio exists, take last audio position instead
    if has_audio.any():
        rev_audio = torch.flip(audio_mask.to(torch.int8), dims=[1])
        last_audio_from_end = torch.argmax(rev_audio, dim=1)
        last_audio_idx = (T - 1) - last_audio_from_end
        last_idx = torch.where(has_audio, last_audio_idx, last_attn_idx)

    gather_idx = last_idx.view(B, 1, 1).expand(B, 1, D)     # [B,1,D]
    pooled = hidden_states.gather(dim=1, index=gather_idx).squeeze(1)
    return pooled  # [B, D]


@torch.inference_mode()
def get_gate_prediction(model, processor, wav_tensor: torch.Tensor, system_prompt: dict) -> torch.Tensor:
    """
    Returns probabilities over [WAIT, TRANSLATE].
    Works if the model natively returns gate_logits OR if we have to compute them
    from hidden_states + model.conaiki_gate.
    """
    target_sr = processor.feature_extractor.sampling_rate

    # conversation (minimal, mirrors training structure on the user/audio side)
    conversation = [
        system_prompt,
        {"role": "user", "content": [{"type": "audio", "audio_url": "placeholder.wav"}]},
    ]
    text = processor.apply_chat_template(conversation, add_generation_prompt=False, tokenize=False)

    # Prepare inputs (single example)
    inputs = processor(
        text=text,
        audio=[wav_tensor.cpu().numpy()],
        sampling_rate=target_sr,
        return_tensors="pt",
        padding=True,
    )

    # Move to the right device
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # Try the fast path: model returns gate_logits directly (your patched forward).
    try:
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=(model.device.type == "cuda" and USE_AUTOCast)):
            outputs = model(**inputs, return_gate_logits=True)
        if hasattr(outputs, "gate_logits") and outputs.gate_logits is not None:
            probs = torch.softmax(outputs.gate_logits.float(), dim=-1).squeeze(0)
            return probs
    except TypeError:
        # The model forward might not accept return_gate_logits; we’ll fall back below.
        pass

    # Fallback: compute gate_logits ourselves from hidden states.
    if not hasattr(model, "conaiki_gate"):
        raise RuntimeError(
            "Model does not expose 'gate_logits' AND has no 'conaiki_gate' module. "
            "Load your custom class (with gate head) or re-export the model with that head."
        )

    # We need last hidden states BEFORE lm_head; ask the model for hidden_states.
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=(model.device.type == "cuda" and USE_AUTOCast)):
        outputs = model(**inputs, output_hidden_states=True, return_dict=True)

    if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None:
        raise RuntimeError("Model did not return hidden_states; cannot build gate logits fallback.")

    # The last entry in hidden_states is the last layer of the text model (before lm_head).
    # NOTE: For Qwen* decoders, hidden_states is a list of all layer outputs. Pick the last.
    last_h = outputs.hidden_states[-1]  # [B, T, D]
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    audio_idx = getattr(model.config, "audio_token_index", None)
    if audio_idx is None:
        # If the config doesn’t have it, try a reasonable fallback.
        # But ideally, this should exist in Qwen-Omni configs.
        audio_idx = 151666  # (example) – replace with your actual audio token id if known.

    pooled = _pool_last_audio_token(last_h, input_ids, attention_mask, audio_idx)  # [B, D]
    gate_logits = model.conaiki_gate(pooled)  # [B, 2] for WAIT/TRANSLATE (or [B, C] if you set C=2)
    probs = torch.softmax(gate_logits.float(), dim=-1).squeeze(0)
    return probs


def evaluate_gate_model():
    start = time.perf_counter()
    device = "cuda" if torch.cuda.is_available() else "cpu"

    print(f"Loading model from: {MODEL_PATH} ...")
    model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
        MODEL_PATH,
        torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
        trust_remote_code=True,   # important if your class was saved with custom code
    ).to(device).eval()
    processor = Qwen2_5OmniProcessor.from_pretrained(MODEL_PATH)
    print("Model and processor loaded.")

    system_prompt = {
        "role": "system",
        "content": [{"type": "text", "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."}],
    }

    # Load dataset
    print(f"Loading evaluation data from: {JSONL_PATH}")
    jsonl_path = Path(JSONL_PATH).expanduser().resolve()
    base_dir = jsonl_path.parent
    with open(jsonl_path, "r", encoding="utf-8") as f:
        samples = [json.loads(line) for line in f]

    WAIT_IDX, TRANS_IDX = 0, 1
    label_map = {"WAIT": WAIT_IDX, "TRANSLATE": TRANS_IDX}

    correct = 0
    total = 0
    cm = {"tp": 0, "tn": 0, "fp": 0, "fn": 0}

    print(f"\nStarting evaluation on {len(samples)} audio chunks...")
    pbar = tqdm(samples, desc="Evaluating Gate Predictions")
    for sample in pbar:
        try:
            true_lbl_s = sample["gate_label"]
            true_lbl = label_map[true_lbl_s]

            audio_rel = sample["audio_path"]
            audio_path = (base_dir / audio_rel).as_posix()

            wav = load_and_prep_audio(audio_path, processor.feature_extractor.sampling_rate)
            probs = get_gate_prediction(model, processor, wav, system_prompt)

            p_translate = probs[TRANS_IDX].item()
            pred_lbl = TRANS_IDX if p_translate >= TRANSLATE_THRESHOLD else WAIT_IDX

            correct += int(pred_lbl == true_lbl)
            if   pred_lbl == TRANS_IDX and true_lbl == TRANS_IDX: cm["tp"] += 1
            elif pred_lbl == WAIT_IDX  and true_lbl == WAIT_IDX:  cm["tn"] += 1
            elif pred_lbl == TRANS_IDX and true_lbl == WAIT_IDX:  cm["fp"] += 1
            elif pred_lbl == WAIT_IDX  and true_lbl == TRANS_IDX: cm["fn"] += 1

            total += 1

        except Exception as e:
            print(f"\nSkipping sample due to error: {e} | Sample: {sample.get('audio_path')}")

    # Report
    print("\n--- Evaluation Complete ---")
    acc = (correct / total * 100.0) if total else 0.0
    print(f"Overall Accuracy: {acc:.2f}% ({correct} / {total})")

    tp, tn, fp, fn = cm["tp"], cm["tn"], cm["fp"], cm["fn"]
    print("\nConfusion Matrix (Positive = TRANSLATE):")
    print(f"  TP: {tp:4d}   TN: {tn:4d}   FP: {fp:4d}   FN: {fn:4d}")

    precision = tp / (tp + fp) if (tp + fp) else 0.0
    recall    = tp / (tp + fn) if (tp + fn) else 0.0
    f1        = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0

    print("\nMetrics (TRANSLATE):")
    print(f"  Precision: {precision:.4f}")
    print(f"  Recall:    {recall:.4f}")
    print(f"  F1-Score:  {f1:.4f}")

    print(f"\nTotal runtime: {time.perf_counter() - start:.2f} s")


if __name__ == "__main__":
    evaluate_gate_model()


  from .autonotebook import tqdm as notebook_tqdm


[2025-08-28 10:29:21,047] [INFO] [real_accelerator.py:260:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2025-08-28 10:29:22,271] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False


Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'mrope_section'}


Loading model from: /raid/vladimir_albrekht/projects/conaiki/qwen_omni/conaiki/qwen_omni_finetuned_common_voice_for_qwen_train_less_than_3_sec/final_model ...


Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00,  7.28it/s]


Model and processor loaded.
Loading evaluation data from: /raid/vladimir_albrekht/projects/conaiki/qwen_omni/conaiki/data/common_voice_for_qwen/less_than_3_sec/processed/streaming_chunks_padded.jsonl

Starting evaluation on 4793 audio chunks...


Evaluating Gate Predictions: 100%|██████████| 4793/4793 [10:34<00:00,  7.55it/s]


--- Evaluation Complete ---
Overall Accuracy: 63.59% (3048 / 4793)

Confusion Matrix (Positive = TRANSLATE):
  TP:  516   TN: 2532   FP: 1390   FN:  355

Metrics (TRANSLATE):
  Precision: 0.2707
  Recall:    0.5924
  F1-Score:  0.3716

Total runtime: 640.25 s





In [1]:
import os
import time
import json
from pathlib import Path
from datetime import datetime

import torch
import torchaudio
import numpy as np
from tqdm import tqdm
from transformers import (
    Qwen2_5OmniProcessor,
    Qwen2_5OmniThinkerForConditionalGeneration,
)

# ==== CONFIG ====
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

MODEL_PATH = "/raid/vladimir_albrekht/projects/conaiki/qwen_omni/conaiki/qwen_omni_finetuned_common_voice_for_qwen_train_less_than_3_sec/final_model_1_epoch"
JSONL_PATH = "/raid/vladimir_albrekht/projects/conaiki/qwen_omni/conaiki/data/common_voice_for_qwen/less_than_3_sec/processed/streaming_chunks_padded.jsonl"

# Threshold range configuration
THRESHOLD_START = 0.1
THRESHOLD_END = 0.99
THRESHOLD_STEP = 0.01

# Output file for results
OUTPUT_JSON_PATH = "gate_threshold_evaluation_results.json"

USE_AUTOCAST = False  # set True if you want cuda autocast(bfloat16) during forward
# =================


def load_and_prep_audio(audio_path: str, target_sr: int) -> torch.Tensor:
    wav, sr = torchaudio.load(audio_path)
    if wav.shape[0] > 1:
        wav = wav.mean(0, keepdim=True)
    if sr != target_sr:
        wav = torchaudio.functional.resample(wav, sr, target_sr)
    return wav.squeeze(0)  # [S]


def _pool_last_audio_token(hidden_states: torch.Tensor,
                           input_ids: torch.LongTensor,
                           attention_mask: torch.LongTensor,
                           audio_token_index: int) -> torch.Tensor:
    """
    hidden_states: [B, T, D] (last hidden layer from the text model)
    input_ids:     [B, T]
    attention_mask:[B, T]
    Returns:       [B, D] pooled at last audio token; fallback = last attended token.
    """
    B, T, D = hidden_states.shape
    device = hidden_states.device

    audio_mask = (input_ids == audio_token_index)  # [B, T]
    has_audio = audio_mask.any(dim=1)              # [B]

    # default: last attended token
    # (find last index where attention_mask==1)
    attn = attention_mask.to(torch.int8)
    rev = torch.flip(attn, dims=[1])                        # [B, T]
    last_from_end = torch.argmax(rev, dim=1)                # [B]
    last_attn_idx = (T - 1) - last_from_end                 # [B]
    last_idx = last_attn_idx.clone()

    # where audio exists, take last audio position instead
    if has_audio.any():
        rev_audio = torch.flip(audio_mask.to(torch.int8), dims=[1])
        last_audio_from_end = torch.argmax(rev_audio, dim=1)
        last_audio_idx = (T - 1) - last_audio_from_end
        last_idx = torch.where(has_audio, last_audio_idx, last_attn_idx)

    gather_idx = last_idx.view(B, 1, 1).expand(B, 1, D)     # [B,1,D]
    pooled = hidden_states.gather(dim=1, index=gather_idx).squeeze(1)
    return pooled  # [B, D]


@torch.inference_mode()
def get_gate_prediction(model, processor, wav_tensor: torch.Tensor, system_prompt: dict) -> torch.Tensor:
    """
    Returns probabilities over [WAIT, TRANSLATE].
    Works if the model natively returns gate_logits OR if we have to compute them
    from hidden_states + model.conaiki_gate.
    """
    target_sr = processor.feature_extractor.sampling_rate

    # conversation (minimal, mirrors training structure on the user/audio side)
    conversation = [
        system_prompt,
        {"role": "user", "content": [{"type": "audio", "audio_url": "placeholder.wav"}]},
    ]
    text = processor.apply_chat_template(conversation, add_generation_prompt=False, tokenize=False)

    # Prepare inputs (single example)
    inputs = processor(
        text=text,
        audio=[wav_tensor.cpu().numpy()],
        sampling_rate=target_sr,
        return_tensors="pt",
        padding=True,
    )

    # Move to the right device
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # Try the fast path: model returns gate_logits directly (your patched forward).
    try:
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=(model.device.type == "cuda" and USE_AUTOCAST)):
            outputs = model(**inputs, return_gate_logits=True)
        if hasattr(outputs, "gate_logits") and outputs.gate_logits is not None:
            probs = torch.softmax(outputs.gate_logits.float(), dim=-1).squeeze(0)
            return probs
    except TypeError:
        # The model forward might not accept return_gate_logits; we'll fall back below.
        pass

    # Fallback: compute gate_logits ourselves from hidden states.
    if not hasattr(model, "conaiki_gate"):
        raise RuntimeError(
            "Model does not expose 'gate_logits' AND has no 'conaiki_gate' module. "
            "Load your custom class (with gate head) or re-export the model with that head."
        )

    # We need last hidden states BEFORE lm_head; ask the model for hidden_states.
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=(model.device.type == "cuda" and USE_AUTOCAST)):
        outputs = model(**inputs, output_hidden_states=True, return_dict=True)

    if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None:
        raise RuntimeError("Model did not return hidden_states; cannot build gate logits fallback.")

    # The last entry in hidden_states is the last layer of the text model (before lm_head).
    # NOTE: For Qwen* decoders, hidden_states is a list of all layer outputs. Pick the last.
    last_h = outputs.hidden_states[-1]  # [B, T, D]
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    audio_idx = getattr(model.config, "audio_token_index", None)
    if audio_idx is None:
        # If the config doesn't have it, try a reasonable fallback.
        # But ideally, this should exist in Qwen-Omni configs.
        audio_idx = 151666  # (example) – replace with your actual audio token id if known.

    pooled = _pool_last_audio_token(last_h, input_ids, attention_mask, audio_idx)  # [B, D]
    gate_logits = model.conaiki_gate(pooled)  # [B, 2] for WAIT/TRANSLATE (or [B, C] if you set C=2)
    probs = torch.softmax(gate_logits.float(), dim=-1).squeeze(0)
    return probs


def calculate_metrics(tp, tn, fp, fn):
    """Calculate precision, recall, F1, and other metrics."""
    accuracy = ((tp + tn) / (tp + tn + fp + fn)) if (tp + tn + fp + fn) > 0 else 0.0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    
    # Matthews Correlation Coefficient
    mcc_num = (tp * tn - fp * fn)
    mcc_den = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
    mcc = mcc_num / mcc_den if mcc_den > 0 else 0.0
    
    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1_score": f1,
        "specificity": specificity,
        "mcc": mcc
    }


def evaluate_single_threshold(samples, model, processor, system_prompt, threshold, base_dir):
    """Evaluate model performance at a single threshold."""
    WAIT_IDX, TRANS_IDX = 0, 1
    label_map = {"WAIT": WAIT_IDX, "TRANSLATE": TRANS_IDX}
    
    tp = tn = fp = fn = 0
    predictions = []
    errors = []
    
    for sample in samples:
        try:
            true_lbl_s = sample["gate_label"]
            true_lbl = label_map[true_lbl_s]
            
            audio_rel = sample["audio_path"]
            audio_path = (base_dir / audio_rel).as_posix()
            
            wav = load_and_prep_audio(audio_path, processor.feature_extractor.sampling_rate)
            probs = get_gate_prediction(model, processor, wav, system_prompt)
            
            p_wait = probs[WAIT_IDX].item()
            p_translate = probs[TRANS_IDX].item()
            pred_lbl = TRANS_IDX if p_translate >= threshold else WAIT_IDX
            
            # Update confusion matrix
            if pred_lbl == TRANS_IDX and true_lbl == TRANS_IDX:
                tp += 1
            elif pred_lbl == WAIT_IDX and true_lbl == WAIT_IDX:
                tn += 1
            elif pred_lbl == TRANS_IDX and true_lbl == WAIT_IDX:
                fp += 1
            elif pred_lbl == WAIT_IDX and true_lbl == TRANS_IDX:
                fn += 1
            
            predictions.append({
                "audio_path": audio_rel,
                "true_label": true_lbl_s,
                "predicted_label": "TRANSLATE" if pred_lbl == TRANS_IDX else "WAIT",
                "p_wait": p_wait,
                "p_translate": p_translate,
                "correct": pred_lbl == true_lbl
            })
            
        except Exception as e:
            errors.append({
                "audio_path": sample.get('audio_path', 'unknown'),
                "error": str(e)
            })
    
    metrics = calculate_metrics(tp, tn, fp, fn)
    
    return {
        "threshold": threshold,
        "confusion_matrix": {
            "tp": tp,
            "tn": tn,
            "fp": fp,
            "fn": fn
        },
        "metrics": metrics,
        "total_samples": len(samples),
        "processed_samples": len(predictions),
        "error_count": len(errors),
        "errors": errors if errors else None
    }


def evaluate_gate_model_multi_threshold():
    """Main evaluation function testing multiple thresholds."""
    overall_start = time.perf_counter()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    print(f"Loading model from: {MODEL_PATH} ...")
    model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
        MODEL_PATH,
        torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
        trust_remote_code=True,
    ).to(device).eval()
    processor = Qwen2_5OmniProcessor.from_pretrained(MODEL_PATH)
    print("Model and processor loaded.")
    
    system_prompt = {
        "role": "system",
        "content": [{"type": "text", "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."}],
    }
    
    # Load dataset
    print(f"Loading evaluation data from: {JSONL_PATH}")
    jsonl_path = Path(JSONL_PATH).expanduser().resolve()
    base_dir = jsonl_path.parent
    with open(jsonl_path, "r", encoding="utf-8") as f:
        samples = [json.loads(line) for line in f]
    
    print(f"Loaded {len(samples)} samples for evaluation.")
    
    # Generate thresholds to test
    thresholds = np.arange(THRESHOLD_START, THRESHOLD_END + THRESHOLD_STEP, THRESHOLD_STEP)
    thresholds = np.round(thresholds, 2)  # Avoid floating point precision issues
    
    # Store results
    all_results = {
        "model_path": MODEL_PATH,
        "dataset_path": JSONL_PATH,
        "evaluation_timestamp": datetime.now().isoformat(),
        "total_samples": len(samples),
        "threshold_results": [],
        "best_threshold": None,
        "summary": {}
    }
    
    # First pass: get predictions for all samples (cache them)
    print("\nGenerating predictions for all samples...")
    WAIT_IDX, TRANS_IDX = 0, 1
    label_map = {"WAIT": WAIT_IDX, "TRANSLATE": TRANS_IDX}
    
    cached_predictions = []
    pbar = tqdm(samples, desc="Computing predictions")
    
    for sample in pbar:
        try:
            true_lbl_s = sample["gate_label"]
            true_lbl = label_map[true_lbl_s]
            
            audio_rel = sample["audio_path"]
            audio_path = (base_dir / audio_rel).as_posix()
            
            wav = load_and_prep_audio(audio_path, processor.feature_extractor.sampling_rate)
            probs = get_gate_prediction(model, processor, wav, system_prompt)
            
            cached_predictions.append({
                "audio_path": audio_rel,
                "true_label": true_lbl_s,
                "true_label_idx": true_lbl,
                "p_wait": probs[WAIT_IDX].item(),
                "p_translate": probs[TRANS_IDX].item()
            })
            
        except Exception as e:
            print(f"\nError processing {sample.get('audio_path')}: {e}")
            cached_predictions.append({
                "audio_path": sample.get('audio_path', 'unknown'),
                "error": str(e)
            })
    
    # Evaluate each threshold using cached predictions
    print(f"\nEvaluating {len(thresholds)} different thresholds...")
    best_f1 = -1
    best_threshold_info = None
    
    for threshold in tqdm(thresholds, desc="Testing thresholds"):
        tp = tn = fp = fn = 0
        
        for pred in cached_predictions:
            if "error" in pred:
                continue
                
            true_lbl = pred["true_label_idx"]
            p_translate = pred["p_translate"]
            pred_lbl = TRANS_IDX if p_translate >= threshold else WAIT_IDX
            
            if pred_lbl == TRANS_IDX and true_lbl == TRANS_IDX:
                tp += 1
            elif pred_lbl == WAIT_IDX and true_lbl == WAIT_IDX:
                tn += 1
            elif pred_lbl == TRANS_IDX and true_lbl == WAIT_IDX:
                fp += 1
            elif pred_lbl == WAIT_IDX and true_lbl == TRANS_IDX:
                fn += 1
        
        metrics = calculate_metrics(tp, tn, fp, fn)
        
        threshold_result = {
            "threshold": float(threshold),
            "confusion_matrix": {
                "tp": tp,
                "tn": tn,
                "fp": fp,
                "fn": fn
            },
            "metrics": {k: float(v) for k, v in metrics.items()},
            "processed_samples": sum([1 for p in cached_predictions if "error" not in p])
        }
        
        all_results["threshold_results"].append(threshold_result)
        
        # Track best threshold by F1 score
        if metrics["f1_score"] > best_f1:
            best_f1 = metrics["f1_score"]
            best_threshold_info = threshold_result
    
    # Set best threshold
    all_results["best_threshold"] = best_threshold_info
    
    # Create summary statistics
    all_accuracies = [r["metrics"]["accuracy"] for r in all_results["threshold_results"]]
    all_f1_scores = [r["metrics"]["f1_score"] for r in all_results["threshold_results"]]
    all_precisions = [r["metrics"]["precision"] for r in all_results["threshold_results"]]
    all_recalls = [r["metrics"]["recall"] for r in all_results["threshold_results"]]
    
    all_results["summary"] = {
        "accuracy_range": [float(min(all_accuracies)), float(max(all_accuracies))],
        "f1_range": [float(min(all_f1_scores)), float(max(all_f1_scores))],
        "precision_range": [float(min(all_precisions)), float(max(all_precisions))],
        "recall_range": [float(min(all_recalls)), float(max(all_recalls))],
        "best_f1_threshold": float(best_threshold_info["threshold"]),
        "best_f1_score": float(best_f1),
        "total_runtime_seconds": float(time.perf_counter() - overall_start)
    }
    
    # Save results to JSON
    output_path = Path(OUTPUT_JSON_PATH)
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(all_results, f, indent=2, ensure_ascii=False)
    
    print(f"\n{'='*60}")
    print("EVALUATION COMPLETE")
    print(f"{'='*60}")
    print(f"Results saved to: {output_path.absolute()}")
    print(f"\nBest threshold (by F1): {best_threshold_info['threshold']:.2f}")
    print(f"  - F1 Score: {best_threshold_info['metrics']['f1_score']:.4f}")
    print(f"  - Accuracy: {best_threshold_info['metrics']['accuracy']:.4f}")
    print(f"  - Precision: {best_threshold_info['metrics']['precision']:.4f}")
    print(f"  - Recall: {best_threshold_info['metrics']['recall']:.4f}")
    print(f"\nTotal runtime: {time.perf_counter() - overall_start:.2f} seconds")
    
    # Print a quick summary table
    print(f"\n{'='*60}")
    print("THRESHOLD COMPARISON TABLE")
    print(f"{'='*60}")
    print(f"{'Threshold':<10} {'Accuracy':<10} {'F1':<10} {'Precision':<10} {'Recall':<10}")
    print("-" * 60)
    for result in all_results["threshold_results"]:
        t = result["threshold"]
        m = result["metrics"]
        print(f"{t:<10.2f} {m['accuracy']:<10.4f} {m['f1_score']:<10.4f} {m['precision']:<10.4f} {m['recall']:<10.4f}")


if __name__ == "__main__":
    evaluate_gate_model_multi_threshold()

  from .autonotebook import tqdm as notebook_tqdm


[2025-08-28 12:10:11,567] [INFO] [real_accelerator.py:260:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2025-08-28 12:10:13,061] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False


Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'mrope_section'}


Loading model from: /raid/vladimir_albrekht/projects/conaiki/qwen_omni/conaiki/qwen_omni_finetuned_common_voice_for_qwen_train_less_than_3_sec/final_model_1_epoch ...


Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00,  7.91it/s]


Model and processor loaded.
Loading evaluation data from: /raid/vladimir_albrekht/projects/conaiki/qwen_omni/conaiki/data/common_voice_for_qwen/less_than_3_sec/processed/streaming_chunks_padded.jsonl
Loaded 4793 samples for evaluation.

Generating predictions for all samples...


Computing predictions: 100%|██████████| 4793/4793 [10:01<00:00,  7.96it/s]



Evaluating 90 different thresholds...


Testing thresholds: 100%|██████████| 90/90 [00:00<00:00, 803.92it/s]


EVALUATION COMPLETE
Results saved to: /raid/vladimir_albrekht/projects/conaiki/qwen_omni/conaiki/evaluation/gate_threshold_evaluation_results.json

Best threshold (by F1): 0.10
  - F1 Score: 0.3953
  - Accuracy: 0.5838
  - Precision: 0.2685
  - Recall: 0.7486

Total runtime: 607.37 seconds

THRESHOLD COMPARISON TABLE
Threshold  Accuracy   F1         Precision  Recall    
------------------------------------------------------------
0.10       0.5838     0.3953     0.2685     0.7486    
0.11       0.5884     0.3950     0.2695     0.7394    
0.12       0.5904     0.3928     0.2688     0.7290    
0.13       0.5927     0.3911     0.2685     0.7199    
0.14       0.5944     0.3883     0.2674     0.7084    
0.15       0.6003     0.3875     0.2685     0.6958    
0.16       0.6005     0.3836     0.2665     0.6843    
0.17       0.6038     0.3828     0.2670     0.6762    
0.18       0.6082     0.3847     0.2691     0.6739    
0.19       0.6090     0.3799     0.2669     0.6590    
0.20       0.6




In [None]:
threshold = 0.5 

--- Evaluation Complete ---
Overall Accuracy: 54.83% (2628 / 4793)

Confusion Matrix (Positive = TRANSLATE):
  TP:  428   TN: 2200   FP: 1722   FN:  443

Metrics (TRANSLATE):
  Precision: 0.1991
  Recall:    0.4914
  F1-Score:  0.2833

Total runtime: 616.64 s