In [None]:
!pip install huggingface_hub
!pip install -U datasets sentence-transformers umap-learn scikit-learn faiss-cpu huggingface_hub numpy

In [None]:


import os
from huggingface_hub import InferenceClient


from google.colab import userdata
from huggingface_hub import InferenceClient

HF_TOKEN = userdata.get("HF_TOKEN")  # hidden secret
client = InferenceClient(api_key=HF_TOKEN)

resp = client.chat.completions.create(
    model="mistralai/Mistral-7B-Instruct-v0.2",
    messages=[{"role": "user", "content": "Explain RAPTOR in simple terms."}],
    max_tokens=200,
)

print(resp.choices[0].message.content)




In [None]:
# =========================
# RAPTOR-only on PubMedQA (public healthcare dataset)
# - Builds RAPTOR tree for ONE selected PubMedQA example
# - Runs 1â€“2 queries against the collapsed tree
# - Uses Colab Secret HF_TOKEN via userdata.get("HF_TOKEN")
# =========================

# !pip -q install -U datasets sentence-transformers umap-learn scikit-learn faiss-cpu huggingface_hub

import re
import numpy as np
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Tuple

from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import umap
from sklearn.mixture import GaussianMixture
import faiss

from huggingface_hub import InferenceClient
from google.colab import userdata


# ----------------------------
# 0) Tokenization + chunking
# ----------------------------
def simple_tokens(text: str) -> List[str]:
    return re.findall(r"\w+|[^\w\s]", text, flags=re.UNICODE)

def chunk_text_sentence_aware(text: str, max_tokens: int = 100) -> List[str]:
    text = re.sub(r"\s+", " ", text).strip()
    if not text:
        return []

    sents = re.split(r"(?<=[\.\?\!])\s+", text)

    chunks, cur, cur_len = [], [], 0
    for s in sents:
        s = s.strip()
        if not s:
            continue
        tok_len = len(simple_tokens(s))
        if tok_len == 0:
            continue

        # hard split too-long sentence
        if tok_len > max_tokens:
            if cur:
                chunks.append(" ".join(cur).strip())
                cur, cur_len = [], 0
            toks = simple_tokens(s)
            for i in range(0, len(toks), max_tokens):
                piece = " ".join(toks[i:i + max_tokens]).strip()
                if piece:
                    chunks.append(piece)
            continue

        # pack into chunks
        if cur_len + tok_len <= max_tokens:
            cur.append(s)
            cur_len += tok_len
        else:
            chunks.append(" ".join(cur).strip())
            cur = [s]
            cur_len = tok_len

    if cur:
        chunks.append(" ".join(cur).strip())

    return [c for c in chunks if c.strip()]


