In [7]:
# ‚ö†Ô∏è RUN THIS CELL TO RESET RUNTIME (only if you have dependency issues)
# This disconnects and deletes the runtime - you'll need to reconnect after
from google.colab import runtime
runtime.unassign()

# LahStats - ML Server (Transcription + Diarization)

Runs MERaLiON (transcription) and pyannote (speaker diarization) as a persistent API server.

**Setup:**
1. Run **Cell 2** (Install dependencies) - uses Colab's pre-installed torch
2. **Restart Runtime** (Runtime ‚Üí Restart runtime)
3. Run **Cell 3** onwards (skip Cell 1 & 2 after restart)
4. Set your tokens in **Cell 4** (ngrok + HuggingFace)
5. Copy the ngrok URL to your backend .env
6. Keep this notebook running during the demo

**‚ö†Ô∏è Important:** Do NOT uninstall Colab's pre-installed torch/torchvision - it causes version mismatch errors!

**Endpoints:**
- `GET /health` - Check server status
- `POST /transcribe` - Transcribe audio
- `POST /diarize` - Speaker diarization (who spoke when)

In [None]:
# Install dependencies - LOCK Colab's torch versions to prevent downgrades!
# pyannote.audio dependencies try to downgrade torch, breaking torchvision.

import subprocess

# Step 1: Get current torch versions and create constraints file
print("Step 1: Locking PyTorch versions...")
result = subprocess.run(['pip', 'show', 'torch'], capture_output=True, text=True)
torch_version = [line.split(': ')[1] for line in result.stdout.split('\n') if line.startswith('Version:')][0]
print(f"  torch=={torch_version}")

result = subprocess.run(['pip', 'show', 'torchvision'], capture_output=True, text=True)
torchvision_version = [line.split(': ')[1] for line in result.stdout.split('\n') if line.startswith('Version:')][0]
print(f"  torchvision=={torchvision_version}")

result = subprocess.run(['pip', 'show', 'torchaudio'], capture_output=True, text=True)
torchaudio_version = [line.split(': ')[1] for line in result.stdout.split('\n') if line.startswith('Version:')][0]
print(f"  torchaudio=={torchaudio_version}")

# Create constraints file to prevent torch downgrades
with open('/tmp/torch_constraints.txt', 'w') as f:
    f.write(f"torch=={torch_version}\n")
    f.write(f"torchvision=={torchvision_version}\n")
    f.write(f"torchaudio=={torchaudio_version}\n")
print("  ‚úÖ Constraints file created")

# Step 2: Install pyannote.audio with constraints (prevents torch downgrade)
print("\nStep 2: Installing pyannote.audio (with torch locked)...")
!pip install -q pyannote.audio -c /tmp/torch_constraints.txt 2>&1 | tail -5

# Step 3: Install transformers and other deps
print("\nStep 3: Installing transformers and other dependencies...")
!pip install -q "transformers>=4.40.0,<4.50.0" accelerate librosa flask pyngrok -c /tmp/torch_constraints.txt 2>&1 | tail -3

# Step 4: Verify torch versions weren't changed
print("\nStep 4: Verifying torch versions...")
!pip show torch torchvision 2>/dev/null | grep -E "^(Name|Version):"

# Step 5: Test pyannote import
print("\nStep 5: Testing pyannote import...")
!python -c "from pyannote.audio import Pipeline; print('‚úÖ pyannote.audio imported successfully')" 2>&1 || echo "‚ùå pyannote import failed - check errors above"

# Restart message
print("\n" + "="*60)
print("‚úÖ Dependencies installed!")
print("‚ö†Ô∏è  NOW RESTART RUNTIME:")
print("   Runtime ‚Üí Restart runtime")
print("   Then run Cell 3 onwards (skip this cell)")
print("="*60)

Checking pre-installed PyTorch versions...
Name: torch
Version: 2.9.1
Name: torchaudio
Version: 2.9.1

Installing pyannote.audio and dependencies...

Installing transformers and other dependencies...

‚úÖ Dependencies installed!
‚ö†Ô∏è  NOW RESTART RUNTIME:
   Runtime ‚Üí Restart runtime
   Then run Cell 3 onwards (skip this cell)


