In [1]:
!pip install pandas==2.2.2  sentence-transformers
!pip install --upgrade --force-reinstall numpy==1.26.4
!pip install chromadb==1.0.20
!pip install rank_bm25==0.2.2
!pip -q install flask flask-cors pyngrok waitress rank_bm25 sentence_transformers chromadb transformers huggingface_hub
!pip install -U bitsandbytes accelerate "transformers>=4.44" peft


Collecting numpy==1.26.4
  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.0/18.0 MB[0m [31m82.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 2.0.2
    Uninstalling numpy-2.0.2:
      Successfully uninstalled numpy-2.0.2
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
opencv-contrib-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= "3.9", but you have numpy 1.26.4 which is incompatible.
pytensor 2.35.1 requires numpy

Collecting chromadb==1.0.20
  Downloading chromadb-1.0.20-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.3 kB)
Collecting pybase64>=1.4.1 (from chromadb==1.0.20)
  Downloading pybase64-1.4.2-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl.metadata (8.7 kB)
Collecting posthog<6.0.0,>=2.4.0 (from chromadb==1.0.20)
  Downloading posthog-5.4.0-py3-none-any.whl.metadata (5.7 kB)
Collecting onnxruntime>=1.14.1 (from chromadb==1.0.20)
  Downloading onnxruntime-1.23.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.1 kB)
Collecting opentelemetry-exporter-otlp-proto-grpc>=1.2.0 (from chromadb==1.0.20)
  Downloading opentelemetry_exporter_otlp_proto_grpc-1.38.0-py3-none-any.whl.metadata (2.4 kB)
Collecting pypika>=0.48.9 (from chromadb==1.0.20)
  Downloading PyPika-0.48.9.tar.gz (67 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.3/67.3 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os, re, json, csv, time, threading
from typing import List, Dict, Any, Tuple
from collections import defaultdict
from datetime import datetime, timezone

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
from rank_bm25 import BM25Okapi
from sentence_transformers import CrossEncoder
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login

from chromadb import PersistentClient
from chromadb.config import Settings
from chromadb.utils import embedding_functions

from flask import Flask, request, jsonify
from flask_cors import CORS
from pyngrok import ngrok

In [None]:
NOTEBOOK_API_KEY = os.getenv("NOTEBOOK_API_KEY", "dev-notebook")  # must match backend's NOTEBOOK_API_KEY

CHROMA_PATH     = "/content/drive/MyDrive/chroma"
COLLECTION_NAME = "actSectionsV2"
EMBED_MODEL     = "intfloat/e5-base-v2"

TOPK_DENSE_WIDE   = 300
TOPK_BM25_WIDE    = 300
TOPK_CE_RERANK    = 400
TOPK_FINAL        = 6

# Early-fusion weights
W_BM25_DOC   = 0.35
W_BM25_HEAD  = 0.35
W_DENSE      = 0.30

# Final fusion
ALPHA_FUSION = 0.80

HEADING_PRIOR_POS = 0.05
HEADING_PRIOR_NEG = -0.05

# Soft Act prior weight added per candidate based on fused act mass
ACT_PRIOR_BETA = 0.05

# Act-gating
ACT_GATING_K   = 3
ACT_CONF_MIN   = 0.55

DEAD_HEAD_RE  = re.compile(r"\b(spent|repealed|revoked|deleted)\b|\[\s*spent\s*\]", re.I)
INTERP_RE     = re.compile(r"\binterpretation\b", re.I)

SALIENT_TERMS = [
    "equality","non-discrimination","discrimination","privacy","consent",
    "detention","expression","housing","health","education","water",
    "termination","unfair termination","redundancy","dismissal",
    "children","bail","arrest","data","lawful processing","principles of data protection"
]


client = PersistentClient(path=CHROMA_PATH, settings=Settings(allow_reset=True))
print("Existing collections:", [c.name for c in client.list_collections()])
coll = client.get_collection(COLLECTION_NAME)
print(f"Loaded collection '{COLLECTION_NAME}' with {coll.count()} documents")

ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=EMBED_MODEL)


In [None]:
def load_corpus_df(chroma_coll) -> pd.DataFrame:
    limit = 5000
    offset = 0
    all_ids, all_docs, all_metas = [], [], []
    while True:
        batch = chroma_coll.get(limit=limit, offset=offset, include=["documents","metadatas"])
        ids = batch.get("ids", [])
        if not ids:
            break
        all_ids.extend(ids)
        all_docs.extend(batch.get("documents", []))
        all_metas.extend(batch.get("metadatas", []))
        offset += len(ids)
        if len(ids) < limit:
            break

    rows = []
    for _id, doc, meta in zip(all_ids, all_docs, all_metas):
        meta = meta or {}
        act  = meta.get("act") or meta.get("Act") or ""
        heading = meta.get("section_title") or meta.get("heading") or meta.get("title") or ""
        section_num = str(meta.get("section") or meta.get("section_num") or "").strip()
        rows.append({
            "id": _id,
            "doc": doc or "",
            "heading": heading or "",
            "act": act,
            "section_num": section_num
        })
    df = pd.DataFrame(rows)
    df["heading"] = df["heading"].fillna("")
    df["section_num"] = df["section_num"].fillna("")
    return df

