In [None]:
!python -m pip install --upgrade pip
!python -m pip uninstall -y torch torchvision torchaudio
!python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
!pip install sentence-transformers faiss-gpu-cu12 fastapi uvicorn

In [None]:
from huggingface_hub import login
login(token="HF_token")

In [None]:
import threading
import uvicorn
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
from pydantic import BaseModel
from data import example_data
from fastapi import FastAPI

In [None]:
app = FastAPI()

embedding_model = SentenceTransformer("google/embeddinggemma-300m")
lm_model_id = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(lm_model_id)
llm = AutoModelForCausalLM.from_pretrained(lm_model_id)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
llm = llm.to(device)
llm.eval()

In [None]:
class QueryRequest(BaseModel):
    query: str

def find_top_entry(query):
    query_emb = embedding_model.encode(query, convert_to_tensor=True)
    corpus_embs = embedding_model.encode([entry["text"] for entry in example_data], convert_to_tensor=True)
    similarities = torch.nn.functional.cosine_similarity(query_emb, corpus_embs)
    top_idx = torch.argmax(similarities).item()
    return example_data[top_idx]

In [None]:
def build_prompt(context_text, question):
    return f"Use the following context to answer the question.\nContext: {context_text}\nQuestion: {question}\nAnswer:"

def generate_answer(prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = llm.generate(**inputs, max_new_tokens=256, temperature=0.7, top_p=0.9, do_sample=True)
    return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()

In [None]:
@app.post("/answer")
async def answer_query(request: QueryRequest):
    relevant = find_top_entry(request.query)
    prompt = build_prompt(relevant["text"], request.query)
    answer = generate_answer(prompt)
    return {
        "answer": answer
    }

In [None]:
def run_api():
    uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")

thread = threading.Thread(target=run_api, daemon=True)
thread.start()