In [1]:
!pip install -q gradio

In [2]:
!pip install -q accelerate bitsandbytes fsspec==2025.3.2 datasets peft transformers trl

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.1/44.1 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m72.9/72.9 MB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m316.6/316.6 kB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m21.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m19.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# main.py
# --- Imports ---
import os
import torch
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
from peft import PeftModel
import requests
import json

# --- 1. App and Model Setup ---

# Initialize FastAPI app
app = FastAPI()

# --- Model Loading (Done once on startup) ---
print("Setting up the model... This may take a few minutes.")

# Ensure you have set your API key as an environment variable in your GCP VM
# export SERPER_API_KEY="your_key_here"
SERPER_API_KEY = os.environ.get("SERPER_API_KEY", "ea1ff041f2442311372fc6c78e6723252aed1238")
if not SERPER_API_KEY:
    print("WARNING: Serper API key not found. Web search will not work.")
    # You could raise an error here, but we'll let it run for chat-only mode
    # raise ValueError("Serper API key not found. Please set the SERPER_API_KEY environment variable.")

# Model IDs
base_model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
adapter_id = "notninja/chad-gpt"

# Quantization config for memory efficiency
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    quantization_config=bnb_config,
    device_map="auto",
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_id)

# Load your fine-tuned adapter
model = PeftModel.from_pretrained(base_model, adapter_id)
print("✅ Model setup complete!")

# Create the text generation pipeline
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)


# --- 2. Pydantic Models for Request Body ---
class ChatRequest(BaseModel):
    message: str
    use_web_search: bool = False


# --- 3. Generation Functions ---

def get_normal_response(message: str):
    """Generates a direct chat response without web search."""
    system_prompt = "You are a 'Chad' chatbot that speaks in Gen-Z slang and give answers from that perspective"
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": message},
    ]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    result = generator(prompt, max_new_tokens=500, temperature=0.7, eos_token_id=tokenizer.eos_token_id)

    full_text = result[0]['generated_text']
    response_only = full_text.split("<|start_header_id|>assistant<|end_header_id|>\n\n")[-1].strip()
    if response_only.endswith("</s>"):
        response_only = response_only[:-len("</s>")].strip()
    return response_only

def get_search_response(message: str):
    """Performs a web search and then generates a response."""
    try:
        url = "https://google.serper.dev/search"
        payload = json.dumps({"q": message})
        headers = {'X-API-KEY': SERPER_API_KEY, 'Content-Type': 'application/json'}
        response = requests.request("POST", url, headers=headers, data=payload)
        search_results = response.json()
        context = ""
        if 'organic' in search_results:
            for result in search_results['organic'][:5]:
                context += result.get('snippet', '') + "\n"
        if not context:
            context = "Couldn't find anything on the web about that, fam."
    except Exception as e:
        print(f"Error during Serper search: {e}")
        context = "The web search is down bad rn, couldn't find anything."

    system_prompt = "You are a 'Chad' chatbot that speaks in Gen-Z slang."
    user_instruction = f"""Based on the following web search results, answer my original question. My question was: '{message}'. Here are the search results: --- {context[:2000]} --- Now, answer my question in a short, confident, Gen-Z way."""

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_instruction},
    ]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    result = generator(prompt, max_new_tokens=500, temperature=0.7, eos_token_id=tokenizer.eos_token_id)

    full_text = result[0]['generated_text']
    response_only = full_text.split("<|start_header_id|>assistant<|end_header_id|>\n\n")[-1].strip()
    if response_only.endswith("</s>"):
        response_only = response_only[:-len("</s>")].strip()
    return response_only


# --- 4. API Endpoints ---

# Endpoint to serve the HTML frontend
@app.get("/", response_class=HTMLResponse)
async def read_root():
    with open("index.html") as f:
        return f.read()

# Endpoint to handle chat messages
@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
    print(f"Received request: message='{request.message}', use_web_search={request.use_web_search}")
    if request.use_web_search:
        response = get_search_response(request.message)
    else:
        response = get_normal_response(request.message)
    return {"response": response}

# --- 5. Main entry point to run the app ---
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)


Setting up the model... This may take a few minutes.


config.json:   0%|          | 0.00/654 [00:00<?, ?B/s]



model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]