corpus_df = load_corpus_df(coll)
assert not corpus_df.empty, "corpus_df is empty — check collection contents."
print("corpus_df shape:", corpus_df.shape)
id2 = corpus_df.set_index("id")


In [None]:
corpus_df = load_corpus_df(coll)
assert not corpus_df.empty, "corpus_df is empty — check collection contents."
print("corpus_df shape:", corpus_df.shape)
id2 = corpus_df.set_index("id")

def _tok(t: str):
    return [x for x in re.split(r"\W+", (t or "").lower()) if x]

bm25_doc  = BM25Okapi([_tok(d) for d in corpus_df["doc"].tolist()])
bm25_head = BM25Okapi([_tok(h) for h in corpus_df["heading"].fillna("").tolist()])


In [None]:
def _minmax(xs):
    if not xs: return []
    mn, mx = float(min(xs)), float(max(xs))
    if mx <= mn: return [0.0]*len(xs)
    return [(x-mn)/(mx-mn) for x in xs]

# ----------------------------
# Dense + BM25 candidate functions
# ----------------------------
def _dense_candidates(q, n):
    q_embed = ef([f"query: {q}"])
    res = coll.query(query_embeddings=q_embed, n_results=n, include=["distances"])
    return res["ids"][0], res["distances"][0]  # distances lower=better

def _bm25_doc_candidates(q, n):
    scores = bm25_doc.get_scores(_tok(q))
    idx = np.argsort(scores)[::-1][:n]
    return [corpus_df.loc[i,"id"] for i in idx], [float(scores[i]) for i in idx]

def _bm25_head_candidates(q, n):
    scores = bm25_head.get_scores(_tok(q))
    idx = np.argsort(scores)[::-1][:n]
    return [corpus_df.loc[i,"id"] for i in idx], [float(scores[i]) for i in idx]

def _dense_to_sim(dists):
    return _minmax([-d for d in dists])


In [None]:

def _act_topk_share_and_map(ids, scores, k=3):
    mass, total = defaultdict(float), 0.0
    for _id, sc in zip(ids, scores):
        if _id in id2.index:
            a = str(id2.loc[_id, "act"] or "")
            mass[a] += sc; total += sc
    ranked = sorted(mass.items(), key=lambda kv: kv[1], reverse=True)
    share  = (ranked[0][1]/total) if (ranked and total>0) else 0.0
    topk   = [a for a,_ in ranked[:k] if a]
    act_share = {a: (v/total if total>0 else 0.0) for a, v in mass.items()}
    return topk, share, act_share

In [None]:
def _heading_prior(heading, q):
    h, ql = (heading or "").lower(), (q or "").lower()
    if any(t in h for t in SALIENT_TERMS) and any(t in ql for t in SALIENT_TERMS):
        return HEADING_PRIOR_POS
    if DEAD_HEAD_RE.search(h): return HEADING_PRIOR_NEG
    if INTERP_RE.search(h):    return -0.03
    return 0.0

def _overlap_prior(heading, q):
    ht, qt = set(_tok(heading)), set(_tok(q))
    if not ht or not qt: return 0.0
    jacc = len(ht & qt) / max(1, len(ht | qt))
    bonus = min(0.05, 0.30 * jacc)
    qtokens = _tok(q)
    phrases = [" ".join(p) for p in zip(qtokens, qtokens[1:])]
    if any(p and p in (heading or "").lower() for p in phrases):
        bonus += 0.02
    return min(bonus, 0.07)


In [None]:
def _section_num_bonus(q, sec_num):
    if not sec_num: return 0.0
    q_nums = re.findall(r"\d+", (q or ""))
    return 0.03 if sec_num in q_nums else 0.0

def _fmt_section(sec_num, head):
    sec_num = (sec_num or "").strip()
    head    = (head or "").strip()
    return f"{sec_num} - {head}" if sec_num and head else (sec_num or head or "")


### RERANKER

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
ce = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=512, device=device)