In [1]:
# Check GPU availability (run after restart)
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print(f"Torch version: {torch.__version__}")

CUDA available: True
GPU: NVIDIA H100 80GB HBM3
Memory: 85.2 GB
Torch version: 2.9.1+cu128


In [2]:
# Set tokens - paste your tokens here (this notebook is gitignored)
import os

# Ngrok token for public URL
os.environ["NGROK_AUTHTOKEN"] = "YOUR_NGROK_TOKEN_HERE"  # <-- Paste your ngrok token here

# HuggingFace token for pyannote (get from https://huggingface.co/settings/tokens)
# Also accept pyannote license at https://huggingface.co/pyannote/speaker-diarization-3.1
os.environ["HUGGINGFACE_TOKEN"] = "YOUR_HUGGINGFACE_TOKEN_HERE"  # <-- Paste your HF token here

print("NGROK_AUTHTOKEN:", "‚úÖ Set" if os.environ.get("NGROK_AUTHTOKEN") else "‚ùå Missing")
print("HUGGINGFACE_TOKEN:", "‚úÖ Set" if os.environ.get("HUGGINGFACE_TOKEN") else "‚ùå Missing (needed for diarization)")

NGROK_AUTHTOKEN: ‚úÖ Set
HUGGINGFACE_TOKEN: ‚úÖ Set


In [3]:
# Load MERaLiON-2-3B-ASR model (smaller, fits on most GPUs)
# Use 3B instead of 10B to avoid OOM errors - matches backend service
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
import torch
import gc

# Clear CUDA cache before loading
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    gc.collect()

MODEL_NAME = "MERaLiON/MERaLiON-2-3B"  # Changed from 10B to 3B to avoid OOM

print("Loading processor...")
processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)

print("Loading model (this may take a few minutes)...")
print(f"Using model: {MODEL_NAME}")

# Check GPU memory
if torch.cuda.is_available():
    total_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU Memory: {total_mem:.1f} GB")
    
    # Load with float16 for GPU
    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        attn_implementation="eager",
        low_cpu_mem_usage=True,
        device_map="auto",  # Automatic device placement
    )
    print(f"Model loaded on GPU (float16)")
else:
    # CPU fallback
    print("No GPU available, loading for CPU...")
    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float32,
        trust_remote_code=True,
        attn_implementation="eager",
    )
    print("Model loaded on CPU")

model.eval()  # Set to evaluation mode

print(f"Model loaded on {next(model.parameters()).device}!")

Loading processor...


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


processor_config.json:   0%|          | 0.00/281 [00:00<?, ?B/s]

processing_meralion2.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/MERaLiON/MERaLiON-2-3B:
- processing_meralion2.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


preprocessor_config.json:   0%|          | 0.00/443 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/34.4M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

Loading model (this may take a few minutes)...
Using model: MERaLiON/MERaLiON-2-3B
GPU Memory: 85.2 GB


config.json: 0.00B [00:00, ?B/s]

configuration_meralion2.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/MERaLiON/MERaLiON-2-3B:
- configuration_meralion2.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling_meralion2.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/MERaLiON/MERaLiON-2-3B:
- modeling_meralion2.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.95G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/197 [00:00<?, ?B/s]

Model loaded on GPU (float16)
Model loaded on cuda:0!


In [6]:
# Load pyannote speaker diarization model
import os
import sys
import torch
import torch.serialization
import importlib

# PyTorch 2.6+ fix: pyannote checkpoints aren't compatible with weights_only=True
# We MUST patch before pyannote imports torch.load

# Step 1: Remove any cached pyannote imports so they pick up our patch
pyannote_modules = [key for key in sys.modules.keys() if 'pyannote' in key]
for mod in pyannote_modules:
    del sys.modules[mod]
print(f"üîÑ Cleared {len(pyannote_modules)} cached pyannote modules")

# Step 2: Patch torch.load at the lowest level
_original_torch_load = torch.serialization.load.__wrapped__ if hasattr(torch.serialization.load, '__wrapped__') else torch.serialization.load

def _patched_load(f, map_location=None, pickle_module=None, *, weights_only=False, **kwargs):
    # Force weights_only=False for pyannote compatibility
    return _original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=False, **kwargs)

