In [2]:
# %%
import torch
import sounddevice as sd
import numpy as np
from faster_whisper import WhisperModel
import time
from collections import deque
import threading
import json
import io

# =============================
# CONFIGURATION
# =============================
TRANSCRIPTION_MODE = "vosk"  # Options: "vosk" (fast partials) or "whisper"
PROJECT_ROOT = r"D:\Work\Projects\AI\interactive-chat-ai"

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# =============================
# LOAD MODELS (ALWAYS LOAD WHISPER FOR FINAL TRANSCRIPTION)
# =============================
print("Loading Silero VAD...")
vad_model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False)

print("Loading Whisper (for final transcription)...")
whisper = WhisperModel(
    "small.en",  # English-only = faster + more accurate
    device="cuda" if torch.cuda.is_available() else "cpu",
    compute_type="int8"
)

# Load Vosk only if needed
vosk_model = None
vosk_rec = None
if TRANSCRIPTION_MODE == "vosk":
    from vosk import Model, KaldiRecognizer
    print("Loading Vosk...")
    vosk_model = Model("models/vosk-model-small-en-us-0.15")
    vosk_rec = KaldiRecognizer(vosk_model, 16000)
    vosk_rec.SetWords(True)

print(f"ASR mode: {TRANSCRIPTION_MODE}")



Loading Silero VAD...


Using cache found in C:\Users\PC/.cache\torch\hub\snakers4_silero-vad_master


Loading Whisper (for final transcription)...
Loading Vosk...
ASR mode: vosk


In [None]:
# =============================
# LLM + TTS LOADING (SIMPLIFIED FOR WINDOWS)
# =============================
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
import piper

# Use TinyLlama instead of Phi-3
LLM_PATH = os.path.join(PROJECT_ROOT, "models", "llm", "tinyllama")

_llm_model = None
_llm_tokenizer = None

def get_llm():
    global _llm_model, _llm_tokenizer
    if _llm_model is None:
        print("‚è≥ Loading TinyLlama (1.1B)...")
        _llm_model = AutoModelForCausalLM.from_pretrained(
            LLM_PATH,
            device_map="cpu",  # CPU-only for stability
            torch_dtype=torch.float32  # Avoid float16 on CPU
        )
        _llm_tokenizer = AutoTokenizer.from_pretrained(LLM_PATH)
        print("‚úÖ LLM loaded!")
    return _llm_model, _llm_tokenizer

In [5]:
# =============================
# AUDIO SETUP
# =============================
SAMPLE_RATE = 16000
audio_buffer = []
VOSK_MIN_SAMPLES = 3200  # 0.2 sec @ 16kHz

def audio_callback(indata, frames, time, status):
    audio_buffer.append(indata.copy())

stream = sd.InputStream(samplerate=SAMPLE_RATE, channels=1, callback=audio_callback)


# =============================
# ASR WORKER (STREAMING PARTIALS)
# =============================
asr_audio = deque()      # For streaming partials (trimmed)
turn_audio = deque()     # For final transcription (full turn)
asr_lock = threading.Lock()
turn_audio_lock = threading.Lock()
current_partial_text = ""
vosk_reset_requested = False

def float32_to_int16(audio):
    audio = np.clip(audio, -1.0, 1.0)
    return (audio * 32767).astype(np.int16)

def asr_worker():
    global current_partial_text, vosk_reset_requested
    WHISPER_WINDOW_SEC = 1.2

    while True:
        time.sleep(0.05 if TRANSCRIPTION_MODE == "vosk" else 0.7)

        if TRANSCRIPTION_MODE == "whisper":
            with asr_lock:
                if not asr_audio:
                    continue
                now = time.time()
                recent = [frame for frame, t in asr_audio if now - t <= WHISPER_WINDOW_SEC]
            if not recent:
                continue
            audio_np = np.concatenate(recent)
            segments, _ = whisper.transcribe(
                audio_np, language="en", vad_filter=False, beam_size=1, temperature=0.0
            )
            text = " ".join(seg.text for seg in segments).strip()
            if text and text != current_partial_text:
                current_partial_text = text
                print("üìù Partial:", text)

        else:  # Vosk mode
            if vosk_reset_requested:
                vosk_rec.Reset()
                vosk_reset_requested = False
                current_partial_text = ""
            with asr_lock:
                if not asr_audio:
                    continue
                frame, _ = asr_audio.popleft()
                if len(frame) < VOSK_MIN_SAMPLES:
                    continue
                pcm16 = float32_to_int16(frame)
            try:
                if vosk_rec.AcceptWaveform(pcm16.tobytes()):
                    res = json.loads(vosk_rec.Result())
                    text = res.get("text", "").strip()
                    if text:
                        print("üìù Final:", text)
                        current_partial_text = ""
                else:
                    res = json.loads(vosk_rec.PartialResult())
                    partial = res.get("partial", "").strip()
                    if partial and partial != current_partial_text:
                        current_partial_text = partial
                        print("üìù Partial:", partial)
            except Exception:
                continue

