# Vox Amelior - GPU Offload Mode

This notebook runs ONLY the GPU-heavy components:
- **Transcription** (NeMo Parakeet ASR)
- **Gemma LLM** (llama.cpp)

Your local machine runs everything else. This notebook exposes simple HTTP endpoints via Tailscale.

**Setup:**
1. Set `TAILSCALE_AUTHKEY` in Colab Secrets
2. Run all cells
3. Point your local services to the Tailscale IP

In [None]:
# @title 1. Install Dependencies (GPU)
!pip install -q fastapi uvicorn nest_asyncio httpx
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install -q nemo_toolkit[asr]
!pip install -q llama-cpp-python --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cu121
!apt-get install -y ffmpeg -qq
print("Dependencies installed.")

In [None]:
# @title 2. Setup Tailscale
import os
import subprocess
import time
from google.colab import userdata

!curl -fsSL https://tailscale.com/install.sh | sh

AUTHKEY = userdata.get('TAILSCALE_AUTHKEY')
if not AUTHKEY:
    raise ValueError("Set TAILSCALE_AUTHKEY in Colab Secrets!")

subprocess.Popen(
    ['sudo', 'tailscaled', '--tun=userspace-networking', '--state=/tmp/tailscale.state'],
    stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
time.sleep(5)

result = subprocess.run(
    ['sudo', 'tailscale', 'up', f'--authkey={AUTHKEY}', '--hostname=vox-gpu'],
    capture_output=True, text=True
)
if result.returncode != 0:
    print(f"Tailscale error: {result.stderr}")
else:
    time.sleep(2)
    ip_result = subprocess.run(['tailscale', 'ip', '-4'], capture_output=True, text=True)
    TAILSCALE_IP = ip_result.stdout.strip()
    print(f"Tailscale IP: {TAILSCALE_IP}")
    print(f"Transcription: http://{TAILSCALE_IP}:8003/transcribe")
    print(f"Gemma:         http://{TAILSCALE_IP}:8001/generate")

In [None]:
# @title 3. Load Transcription Model (NeMo Parakeet)
import torch
import nemo.collections.asr as nemo_asr

print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Load Parakeet 1.1B (fast, accurate)
print("Loading Parakeet ASR model...")
asr_model = nemo_asr.models.ASRModel.from_pretrained("nvidia/parakeet-ctc-1.1b")
asr_model = asr_model.to('cuda' if torch.cuda.is_available() else 'cpu')
asr_model.eval()
print("Transcription model loaded.")

In [None]:
# @title 4. Load Gemma Model (Optional - Comment out if not needed)
from llama_cpp import Llama
from huggingface_hub import hf_hub_download

# Download Gemma 2B 4-bit quantized (small, fast)
print("Downloading Gemma model...")
model_path = hf_hub_download(
    repo_id="google/gemma-2b-it-GGUF",
    filename="gemma-2b-it.gguf",
    local_dir="/content/models"
)

print("Loading Gemma...")
llm = Llama(
    model_path=model_path,
    n_ctx=4096,
    n_gpu_layers=-1,  # Use all GPU layers
    verbose=False
)
print("Gemma loaded.")

In [None]:
# @title 5. Start GPU API Server
import io
import base64
import tempfile
import nest_asyncio
import uvicorn
from fastapi import FastAPI, UploadFile, File
from pydantic import BaseModel

nest_asyncio.apply()

app = FastAPI(title="Vox GPU Offload")

class TranscribeRequest(BaseModel):
    audio_base64: str  # Base64 encoded audio

class GenerateRequest(BaseModel):
    prompt: str
    max_tokens: int = 512
    temperature: float = 0.7

@app.get("/health")
def health():
    return {"status": "healthy", "gpu": torch.cuda.is_available()}

@app.post("/transcribe")
async def transcribe(request: TranscribeRequest):
    """Transcribe base64-encoded audio"""
    try:
        audio_bytes = base64.b64decode(request.audio_base64)
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
            f.write(audio_bytes)
            temp_path = f.name
        
        transcription = asr_model.transcribe([temp_path])
        text = transcription[0] if transcription else ""
        
        import os
        os.unlink(temp_path)
        
        return {"text": text, "success": True}
    except Exception as e:
        return {"error": str(e), "success": False}

@app.post("/transcribe/file")
async def transcribe_file(file: UploadFile = File(...)):
    """Transcribe uploaded audio file"""
    try:
        content = await file.read()
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
            f.write(content)
            temp_path = f.name
        
        transcription = asr_model.transcribe([temp_path])
        text = transcription[0] if transcription else ""
        
        import os
        os.unlink(temp_path)
        
        return {"text": text, "success": True}
    except Exception as e:
        return {"error": str(e), "success": False}

@app.post("/generate")
async def generate(request: GenerateRequest):
    """Generate text with Gemma"""
    try:
        output = llm(
            request.prompt,
            max_tokens=request.max_tokens,
            temperature=request.temperature,
            stop=["<end_of_turn>", "<eos>"]
        )
        text = output["choices"][0]["text"].strip()
        return {"text": text, "success": True}
    except Exception as e:
        return {"error": str(e), "success": False}

print("Starting GPU API server on ports 8001 (Gemma) and 8003 (Transcription)...")
print(f"Access via Tailscale: http://{TAILSCALE_IP}:8000")

# Run on single port for simplicity
uvicorn.run(app, host="0.0.0.0", port=8000)