torch.serialization.load = _patched_load
torch.load = _patched_load
print("‚úÖ Patched torch.load with weights_only=False")

# Step 3: Now import pyannote fresh (will use patched torch.load)
from pyannote.audio import Pipeline
print("‚úÖ Imported pyannote.audio")

HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", "")

if HF_TOKEN:
    print("Loading pyannote speaker diarization model...")
    try:
        # Note: newer pyannote uses 'token' instead of deprecated 'use_auth_token'
        diarization_pipeline = Pipeline.from_pretrained(
            "pyannote/speaker-diarization-3.1",
            token=HF_TOKEN
        )
        # Move to GPU if available
        if torch.cuda.is_available():
            diarization_pipeline = diarization_pipeline.to(torch.device("cuda"))
            print("‚úÖ Diarization model loaded on GPU")
        else:
            print("‚úÖ Diarization model loaded on CPU")
    except Exception as e:
        print(f"‚ùå Failed to load diarization model: {e}")
        print("   Make sure you accepted the license at:")
        print("   https://huggingface.co/pyannote/speaker-diarization-3.1")
        diarization_pipeline = None
else:
    print("‚ö†Ô∏è  HUGGINGFACE_TOKEN not set - diarization will be disabled")
    print("   Set it in Cell 4 to enable speaker diarization")
    diarization_pipeline = None

ModuleNotFoundError: No module named 'torch_audiomentations'

In [None]:
# Transcription function
import numpy as np

def transcribe(audio_data, sample_rate=16000):
    """Transcribe audio using MERaLiON."""
    # Ensure float32 numpy array
    if not isinstance(audio_data, np.ndarray):
        audio_data = np.array(audio_data)
    audio_data = audio_data.astype(np.float32)
    
    # Chat-style prompt for MERaLiON
    prompt_template = "Instruction: {query} \nFollow the text instruction based on the following audio: <SpeechHere>"
    transcribe_prompt = """Transcribe this Singlish speech using romanized text only. 
Do NOT use Chinese characters. 
Write Singlish words in romanized form: walao, shiok, lah, leh, lor, sia, paiseh, sian, etc.
Output format: Speaker labels with romanized transcription."""
    
    conversation = [[{"role": "user", "content": prompt_template.format(query=transcribe_prompt)}]]
    chat_prompt = processor.tokenizer.apply_chat_template(
        conversation=conversation,
        tokenize=False,
        add_generation_prompt=True
    )
    
    # Process inputs
    inputs = processor(text=chat_prompt, audios=[audio_data])
    
    # Move to device
    device = next(model.parameters()).device
    dtype = next(model.parameters()).dtype
    
    def move_to_device(v):
        if not hasattr(v, 'to'):
            return v
        v = v.to(device)
        if v.is_floating_point():
            v = v.to(dtype)
        return v
    
    inputs = {k: move_to_device(v) for k, v in inputs.items()}
    
    # Generate
    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_new_tokens=256)
    
    # Decode
    transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    # Clean up: extract only the model's response (matches backend _clean_model_output)
    if "model\n" in transcription:
        transcription = transcription.split("model\n", 1)[-1]
    
    # Remove speaker markers like <Speaker1>:, <Speaker2>:, etc.
    import re
    transcription = re.sub(r'<Speaker\d+>:\s*', '', transcription)
    # Remove <SpeechHere> tags
    transcription = re.sub(r'<SpeechHere>', '', transcription)
    # Clean bracketed words: !(walao)! -> walao, (lah) -> lah
    transcription = re.sub(r'!\(([^)]+)\)!', r'\1', transcription)
    transcription = re.sub(r'\(([a-zA-Z]+)\)', r'\1', transcription)
    # Remove filler markers
    transcription = re.sub(r'\(err\)', '', transcription, flags=re.IGNORECASE)
    transcription = re.sub(r'\(uh\)', '', transcription, flags=re.IGNORECASE)
    transcription = re.sub(r'\(um\)', '', transcription, flags=re.IGNORECASE)
    # Clean up extra whitespace
    transcription = re.sub(r'\s+', ' ', transcription).strip()
    
    return transcription

print("Transcription function ready!")