threading.Thread(target=asr_worker, daemon=True).start()
print("ASR worker started")

# =============================
# TURN-TAKING RULES
# =============================
TRAILING_CONJUNCTIONS = {"and","or","but","because","so","that","which","who","when","if","though","while"}
OPEN_ENDED_PREFIXES = ("i think","i guess","i'm not sure","the thing is","it depends")
QUESTION_LEADINS = ("do you think","would you say","is it possible","can you")
SELF_REPAIR_MARKERS = ("i mean","actually","sorry","no wait")
FILLER_ENDINGS = ("uh","um","like","you know","kind of")

def lexical_bias(text: str) -> float:
    if not text: return 0.0
    t = text.lower().strip()
    words = t.split()
    score = 0.0
    if words[-1] in TRAILING_CONJUNCTIONS: score -= 1.0
    if any(t.startswith(p) for p in OPEN_ENDED_PREFIXES): score -= 0.6
    if any(t.startswith(q) for q in QUESTION_LEADINS): score -= 0.5
    if any(m in t[-20:] for m in SELF_REPAIR_MARKERS): score -= 0.4
    if words[-1] in FILLER_ENDINGS: score -= 0.7
    return score

def energy_decay_score(energy_history):
    if len(energy_history) < 5: return 0.0
    x = np.arange(len(energy_history))
    y = np.array(energy_history)
    slope = np.polyfit(x, y, 1)[0]
    return 0.8 if slope < -0.00015 else 0.0

ASR worker started


In [6]:
# =============================
# WINDOWS-RELIABLE TTS (POWER SHELL)
# =============================
import subprocess
import queue
import threading

response_queue = queue.Queue()

def speak(text):
    """Speak text using Windows PowerShell (100% reliable on Win 10/11)"""
    safe_text = text.replace('"', '""').replace('\n', ' ').replace('\r', '')
    cmd = f'Add-Type -AssemblyName System.Speech; $s=New-Object System.Speech.Synthesis.SpeechSynthesizer; $s.Speak("{safe_text}")'
    try:
        subprocess.run(["powershell", "-Command", cmd],
                       stdout=subprocess.DEVNULL,
                       stderr=subprocess.DEVNULL,
                       timeout=10)
    except Exception as e:
        print(f"üîä Speech error: {e}")

def tts_main_loop():
    """Main thread TTS loop (never fails on Windows)"""
    while True:
        try:
            text = response_queue.get(timeout=0.1)
            print(f"üó£Ô∏è Speaking: '{text}'")
            speak(text)
        except queue.Empty:
            pass

# Start TTS loop in background (non-daemon = survives between turns)
threading.Thread(target=tts_main_loop, daemon=False).start()
print("‚úÖ PowerShell TTS initialized")

‚úÖ PowerShell TTS initialized


In [7]:
# =============================
# MAIN LOOP
# =============================
import tempfile
import os
import wave
import time
import re

# CONFIG
VAD_MIN_SAMPLES = 512
PAUSE_MS = 600
END_MS = 1200
SAFETY_TIMEOUT_MS = 2500
ENERGY_FLOOR = 0.015
WHISPER_WINDOW_SEC = 3.0
CONFIDENCE_THRESHOLD = 1.2

# STATE
state = "IDLE"
last_voice_time = None
last_ai_interrupted = False
vad_buffer = np.zeros(0, dtype=np.float32)
energy_history = deque(maxlen=15)
pause_history = deque(maxlen=5)
micro_spike_times = deque(maxlen=5)

stream.start()
print("üéôÔ∏è Real-time conversation test started")