# ----------------------------
# 1) GMM selection by BIC
# ----------------------------
def bic_best_gmm(X: np.ndarray, k_min: int = 2, k_max: int = 12, seed: int = 42):
    n = len(X)
    if n < 3:
        return None

    k_max = min(k_max, n - 1)
    k_max = min(k_max, max(2, n // 2))
    if k_max < k_min:
        return None

    best_gmm, best_bic = None, float("inf")
    for k in range(k_min, k_max + 1):
        gmm = GaussianMixture(n_components=k, covariance_type="full", random_state=seed)
        gmm.fit(X)
        bic = gmm.bic(X)
        if bic < best_bic:
            best_bic, best_gmm = bic, gmm
    return best_gmm


# ----------------------------
# 2) Node structure (RAPTOR)
# ----------------------------
@dataclass
class Node:
    node_id: int
    text: str
    level: int
    children: List[int] = field(default_factory=list)
    embedding: Optional[np.ndarray] = None
    token_len: int = 0


# ----------------------------
# 3) Hugging Face LLM wrapper
# ----------------------------
class HFChatLLM:
    def __init__(self, api_key: str, model: str = "mistralai/Mistral-7B-Instruct-v0.2"):
        self.client = InferenceClient(api_key=api_key)
        self.model = model

    def chat(self, user_msg: str, max_tokens: int = 256) -> str:
        resp = self.client.chat.completions.create(
            model=self.model,
            messages=[{"role": "user", "content": user_msg}],
            max_tokens=max_tokens,
        )
        return resp.choices[0].message.content.strip()


# ----------------------------
# 4) RAPTOR implementation
# ----------------------------
class Raptor:
    def __init__(
        self,
        llm: HFChatLLM,
        embedder: SentenceTransformer,
        seed: int = 42,
        umap_dim: int = 10,
        base_neighbors: int = 15,
    ):
        self.llm = llm
        self.embedder = embedder
        self.seed = seed
        self.umap_dim = umap_dim
        self.base_neighbors = base_neighbors
        self.nodes: Dict[int, Node] = {}
        self.next_id = 0

    def reset(self):
        self.nodes = {}
        self.next_id = 0

    def _new_node(self, text: str, level: int, children: List[int]) -> int:
        nid = self.next_id
        self.next_id += 1
        self.nodes[nid] = Node(
            node_id=nid,
            text=text,
            level=level,
            children=children,
            token_len=len(simple_tokens(text))
        )
        return nid

    def _embed_nodes(self, node_ids: List[int]) -> None:
        texts = [self.nodes[i].text for i in node_ids]
        embs = self.embedder.encode(texts, normalize_embeddings=True, show_progress_bar=False)
        for i, e in zip(node_ids, embs):
            self.nodes[i].embedding = e.astype(np.float32)

    def _make_umap(self, n_samples: int):
        # Adaptive + init="random" avoids spectral k>=N error
        nn = min(self.base_neighbors, max(2, n_samples - 1))
        ncomp = min(self.umap_dim, max(2, n_samples - 1))
        return umap.UMAP(
            n_components=ncomp,
            n_neighbors=nn,
            metric="cosine",
            random_state=self.seed,
            init="random",
            low_memory=True,
        )

    def summarize_cluster(self, child_texts: List[str]) -> str:
        joined = "\n\n".join(child_texts)[:12000]
        prompt = (
            "You are summarizing a cluster of medical research passages.\n"
            "Produce 5â€“8 bullet points capturing key facts, methods, results, and entities.\n"
            "Be information-dense and avoid fluff.\n\n"
            f"PASSAGES:\n{joined}\n\nSUMMARY:"
        )
        return self.llm.chat(prompt, max_tokens=240)

    def build_tree(
        self,
        document_text: str,
        chunk_tokens: int = 100,
        max_levels: int = 3,
        min_cluster_size: int = 3,
        min_points_for_clustering: int = 8,  # small-N stop fix
    ) -> List[int]:
        self.reset()

        chunks = chunk_text_sentence_aware(document_text, max_tokens=chunk_tokens)
        leaf_ids = [self._new_node(c, level=0, children=[]) for c in chunks]
        if not leaf_ids:
            return []

        self._embed_nodes(leaf_ids)

        current = leaf_ids
        level = 0

        while level < max_levels:
            if len(current) < max(min_cluster_size, min_points_for_clustering):
                break

            X = np.stack([self.nodes[i].embedding for i in current], axis=0)
            reducer = self._make_umap(n_samples=len(current))
            Xr = reducer.fit_transform(X)

            gmm = bic_best_gmm(Xr, k_min=2, k_max=min(12, len(current) - 1), seed=self.seed)
            if gmm is None:
                break

            labels = gmm.predict(Xr)
            clusters: Dict[int, List[int]] = {}
            for nid, lab in zip(current, labels):
                clusters.setdefault(int(lab), []).append(nid)

            if len(clusters) <= 1:
                break

            parents, new_parent_ids = [], []
            for kids in clusters.values():
                if len(kids) < min_cluster_size:
                    parents.extend(kids)
                    continue

                summary = self.summarize_cluster([self.nodes[k].text for k in kids])
                pid = self._new_node(summary, level=level + 1, children=kids)
                parents.append(pid)
                new_parent_ids.append(pid)

            if not new_parent_ids:
                break

            self._embed_nodes(new_parent_ids)
            current = parents
            level += 1

        return current

    def collapsed_tree_retrieval(
        self,
        query: str,
        top_k: int = 12,
        max_tokens: int = 900,
    ) -> Tuple[str, List[int]]:
        if not self.nodes:
            return "", []

        q = self.embedder.encode([query], normalize_embeddings=True, show_progress_bar=False)[0].astype(np.float32)

        node_ids = list(self.nodes.keys())
        embs = np.stack([self.nodes[i].embedding for i in node_ids], axis=0).astype(np.float32)

        index = faiss.IndexFlatIP(embs.shape[1])
        index.add(embs)

        _, idxs = index.search(q.reshape(1, -1), min(len(node_ids), top_k * 6))

        chosen, total = [], 0
        for j in idxs[0]:
            nid = node_ids[int(j)]
            tlen = self.nodes[nid].token_len
            if total + tlen > max_tokens:
                continue
            chosen.append(nid)
            total += tlen
            if len(chosen) >= top_k:
                break

        context = "\n\n".join([self.nodes[i].text for i in chosen])
        return context, chosen


# ----------------------------
# 5) Dataset formatting (PubMedQA)
# ----------------------------
def pubmedqa_to_document(example) -> str:
    ctx = example.get("context", {}) or {}
    contexts = ctx.get("contexts", [])
    if isinstance(contexts, list):
        ctx_text = " ".join([str(x) for x in contexts if x])
    else:
        ctx_text = str(contexts)
    long_answer = str(example.get("long_answer", "") or "")
    return f"{ctx_text}\n\n{long_answer}".strip()


# ----------------------------
# 6) Answer helper (optional)
# ----------------------------
def answer_with_llm(llm: HFChatLLM, question: str, context: str, max_tokens: int = 220) -> str:
    prompt = (
        "Answer the medical question using ONLY the provided context.\n"
        "Return:\n"
        "Decision: yes / no / maybe\n"
        "Justification: 2-4 sentences grounded in context\n\n"
        f"CONTEXT:\n{context}\n\nQUESTION:\n{question}\n\nANSWER:"
    )
    return llm.chat(prompt, max_tokens=max_tokens)


# ----------------------------
# 7) Run RAPTOR on 1 example + 1 or 2 queries
# ----------------------------
def run_raptor_pubmedqa_single_example(
    example_idx: int = 0,
    queries: Optional[List[str]] = None,
    chunk_tokens: int = 100,
    max_levels: int = 3,
    top_k: int = 12,
    max_context_tokens: int = 900,
):
    HF_TOKEN = userdata.get("HF_TOKEN")  # âœ… hidden secret
    if HF_TOKEN is None:
        raise ValueError("HF_TOKEN not found in Colab Secrets. Add it in the ðŸ”‘ Secrets panel and allow notebook access.")

    llm = HFChatLLM(api_key=HF_TOKEN, model="mistralai/Mistral-7B-Instruct-v0.2")
    embedder = SentenceTransformer("sentence-transformers/multi-qa-mpnet-base-cos-v1")
    raptor = Raptor(llm=llm, embedder=embedder)

    # Load ONE example (public healthcare dataset)
    ds = load_dataset("pubmed_qa", "pqa_labeled", split="train")
    ex = ds[example_idx]

    question = ex["question"]
    gold = (ex.get("final_decision") or "").strip()
    doc_text = pubmedqa_to_document(ex)

    print("=" * 100)
    print("DATASET: PubMedQA (pqa_labeled)")
    print("Example index:", example_idx)
    print("Question:", question)
    print("Gold label:", gold)
    print("Doc length (chars):", len(doc_text))
    print("=" * 100)

    # Build RAPTOR tree for this document
    roots = raptor.build_tree(
        doc_text,
        chunk_tokens=chunk_tokens,
        max_levels=max_levels,
        min_cluster_size=3,
        min_points_for_clustering=8
    )
    print("RAPTOR built.")
    print("Total nodes:", len(raptor.nodes))
    print("Root node ids:", roots)

    # Default: use dataset question + one extra query (2 queries total)
    if queries is None:
        queries = [
            question,
            "What is the main conclusion and evidence from this study?"
        ]

    # Run 1â€“2 queries
    for qi, q in enumerate(queries[:2], start=1):
        print("\n" + "-" * 100)
        print(f"QUERY {qi}: {q}")

        context, chosen = raptor.collapsed_tree_retrieval(
            q, top_k=top_k, max_tokens=max_context_tokens
        )

        print(f"\nRetrieved nodes: {len(chosen)} (showing first 1500 chars of context)")
        print(context[:1500])

        # Optional: ask LLM to answer using retrieved context
        ans = answer_with_llm(llm, q, context)
        print("\nLLM Answer:")
        print(ans)


# ----------------------------
# Run it
# ----------------------------
run_raptor_pubmedqa_single_example(
    example_idx=0,   # change to try different PubMedQA sample
    queries=None,    # or pass your own list: ["...", "..."]
    chunk_tokens=100,
    max_levels=3,
    top_k=12,
    max_context_tokens=900,
)