In [None]:
# Post-processing functions
import re
from typing import Dict

CORRECTIONS = {
    # Walao variations
    'while up': 'walao', 'wah lao eh': 'walao', 'wa lao eh': 'walao',
    'wah lao': 'walao', 'wa lao': 'walao', 'wah low': 'walao',
    'wa low': 'walao', 'while ah': 'walao', 'wah lau': 'walao',
    'wa lau': 'walao', 'wah liao': 'walao', 'wa liao': 'walao',
    'while low': 'walao', 'wah lei': 'walao', 'why lao': 'walao',
    'why low': 'walao', 'wah la': 'walao',
    # Vulgar - cheebai
    'cheap buy': 'cheebai', 'chee bye': 'cheebai', 'chi bye': 'cheebai',
    'chee bai': 'cheebai', 'chi bai': 'cheebai', 'chee by': 'cheebai',
    'chi by': 'cheebai', 'cb': 'cheebai', 'c b': 'cheebai', 'see bee': 'cheebai',
    # Vulgar - lanjiao
    'lunch hour': 'lanjiao', 'lan jiao': 'lanjiao', 'lan chow': 'lanjiao',
    'lan chiao': 'lanjiao', 'lun jiao': 'lanjiao', 'lan chio': 'lanjiao',
    'lunchow': 'lanjiao', 'lan jio': 'lanjiao',
    # Vulgar - kanina
    'can nina': 'kanina', 'kar ni na': 'kanina', 'ka ni na': 'kanina',
    'car nina': 'kanina', 'knn': 'kanina', 'k n n': 'kanina',
    # Vulgar - nabei
    'nah bay': 'nabei', 'na bei': 'nabei', 'nah bei': 'nabei',
    'na beh': 'nabei', 'nah beh': 'nabei',
    # Paiseh
    'pie say': 'paiseh', 'pai seh': 'paiseh', 'pie seh': 'paiseh',
    'pai say': 'paiseh', 'pie se': 'paiseh', 'pai se': 'paiseh', 'paise': 'paiseh',
    # Shiok
    'shook': 'shiok', 'she ok': 'shiok', 'shoe ok': 'shiok', 'shi ok': 'shiok',
    # Alamak
    'ala mak': 'alamak', 'allah mak': 'alamak', 'a la mak': 'alamak',
    'allamak': 'alamak', 'aller mak': 'alamak',
    # Aiyo/Aiyah
    'ai yo': 'aiyo', 'ai yoh': 'aiyo', 'aiya': 'aiyah', 'ai ya': 'aiyah',
    'eye yo': 'aiyo', 'aye yo': 'aiyo', 'ai yah': 'aiyah', 'eye yah': 'aiyah',
    # Jialat
    'jia lat': 'jialat', 'gia lat': 'jialat', 'jia lut': 'jialat',
    'jee ah lat': 'jialat', 'gia lut': 'jialat',
    # Bojio
    'bo jio': 'bojio', 'boh jio': 'bojio', 'bo gio': 'bojio',
    'never jio': 'bojio', 'boh gio': 'bojio',
    # Sia/Sian
    'see ya': 'sia', 'see ah': 'sia', 'siah': 'sia', 'si ah': 'sia',
    'see an': 'sian', 'si an': 'sian', 'see en': 'sian', 'si en': 'sian',
    # Other
    'kia su': 'kiasu', 'key ah su': 'kiasu', 'kia si': 'kiasi', 'key ah si': 'kiasi',
    'boh doh': 'bodoh', 'bo doh': 'bodoh', 'sua ku': 'suaku', 'swah ku': 'suaku',
    'le pak': 'lepak', 'lay pak': 'lepak', 'chop': 'chope', 'ma kan': 'makan',
    'go stan': 'gostan', 'go stun': 'gostan', 'si bei': 'sibei', 'see bay': 'sibei',
    'si bay': 'sibei', 'ah tas': 'atas', 'ar tas': 'atas', 'kay poh': 'kaypoh',
    'kae poh': 'kaypoh', 'kaypo': 'kaypoh', 'kpo': 'kaypoh',
    'steady pom pi pi': 'steady', 'goon du': 'goondu', 'gun du': 'goondu',
}