try:
    while True:
        if not audio_buffer:
            time.sleep(0.01)
            continue

        # ---- COLLECT AUDIO CHUNK ----
        chunk = audio_buffer.pop(0).astype(np.float32).flatten()
        vad_buffer = np.concatenate([vad_buffer, chunk])

        if len(vad_buffer) < VAD_MIN_SAMPLES:
            continue

        frame = vad_buffer[:VAD_MIN_SAMPLES]
        vad_buffer = vad_buffer[VAD_MIN_SAMPLES:]
        if len(frame) < VAD_MIN_SAMPLES:
            continue

        now = time.time()
        rms = np.sqrt(np.mean(frame ** 2))
        energy_history.append(rms)

        # ---- VAD ----
        with torch.no_grad():
            vad_confidence = vad_model(torch.from_numpy(frame).unsqueeze(0), 16000).item()
        speech_started = vad_confidence > 0.5
        sustained = sum(e > ENERGY_FLOOR for e in energy_history) >= 3

        # ---- MICRO-SPIKE DETECTION ----
        if state == "PAUSING" and rms > ENERGY_FLOOR:
            micro_spike_times.append(now)

        # ---- STATE MACHINE ----
        if state == "IDLE":
            if speech_started or sustained:
                state = "SPEAKING"
                last_voice_time = now
                print("üü¢ Speech started")

        elif state == "SPEAKING":
            if speech_started or sustained:
                last_voice_time = now
            else:
                elapsed = (now - last_voice_time) * 1000
                if elapsed >= PAUSE_MS:
                    state = "PAUSING"
                    print(f"üü° Pause {int(elapsed)} ms")

        elif state == "PAUSING":
            elapsed = (now - last_voice_time) * 1000

            # SAFETY TIMEOUT
            if elapsed > SAFETY_TIMEOUT_MS:
                print(f"üî¥ SAFETY TIMEOUT: Force-ending turn after {elapsed:.0f}ms")
                state = "IDLE"
                last_voice_time = None
                energy_history.clear()
                pause_history.clear()
                micro_spike_times.clear()
                last_ai_interrupted = False
                with turn_audio_lock:
                    turn_audio.clear()
                current_partial_text = ""
                if TRANSCRIPTION_MODE == "vosk":
                    vosk_reset_requested = True
                continue

            # RESUME SPEECH?
            if speech_started or sustained:
                state = "SPEAKING"
                last_voice_time = now
                print("üü¢ Speech resumed")
            else:
                # CALCULATE CONFIDENCE
                confidence = 0.0
                if elapsed > END_MS:
                    confidence += 1.0
                if len(energy_history) >= 8:
                    recent_energies = list(energy_history)[-8:]
                    if max(recent_energies) < ENERGY_FLOOR * 1.8:
                        confidence += 0.7
                if elapsed < 1000:
                    recent_spikes = [t for t in micro_spike_times if now - t < 0.6]
                    if len(recent_spikes) >= 2:
                        confidence -= 0.5
                if elapsed < 900 and current_partial_text:
                    confidence += lexical_bias(current_partial_text) * 0.6
                if last_ai_interrupted:
                    confidence -= 0.5

                # END TURN?
                if confidence >= CONFIDENCE_THRESHOLD:
                    print(f"üî¥ Turn ended (confidence={confidence:.2f}, silence={elapsed:.0f}ms)")

                    # CAPTURE FULL TURN AUDIO
                    with turn_audio_lock:
                        turn_frames = list(turn_audio)
                        turn_audio.clear()

                    # RESET STATE
                    state = "IDLE"
                    last_voice_time = None
                    energy_history.clear()
                    pause_history.clear()
                    micro_spike_times.clear()
                    last_ai_interrupted = False
                    current_partial_text = ""
                    if TRANSCRIPTION_MODE == "vosk":
                        vosk_reset_requested = True

                    # GENERATE RESPONSE
                    def generate_response(frames):
                        start_time = time.time()
                        
                        if not frames:
                            print("‚ö†Ô∏è No audio captured ‚Äî skipping response")
                            return
                        full_audio = np.concatenate([frame for frame, _ in frames])
                        print(f"üîä Captured {len(frames)} frames ({full_audio.shape[0]/16000:.2f}s)")

                        # FINAL TRANSCRIPTION WITH WHISPER
                        segments, _ = whisper.transcribe(
                            full_audio,
                            language="en",
                            beam_size=5,
                            temperature=0.0,
                            condition_on_previous_text=False
                        )
                        user_text = " ".join(seg.text for seg in segments).strip()
                        if not user_text:
                            print("‚ö†Ô∏è Empty transcription ‚Äî skipping response")
                            return
                        print(f"üí¨ User: {user_text}")

                        try:
                            # LLM
                            llm_model, tokenizer = get_llm()
                            prompt = f"<|user|>\n{user_text}<|end|>\n<|assistant|>\n"
                            inputs = tokenizer(prompt, return_tensors="pt").to(llm_model.device)
                            with torch.no_grad():
                                outputs = llm_model.generate(
                                    **inputs,
                                    max_new_tokens=20,
                                    do_sample=False,
                                    pad_token_id=tokenizer.eos_token_id
                                )
                            response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
                            response_text = response_text.split("<|assistant|>")[-1].strip()

                            # üî• ULTRA-AGGRESSIVE SANITIZATION (ASCII ONLY)
                            # Remove EVERYTHING except letters, spaces, and basic punctuation
                            response_text = re.sub(r'[^a-zA-Z\s.!?]', ' ', response_text)
                            response_text = re.sub(r'\s+', ' ', response_text).strip()
                            
                            # Keep ONLY first 6 words MAX (Piper struggles beyond this on CPU)
                            words = response_text.split()[:6]
                            if not words:
                                safe_text = "Okay"
                            else:
                                safe_text = ' '.join(words)
                                # Ensure ends with punctuation
                                if not safe_text.endswith(('.', '!', '?')):
                                    safe_text += "."
                            
                            # FINAL SAFETY: If still problematic, use hardcoded phrase
                            if len(safe_text) < 3 or len(safe_text) > 50:
                                safe_text = "Okay."
                            
                            print(f"ü§ñ AI: {safe_text}")

                            response_queue.put(safe_text)
                        except Exception as e:
                            print(f"‚ùå Error: {e}")
                            import traceback
                            traceback.print_exc()

                    # LAUNCH RESPONSE THREAD
                    threading.Thread(target=generate_response, args=(turn_frames,), daemon=True).start()

        # ---- BUFFER AUDIO FOR STREAMING AND FINAL TRANSCRIPTION ----
        if state in ("SPEAKING", "PAUSING"):
            # For final transcription (never trimmed until turn ends)
            with turn_audio_lock:
                turn_audio.append((frame.copy(), now))
            # For streaming partials
            with asr_lock:
                asr_audio.append((frame.copy(), now))
                if TRANSCRIPTION_MODE == "whisper":
                    cutoff = now - WHISPER_WINDOW_SEC
                    while asr_audio and asr_audio[0][1] < cutoff:
                        asr_audio.popleft()
            # Vosk internal buffer
            if TRANSCRIPTION_MODE == "vosk":
                if not hasattr(asr_worker, "vosk_buf"):
                    asr_worker.vosk_buf = np.zeros(0, dtype=np.float32)
                asr_worker.vosk_buf = np.concatenate([asr_worker.vosk_buf, frame])
                while len(asr_worker.vosk_buf) >= VOSK_MIN_SAMPLES:
                    chunk_to_send = asr_worker.vosk_buf[:VOSK_MIN_SAMPLES]
                    asr_worker.vosk_buf = asr_worker.vosk_buf[VOSK_MIN_SAMPLES:]

