In [None]:
import yaml
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from transformers import pipeline
from qdrant_client import QdrantClient


In [27]:
# load config
with open("config.yaml", "r") as f:
    cfg = yaml.safe_load(f)

cfg


{'qdrant': {'url': 'http://localhost:6333',
  'api_key': None,
  'collection': 'documents'},
 'models': {'ocr': 'deepseek-ai/DeepSeek-OCR',
  'embedding': 'sentence-transformers/all-mpnet-base-v2',
  'generation': 'gpt2'},
 'ingest': {'chunk_chars': 1000, 'chunk_overlap': 100, 'top_k': 5},
 'docs_dir': {'docs_to_ingest': '/home/cpu064/Downloads/docs/'}}

In [28]:
qcfg = cfg["qdrant"]
COLLECTION = qcfg["collection"]

embed_model_name = cfg["models"]["embedding"]
gen_model_name  = cfg["models"]["generation"]

qcfg, embed_model_name, gen_model_name


({'url': 'http://localhost:6333', 'api_key': None, 'collection': 'documents'},
 'sentence-transformers/all-mpnet-base-v2',
 'gpt2')

In [29]:
embedder = SentenceTransformer(embed_model_name)
generator = pipeline("text-generation", model=gen_model_name)


Device set to use cuda:0


In [30]:
q = QdrantClient(
    url=qcfg["url"],
    api_key=qcfg.get("api_key")
)


In [31]:
class QueryRequest(BaseModel):
    question: str
    top_k: int | None = None


In [None]:
def query(req: QueryRequest):
    top_k = req.top_k or cfg["ingest"]["top_k"]

    # embed the query
    q_vector = embedder.encode(req.question).tolist()

    # retrieve from Qdrant
    if hasattr(q, "search"):
        hits = q.search(collection_name=COLLECTION, query_vector=q_vector, limit=top_k)
    elif hasattr(q, "search_points"):
        try:
            hits = q.search_points(collection_name=COLLECTION, query_vector=q_vector, limit=top_k)
        except TypeError:
            # fallback to positional parameters for some client versions
            hits = q.search_points(COLLECTION, q_vector, top_k)
    else:
        raise AttributeError(
            "QdrantClient has no 'search' or 'search_points' method. "
            "Please upgrade/downgrade qdrant-client or check the client API."
        )

    contexts = []
    for h in hits:
        payload = h.payload
        contexts.append(f"(source:{payload.get('source_file')} page_chunk:{payload.get('chunk_index')}) {payload.get('text')}")

    # build prompt and generate
    prompt = "You are a helpful assistant. Use the following context to answer the question.\n\n"
    prompt += "\n\n---\n\n".join(contexts)
    prompt += f"\n\nQuestion: {req.question}\nAnswer:"

    # You might want to use an instruction model or an LLM with longer context
    out = generator(prompt, max_length=256, do_sample=False)
    answer = out[0]["generated_text"]

    return {"answer": answer, "retrieved": [h.payload for h in hits]}

In [None]:
test = QueryRequest(question="What is this document about?", top_k=3)
response = query(test)
response