WORD_CORRECTIONS = {
    'la': 'lah', 'laa': 'lah', 'laaa': 'lah',
    'low': 'lor', 'loh': 'lor',
    # 'leh' is a distinct particle - don't convert to 'lah'
    'ler': 'lah',
    'seh': 'sia',
    'arh': 'ah',
    'err': 'eh',
    'shio': 'shiok',
}

TARGET_WORDS = [
    # Vulgar
    'walao', 'cheebai', 'lanjiao', 'kanina', 'nabei',
    # Particles
    'lah', 'lor', 'sia', 'meh', 'leh', 'hor', 'ah', 'one', 'what', 'lei', 'ma',
    # Exclamations
    'wah', 'eh', 'huh', 'aiyo', 'aiyah', 'alamak',
    # Colloquial
    'can', 'cannot', 'paiseh', 'shiok', 'sian', 'bodoh', 'kiasu', 'kiasi',
    'bojio', 'suaku', 'lepak', 'blur', 'goondu', 'cheem', 'chim',
    # Actions
    'chope', 'kena', 'makan', 'tahan', 'gostan', 'cabut', 'sabo', 'arrow',
    # Intensifiers
    'sibei', 'buay', 'jialat',
    # Food/Drink
    'kopi', 'teh', 'peng',
    # Misc
    'atas', 'kaypoh', 'steady', 'power', 'liao',
]

def apply_corrections(text: str) -> str:
    if not text:
        return text
    result = text
    for wrong, correct in sorted(CORRECTIONS.items(), key=lambda x: len(x[0]), reverse=True):
        pattern = re.compile(re.escape(wrong), re.IGNORECASE)
        result = pattern.sub(correct, result)
    for wrong, correct in WORD_CORRECTIONS.items():
        pattern = re.compile(r'\b' + re.escape(wrong) + r'\b', re.IGNORECASE)
        result = pattern.sub(correct, result)
    return result

def count_target_words(text: str) -> Dict[str, int]:
    if not text:
        return {}
    normalized = text.lower()
    counts = {}
    for word in TARGET_WORDS:
        pattern = re.compile(r'(?<![a-zA-Z])' + re.escape(word) + r'(?![a-zA-Z])', re.IGNORECASE)
        matches = pattern.findall(normalized)
        if matches:
            counts[word] = len(matches)
    return counts

print("Post-processing functions ready!")

In [None]:
# üî• MANUALLY KILL NGROK SESSIONS (Run this if ngrok errors persist)
# This cell helps you kill ngrok sessions that pyngrok.kill() can't reach

import subprocess
import os

print("=" * 60)
print("KILLING ALL NGROK SESSIONS")
print("=" * 60)

# Method 1: Kill via pyngrok
try:
    from pyngrok import ngrok
    ngrok.kill()
    print("‚úÖ Killed via pyngrok")
except Exception as e:
    print(f"‚ùå pyngrok.kill() failed: {e}")

# Method 2: Kill processes (Linux/Mac/Colab)
try:
    result = subprocess.run(['pkill', '-9', '-f', 'ngrok'], 
                           capture_output=True, text=True, timeout=5)
    if result.returncode == 0:
        print("‚úÖ Killed ngrok processes via pkill")
    else:
        print("‚ÑπÔ∏è  No ngrok processes found (or pkill not available)")
except Exception as e:
    print(f"‚ÑπÔ∏è  pkill failed (normal on Windows/Colab): {e}")

# Method 3: Try psutil if available
try:
    import psutil
    killed = 0
    for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
        try:
            cmdline = ' '.join(proc.info['cmdline'] or [])
            if 'ngrok' in cmdline.lower():
                proc.kill()
                killed += 1
                print(f"‚úÖ Killed process PID {proc.info['pid']}")
        except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
            pass
    if killed == 0:
        print("‚ÑπÔ∏è  No ngrok processes found via psutil")
except ImportError:
    print("‚ÑπÔ∏è  psutil not installed (optional)")
except Exception as e:
    print(f"‚ÑπÔ∏è  psutil failed: {e}")