In [None]:
def retrieve_top6_for_question_T6(q: str) -> Dict[str, Any]:
    # 1) Wide pools
    d_ids, d_dists = _dense_candidates(q, TOPK_DENSE_WIDE)
    bd_ids, bd_scs = _bm25_doc_candidates(q, TOPK_BM25_WIDE)
    bh_ids, bh_scs = _bm25_head_candidates(q, TOPK_BM25_WIDE)

    # 2) Early fusion
    d_norm  = _dense_to_sim(d_dists)
    bd_norm = _minmax(bd_scs)
    bh_norm = _minmax(bh_scs)

    fuse = defaultdict(float)
    for i, _id in enumerate(d_ids):  fuse[_id]  += W_DENSE     * d_norm[i]
    for i, _id in enumerate(bd_ids): fuse[_id]  += W_BM25_DOC  * bd_norm[i]
    for i, _id in enumerate(bh_ids): fuse[_id]  += W_BM25_HEAD * bh_norm[i]

    cand_ids    = sorted(fuse.keys(), key=lambda k: fuse[k], reverse=True)
    cand_scores = [fuse[_id] for _id in cand_ids]

    # 3) Act-gating + soft act prior map
    top_acts, share, act_share = _act_topk_share_and_map(cand_ids, cand_scores, k=ACT_GATING_K)
    if share >= ACT_CONF_MIN and top_acts:
        gated = []
        for _id in cand_ids:
            if _id in id2.index and id2.loc[_id, "act"] in top_acts:
                gated.append(_id)
            if len(gated) >= TOPK_CE_RERANK: break
    else:
        gated = cand_ids[:TOPK_CE_RERANK]

    # 4) CE rerank
    qdoc = [(q, id2.loc[_id, "doc"]) for _id in gated]
    ce_scores = ce.predict(qdoc, batch_size=32, show_progress_bar=False).tolist()
    ce_norm   = _minmax(ce_scores)

    # dense sims for same gated set
    d_map       = {i:d for i,d in zip(d_ids, d_dists)}
    dense_sims  = [(-d_map[_id] if _id in d_map else float("-inf")) for _id in gated]
    dense_norm2 = _minmax(dense_sims)

    # 5) Final fusion + priors
    finals = []
    for i, _id in enumerate(gated):
        row    = id2.loc[_id]
        head   = row["heading"]
        secnum = str(row["section_num"] or "")
        act    = str(row["act"] or "")

        score = (
            ALPHA_FUSION * ce_norm[i] +
            (1.0 - ALPHA_FUSION) * dense_norm2[i] +
            _heading_prior(head, q) +
            _overlap_prior(head, q) +
            ACT_PRIOR_BETA * float(act_share.get(act, 0.0)) +
            _section_num_bonus(q, secnum)
        )
        finals.append((_id, score))

    finals_sorted = sorted(finals, key=lambda t: t[1], reverse=True)

    # 6) De-dup by section number (best per section)
    seen_secs, top_ids = set(), []
    for _id, _sc in finals_sorted:
        s = str(id2.loc[_id, "section_num"] or "")
        if s and s in seen_secs:
            continue
        seen_secs.add(s)
        top_ids.append(_id)
        if len(top_ids) >= TOPK_FINAL: break
    if len(top_ids) < TOPK_FINAL:
        for _id, _ in finals_sorted:
            if _id not in top_ids:
                top_ids.append(_id)
            if len(top_ids) >= TOPK_FINAL: break

    sub = id2.loc[top_ids].reset_index()
    return {
        "Top6_IDs": top_ids,
        "Top6_Sections_fmt": [_fmt_section(s, h) for s,h in zip(sub["section_num"], sub["heading"])],
        "Top6_Answers": sub["doc"].tolist(),
        "Top6_Acts": sub["act"].tolist(),
    }

In [None]:
def _build_context_from_bundle(bundle: Dict[str, Any]) -> str:
    ctx = []
    for i, (sec, act, doc) in enumerate(zip(bundle["Top6_Sections_fmt"], bundle["Top6_Acts"], bundle["Top6_Answers"]), start=1):
        ctx.append(f"[{i}] {act} - {sec}\n{doc.strip()}")
    return "\n\n".join(ctx)


### LOADING QWEN

In [None]:
try:
    login()
except Exception:
    pass

In [None]:
MODEL_NAME = "Qwen/Qwen3-8B"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    load_in_4bit=True
)
print("Qwen model loaded successfully!")


### MODEL WARM-UP


In [None]:
# Warm up embedding and reranker to avoid first-request latency
_warmup_q = "Warm-up query about employment law"
_warmup_doc = "This is placeholder legal text for warm-up purposes only."
_ = ef([f"query: {_warmup_q}"])
with torch.inference_mode():
    _ = ce.predict([(_warmup_q, _warmup_doc)], batch_size=1, show_progress_bar=False)
print("Retrieval models warmed up.")


