# üêç Ouroboros Colab Inference Server (LLaDA - CPU Offload)

This notebook runs LLaDA-8B-Instruct with **CPU offloading** for T4 GPU.

### Instructions
1.  **Runtime**: Go to `Runtime` -> `Change runtime type` -> Select **T4 GPU**.
2.  **Ngrok**: Get your authtoken from [ngrok.com](https://ngrok.com).
3.  **Hugging Face**: Have your HF Token ready.
4.  **Run**: Execute all cells below **IN ORDER**.
5.  **Connect**: Copy the ngrok URL to your local `.env` as `COLAB_API_URL`.

In [None]:
# @title 1. Install Dependencies
!pip install -q transformers==4.38.2 pyngrok fastapi uvicorn nest-asyncio torch accelerate huggingface_hub bitsandbytes

In [None]:
# @title 2. Clone LLaDA Repository
import os
import sys

if not os.path.exists('/content/LLaDA'):
    !git clone https://github.com/ML-GSAI/LLaDA.git /content/LLaDA
    print("‚úÖ Cloned LLaDA repository")
else:
    print("‚úÖ LLaDA repository already exists")

if '/content/LLaDA' not in sys.path:
    sys.path.insert(0, '/content/LLaDA')
    print(f"‚úÖ Added to Python path")

In [None]:
# @title 3. Setup Ngrok
import getpass
from pyngrok import ngrok, conf

print("Enter your ngrok authtoken (hidden):")
token = getpass.getpass()
conf.get_default().auth_token = token
print("‚úÖ Ngrok configured!")

In [None]:
# @title 4. Setup Hugging Face
import getpass

print("Enter your Hugging Face Token:")
hf_token = getpass.getpass("HF Token: ")

if hf_token.strip():
    from huggingface_hub import login
    login(token=hf_token.strip())
    print("‚úÖ Logged in!")

In [None]:
# @title 5. Load LLaDA with 8-bit + CPU Offload
import torch
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
from generate import generate

print("‚úÖ Imported official generate()")

MODEL_NAME = "GSAI-ML/LLaDA-8B-Instruct"
print(f"\nLoading {MODEL_NAME} with 8-bit + CPU offload...")

# 8-bit quantization with CPU offload enabled
quant_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_enable_fp32_cpu_offload=True  # KEY: Allow CPU offload!
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)

# Custom device map: keep most on GPU, offload lm_head to CPU
device_map = {
    "model": 0,      # Main model on GPU
    "lm_head": "cpu" # Offload language model head to CPU
}

model = AutoModel.from_pretrained(
    MODEL_NAME,
    quantization_config=quant_config,
    device_map=device_map,
    trust_remote_code=True
).eval()

print("‚úÖ LLaDA loaded successfully (8-bit + CPU offload)!")
print(f"Model devices: {model.hf_device_map}")

In [None]:
# @title 6. Start Server
import nest_asyncio
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

nest_asyncio.apply()
app = FastAPI()

class GenerationRequest(BaseModel):
    prompt: str
    max_tokens: int = 128
    temperature: float = 0.0

class GenerationResponse(BaseModel):
    generated_text: str
    tokens_used: int

@app.post("/generate")
async def generate_text(req: GenerationRequest):
    try:
        messages = [{"role": "user", "content": req.prompt}]
        prompt_text = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=False
        )
        
        input_ids = tokenizer(prompt_text)['input_ids']
        input_ids = torch.tensor(input_ids).to('cuda').unsqueeze(0)
        
        print(f"Generating {req.max_tokens} tokens...")
        
        output_ids = generate(
            model,
            input_ids,
            steps=req.max_tokens,
            gen_length=req.max_tokens,
            block_length=32,
            temperature=req.temperature,
            cfg_scale=0.0,
            remasking='low_confidence'
        )
        
        generated_text = tokenizer.batch_decode(
            output_ids[:, input_ids.shape[1]:],
            skip_special_tokens=True
        )[0]
        
        tokens_used = output_ids.shape[1] - input_ids.shape[1]
        print(f"‚úÖ Generated: {generated_text[:100]}...")
        
        return GenerationResponse(
            generated_text=generated_text,
            tokens_used=tokens_used
        )
    except Exception as e:
        import traceback
        error_trace = traceback.format_exc()
        print(f"‚ùå Error: {error_trace}")
        raise HTTPException(status_code=500, detail=f"{str(e)}\n{error_trace}")

public_url = ngrok.connect(8000)
print(f"\nüî• SERVER RUNNING! üî•")
print(f"Copy this URL to your local .env: {public_url.public_url}\n")

config = uvicorn.Config(app, host="0.0.0.0", port=8000)
server = uvicorn.Server(config)
await server.serve()