print("\n" + "=" * 60)
print("‚ö†Ô∏è  IF ERRORS STILL PERSIST:")
print("=" * 60)
print("1. Visit: https://dashboard.ngrok.com/agents")
print("2. Click 'Stop' on ALL active sessions")
print("3. Wait 10 seconds")
print("4. Then re-run the ngrok cell below")
print("=" * 60)

In [None]:
# Start Flask API server with transcription + diarization
from flask import Flask, request, jsonify
import librosa
import io
import base64
import threading
import tempfile
import os

app = Flask(__name__)

@app.route('/health', methods=['GET'])
def health():
    diarization_status = "available" if diarization_pipeline else "disabled"
    return jsonify({
        "status": "ok", 
        "model": "MERaLiON-2-3B",
        "diarization": diarization_status
    })

@app.route('/transcribe', methods=['POST'])
def transcribe_endpoint():
    try:
        # Accept audio as base64 or file upload
        if request.is_json:
            data = request.get_json()
            audio_b64 = data.get('audio')
            audio_bytes = base64.b64decode(audio_b64)
        else:
            audio_file = request.files.get('audio')
            audio_bytes = audio_file.read()
        
        # Load audio
        audio_data, sr = librosa.load(io.BytesIO(audio_bytes), sr=16000)
        
        # Transcribe
        raw_text = transcribe(audio_data)
        
        # Post-process
        corrected = apply_corrections(raw_text)
        counts = count_target_words(corrected)
        
        return jsonify({
            "raw_transcription": raw_text,
            "corrected": corrected,
            "word_counts": counts,
            "total_singlish_words": sum(counts.values())
        })
    except Exception as e:
        return jsonify({"error": str(e)}), 500

@app.route('/diarize', methods=['POST'])
def diarize_endpoint():
    """Speaker diarization - segments audio by who spoke when."""
    if not diarization_pipeline:
        return jsonify({
            "error": "Diarization not available. Set HUGGINGFACE_TOKEN in cell-3."
        }), 503
    
    temp_path = None
    try:
        # Accept audio as base64 or file upload
        if request.is_json:
            data = request.get_json()
            audio_b64 = data.get('audio')
            audio_bytes = base64.b64decode(audio_b64)
        else:
            audio_file = request.files.get('audio')
            audio_bytes = audio_file.read()
        
        # Save to temp file (pyannote needs file path)
        with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
            temp_path = f.name
            f.write(audio_bytes)
        
        # Run diarization
        diarization_result = diarization_pipeline(temp_path)
        
        # Convert to list of segments
        segments = []
        for turn, _, speaker in diarization_result.itertracks(yield_label=True):
            segments.append({
                "speaker_id": speaker,
                "start_time": round(turn.start, 3),
                "end_time": round(turn.end, 3),
                "duration": round(turn.end - turn.start, 3)
            })
        
        # Get unique speakers
        speakers = list(set(s["speaker_id"] for s in segments))
        
        return jsonify({
            "segments": segments,
            "speakers": speakers,
            "num_speakers": len(speakers),
            "num_segments": len(segments)
        })
    except Exception as e:
        import traceback
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500
    finally:
        # Clean up temp file
        if temp_path and os.path.exists(temp_path):
            try:
                os.remove(temp_path)
            except:
                pass

# Run Flask in background thread
# Try different ports if 5000 is in use
import socket

def find_free_port(start_port=5000):
    for port in range(start_port, start_port + 10):
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            if s.connect_ex(('localhost', port)) != 0:
                return port
    return None

PORT = find_free_port(5000)
if PORT is None:
    raise RuntimeError("Could not find a free port")

print(f"Starting Flask server on port {PORT}...")
threading.Thread(target=lambda: app.run(host='0.0.0.0', port=PORT, use_reloader=False)).start()
print(f"Flask server started on port {PORT}!")
print(f"\nEndpoints available:")
print(f"  GET  /health    - Check server status")
print(f"  POST /transcribe - Transcribe audio")
print(f"  POST /diarize   - Speaker diarization")

In [None]:
# Expose server via ngrok
import os
from pyngrok import ngrok

# Kill any existing ngrok tunnels first (free tier allows only 1 session)
import subprocess
import time