except KeyboardInterrupt:
    stream.stop()
    print("\nüõë Test stopped")

üéôÔ∏è Real-time conversation test started
üü¢ Speech started
üü° Pause 602 ms
üî¥ Turn ended (confidence=1.70, silence=1203ms)
üîä Captured 51 frames (1.63s)


`torch_dtype` is deprecated! Use `dtype` instead!


üí¨ User: Thank you.
‚è≥ Loading TinyLlama (1.1B)...


Loading weights: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 201/201 [00:00<00:00, 437.02it/s, Materializing param=model.norm.weight]                              


‚úÖ LLM loaded!
üü¢ Speech started
üü° Pause 620 ms
ü§ñ AI: I m glad I could help.
üó£Ô∏è Speaking: I m glad I could help.
üî¥ Turn ended (confidence=1.70, silence=1219ms)
üîä Captured 79 frames (2.53s)
üí¨ User: Tell me something.
ü§ñ AI: Sure here s something I love.
üó£Ô∏è Speaking: 'Sure here s something I love.'
üü¢ Speech started
üü° Pause 623 ms
üî¥ Turn ended (confidence=1.70, silence=1225ms)
üîä Captured 103 frames (3.30s)
üü¢ Speech started
üí¨ User: Why are you so robotic?
üü° Pause 622 ms
üî¥ Turn ended (confidence=1.70, silence=1214ms)
üîä Captured 90 frames (2.88s)
ü§ñ AI: I m sorry for the confusion.
üó£Ô∏è Speaking: I m sorry for the confusion.
‚ö†Ô∏è Empty transcription ‚Äî skipping response

üõë Test stopped