In [None]:
# Warm up generator to cache weights on device
warmup_messages = [
    {"role": "system", "content": "You are Uhaki, an AI legal assistant."},
    {"role": "user", "content": "Provide a short legal summary for warm-up."}
]
warmup_text = tokenizer.apply_chat_template(
    warmup_messages,
    tokenize=False,
    add_generation_prompt=True
)
warmup_inputs = tokenizer(warmup_text, return_tensors="pt").to(model.device)
with torch.no_grad():
    _ = model.generate(
        **warmup_inputs,
        max_new_tokens=32,
        temperature=0.7,
        top_p=0.9
    )
print("Generation model warmed up.")


In [None]:
def generate_uhaki_answer(query: str, bundle: Dict[str, Any], enable_thinking: bool = False) -> str:
    context = _build_context_from_bundle(bundle)
    system_prompt = (
        "You are Uhaki, an AI legal assistant for Kenyan law. "
        "Answer the user's question using only the legal information provided in the context below. "
        "Provide a concise but complete legal summary. "
        "Cite Acts and sections in parentheses (e.g., Employment Act s.44). "
        "If the answer is not in the context, say so."
    )
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=enable_thinking
    )
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.2,
            top_p=0.9
        )
    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    cleaned = re.split(r"(?:Question:.*?\n|assistant\n)", decoded, flags=re.IGNORECASE)[-1].strip()
    cleaned = re.sub(r"<think>.*?</think>", "", cleaned, flags=re.DOTALL | re.IGNORECASE)
    cleaned = re.sub(r"\s+", " ", cleaned).strip()
    return cleaned

### FLASK APP

In [None]:
def generate_uhaki_answer(query: str, bundle: Dict[str, Any], enable_thinking: bool = False) -> str:
    context = _build_context_from_bundle(bundle)
    system_prompt = (
        "You are Uhaki, an AI legal assistant for Kenyan law. "
        "Answer the user's question using only the legal information provided in the context below. "
        "Provide a concise but complete legal summary. "
        "Cite Acts and sections in parentheses (e.g., Employment Act s.44). "
        "If the answer is not in the context, say so."
    )
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=enable_thinking
    )
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.2,
            top_p=0.9
        )
    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    cleaned = re.split(r"(?:Question:.*?\n|assistant\n)", decoded, flags=re.IGNORECASE)[-1].strip()
    cleaned = re.sub(r"<think>.*?</think>", "", cleaned, flags=re.DOTALL | re.IGNORECASE)
    cleaned = re.sub(r"\s+", " ", cleaned).strip()
    return cleaned

### FLASK APP

In [None]:
app = Flask(__name__)
CORS(app)


In [None]:
@app.get("/health")
def health():
    return jsonify({
        "ok": True,
        "collection": COLLECTION_NAME,
        "docs": int(coll.count()),
        "embed_model": EMBED_MODEL,
        "ce_model": "cross-encoder/ms-marco-MiniLM-L-6-v2",
        "gen_model": MODEL_NAME
    })

In [None]:
@app.post("/generate")
def generate():
    # Simple auth
    if request.headers.get("X-API-Key") != NOTEBOOK_API_KEY:
        return jsonify({"error": "Unauthorized"}), 401

    data = request.get_json(force=True) or {}
    query = (data.get("query") or "").strip()
    top_k_return   = int(data.get("top_k_return", TOPK_FINAL))
    # Note: our pipeline already uses wide pools internally; top_k_retrieve is ignored.

    if not query:
        return jsonify({"error": "No query provided"}), 400

    # Retrieval
    bundle = retrieve_top6_for_question_T6(query)
    # If caller asked fewer than 6 back
    if top_k_return < TOPK_FINAL:
        for k in ["Top6_IDs", "Top6_Sections_fmt", "Top6_Answers", "Top6_Acts"]:
            bundle[k] = bundle[k][:top_k_return]

    # Generation
    answer = generate_uhaki_answer(query, bundle)

    # Response
    acts_sections = [
        {"act": act, "section": sec}
        for act, sec in zip(bundle["Top6_Acts"], bundle["Top6_Sections_fmt"])
    ]
    return jsonify({
        "ok": True,
        "query": query,
        "answer": answer,
        "top6": acts_sections,
        "raw": {
            "ids": bundle["Top6_IDs"],
            "sections_fmt": bundle["Top6_Sections_fmt"],
            "acts": bundle["Top6_Acts"]
        }
    })


In [None]:
NGROK_TOKEN = "YOUR_NGROK_AUTHTOKEN_HERE"  # put your token here
!ngrok config add-authtoken $NGROK_TOKEN

public_url = ngrok.connect(addr=5001, proto="http")
print("Notebook public URL:", public_url)
os.environ["PUBLIC_NOTEBOOK_URL"] = str(public_url)

In [None]:
def run_server():
    from waitress import serve
    serve(app, host="0.0.0.0", port=5001, threads=8)

threading.Thread(target=run_server, daemon=True).start()
time.sleep(2)
print("Notebook server ready at:", public_url, " — endpoint: /generate")