# Method 1: Use pyngrok kill
try:
    ngrok.kill()
    print("Killed existing ngrok sessions (pyngrok)")
    time.sleep(1)  # Wait a bit
except Exception as e:
    print(f"pyngrok.kill() failed: {e}")

# Method 2: Kill ngrok processes directly
try:
    # Find and kill all ngrok processes
    result = subprocess.run(['pkill', '-f', 'ngrok'], capture_output=True, text=True)
    if result.returncode == 0:
        print("Killed ngrok processes via pkill")
    time.sleep(1)
except Exception as e:
    print(f"pkill failed (might be Windows/Colab): {e}")
    # Try Windows/alternative method
    try:
        import psutil
        for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
            try:
                if 'ngrok' in ' '.join(proc.info['cmdline'] or []).lower():
                    proc.kill()
                    print(f"Killed ngrok process PID {proc.info['pid']}")
            except (psutil.NoSuchProcess, psutil.AccessDenied):
                pass
        time.sleep(1)
    except ImportError:
        print("psutil not available, skipping process kill")
    except Exception as e:
        print(f"Process kill failed: {e}")

print("\n‚ö†Ô∏è  If errors persist, manually kill ngrok:")
print("  1. Visit: https://dashboard.ngrok.com/agents")
print("  2. Stop all active sessions")
print("  3. Or run: ngrok kill (if ngrok CLI installed)")
print("  4. Or run: pkill -f ngrok (Linux/Mac)")
print("  5. Or run: taskkill /F /IM ngrok.exe (Windows)")

# Set authtoken from environment or paste directly
NGROK_TOKEN = os.environ.get("NGROK_AUTHTOKEN", "YOUR_TOKEN_HERE")
ngrok.set_auth_token(NGROK_TOKEN)

# Start tunnel (use PORT variable from previous cell, or default to 5000)
# Check if PORT is defined, otherwise use default or try to detect Flask port
import socket

if 'PORT' not in globals():
    # Try to find which port Flask is running on
    flask_port = None
    for port in range(5000, 5010):
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            if s.connect_ex(('localhost', port)) == 0:
                flask_port = port
                break
    
    if flask_port:
        PORT = flask_port
        print(f"Detected Flask running on port {PORT}")
    else:
        PORT = 5000
        print(f"PORT not defined, using default port {PORT}")
        print("‚ö†Ô∏è  Make sure Flask server cell was run first!")
else:
    print(f"Using PORT={PORT} from Flask server cell")

try:
    public_url = ngrok.connect(PORT).public_url
except Exception as e:
    error_msg = str(e)
    print(f"\n{'='*60}")
    print(f"‚ùå NGROK CONNECTION FAILED")
    print(f"{'='*60}")
    print(f"Error: {error_msg}")
    print(f"Tried to connect to port {PORT}")
    print(f"\n{'='*60}")
    print("üîß TO FIX THIS:")
    print(f"{'='*60}")
    print("1. Run the cell ABOVE this one (the kill helper cell)")
    print("2. OR manually visit: https://dashboard.ngrok.com/agents")
    print("3. Stop ALL active ngrok sessions")
    print("4. Wait 10 seconds")
    print("5. Re-run THIS cell")
    print(f"{'='*60}")
    print("\nüí° The free ngrok tier only allows 1 session at a time.")
    print("   You must kill the existing session before starting a new one.")
    print(f"{'='*60}\n")
    
    # Don't raise - let user fix it and retry
    public_url = None

if public_url:
    print(f"\n{'='*60}")
    print(f"‚úÖ ML API READY!")
    print(f"{'='*60}")
    print(f"\nPublic URL: {public_url}")
    print(f"\nSet in your backend .env:")
    print(f"TRANSCRIPTION_API_URL={public_url}")
    print(f"DIARIZATION_API_URL={public_url}")
    print(f"\nEndpoints:")
    print(f"  GET  {public_url}/health     - Check server status")
    print(f"  POST {public_url}/transcribe - Transcribe audio")
    print(f"  POST {public_url}/diarize    - Speaker diarization")
else:
    print("\n‚ùå ngrok tunnel not created. Fix the error above and re-run this cell.")
print(f"\nKeep this notebook running!")
print(f"{'='*60}")

In [None]:
  !nvidia-smi