# 🧠 RADE: Retriever-Augmented Document Entity Extraction

**Note: this leverages in-memory encoding/searching and does not require Faiss GPU install. However, a GPU machine may be required**

This notebook demonstrates how to extract information from trust documents using the **RADE** pipeline.  
Each query retrieves relevant chunks using **ColBERT**, then:
- Runs a **QA model** (RoBERTa) to extract precise answers
- Uses **GLiNER** for entity extraction tied to each query

***
## Section 1:

This section demonstrates how to use the **RADE** (Retriever-Augmented Document Entity Extraction) system for parsing and encoding an unstructured pdf documents in memory. This is an alternative to the FAISS indexinf approach. Please note, due to the in memory encoding, document numbers and sizes need to be limited

RADE integrates:
- **ColBERT** for semantic retrieval
- **Azure Document Intelligence** for document parsing (PDFs & scanned pages)

---

## 🗂️ Steps:
1. **Initialize RADE** with model names and Azure credentials
2. **Add a document** (PDF) — scanned or digital
3. **Build the ColBERT encodings in memory**
4. **Retrieve top-k relevant pages**


### 🔧 Step 1 – Initialize RADE

In [0]:
#initialize
from rade import RADE  # assuming you saved it as rade.py
rade = RADE()


## 📄 Step 2 – Add a PDF Document

In [0]:
# ✅ Define additional QA Queries

pdf_path = "/Workspace/Users/morashed@hdinhfdic.onmicrosoft.com/RADE/data/living-trust-forms-03_Ileana Hardison Rev Family Trust DTD 05052010.pdf.crdownload"

rade.add_document(pdf_path, 1)

📚 Step 3 – encode the document in memory

In [0]:
rade.encode_docuemnt()

🔍 Step 4 – Retrieve Top-k Pages for a Query

In [0]:
results = rade.search_encoded_docs(["Who are the grantors"], k=3)
for idx, r in enumerate(results):
    print(f"📄 Result #{idx} \n Page {r.page.page_num} | Score: {r.score:.2f} | DocName: {r.page.doc_name}")
    print(r.page.text)

#Call clear after every run
# rade.retrieval_model.clear_encoded_docs()

In [0]:
results[1]

---

## Section 2 🔍 RADE Trust Document Analyzer -  Entity and QA Extraction from Trust Documents

This section supports:
- Named Entity Recognition (NER) using GLiNER for labeled entity queries
- Question Answering (QA) using RoBERTa for direct questions

Each result:
- Retrieves top-k chunks via ColBERT
- Runs either QA or NER based on the query type
- Presents results interactively using arrows to flip through top passages



### Step1 🧩 – Define Query-to-Label Mapping

Use the schema as below:

```python
{
  "Who are the Grantors?": {
    "type": "ner",
    "labels": ["Grantor", "Settlor"]
  },
  "Who are the Trustees?": {
    "type": "ner",
    "labels": ["Primary Trustee", "Trustee"]
  },
  "Who are the Successor Trustees?": {
    "type": "ner",
    "labels": ["Successor Trustee"]
  },
  "Who are the Beneficiaries?": {
    "type": "ner",
    "labels": ["Primary Beneficiary", "Beneficiary", "Residuary Beneficiary"]
  },
  "Who are the Successor Beneficiaries?": {
    "type": "ner",
    "labels": ["Successor Beneficiary", "Secondary Beneficiary"]
  },
  "What is the name of the trust?": {
    "type": "qa",
    "labels": []
  },
  "What is the date of the trust?": {
    "type": "qa",
    "labels": []
  },
  "Is this trust revocable or irrevocable?": {
    "type": "qa",
    "labels": []
  }
}


In [0]:
#read query file
import json
query_file = "data/query_plan.json"
with open(query_file, "r") as j:
    query_plan = json.load(j)


In [0]:
#function to use RADE extractive model
def extract_entities_per_page(rade, retrieved_pages, labels):
    page_entities = {}

    for rp in retrieved_pages:
        text = rp.page.text.replace("\n", " ")
        entities = rade.entity_extraction_model.predict_entities(text, labels)
        key = f"{rp.page.doc_name}_page{rp.page.page_num}"
        page_entities[key] = entities

    return page_entities


## 🧠 Step 2 – Run RADE: Parse + Index + Query

In [0]:
#Run queries and following NER and QA pipelines with Rade
query_results = []

for query, meta in query_plan.items():
    retrieved = rade.search_encoded_docs([query], 3)

    if meta["type"] == "qa":
        qa_result = rade.run_qa_pipeline(
            question=query,
            retrieved_texts=[
                {
                    "content": rp.page.text,
                    "document_metadata": {
                        "document": rp.page.doc_name,
                        "page": rp.page.page_num
                    }
                } for rp in retrieved
            ]
        )
        query_results.append({
            "query": query,
            "type": "qa",
            "labels": [],
            "results": retrieved,
            "qa": qa_result,
            "entities": {}
        })

    elif meta["type"] == "ner":
        retrieved_texts = [
            {
                "content": rp.page.text,
                "document_metadata": {
                    "document": rp.page.doc_name,
                    "page": rp.page.page_num
                }
            } for rp in retrieved
        ]

        entity_result = rade.extract_entities_with_gliner(
            retrieved_texts=retrieved_texts,
            labels=meta["labels"],
            threshold=0.3
        )

        query_results.append({
            "query": query,
            "type": "ner",
            "labels": meta["labels"],
            "results": retrieved,
            "qa": None,
            "entities": {
                f"{r['document']}_page{r['page']}": [
                    e for e in entity_result["entities"]
                    if e in entity_result["entities"]
                ]
                for r in entity_result["retrieved"]
            }
        })


## Step 3 - Visualize

In [0]:
import html
import ipywidgets as widgets
from IPython.display import display, Markdown, HTML, clear_output

def highlight_entities(text: str, entities: list, highlight_answer: str = None) -> str:
    """
    Highlight QA answer and NER entities using bright yellow <mark>.
    """
    text = html.escape(text)

    # Highlight QA answer (first)
    if highlight_answer:
        safe_answer = html.escape(highlight_answer.strip())
        if safe_answer and safe_answer in text:
            text = text.replace(
                safe_answer,
                f'<mark style="background-color:yellow" title="QA Answer">{safe_answer}</mark>'
            )

    # Highlight NER entities
    entities = sorted(entities, key=lambda e: -len(e['text']))
    for ent in entities:
        safe_text = html.escape(ent["text"])
        label = html.escape(ent["label"])
        score = ent.get("score", None)
        tooltip = f"{label}" + (f" ({score:.2f})" if score else "")
        text = text.replace(
            safe_text,
            f'<mark style="background-color:yellow" title="{tooltip}">{safe_text}</mark>'
        )

    return text


def show_query_results(query_result: dict):
    query = query_result["query"]
    query_type = query_result["type"]
    pages = query_result["results"]
    entity_map = query_result["entities"]
    qa = query_result["qa"]

    index = widgets.IntText(value=0, layout=widgets.Layout(width="40px"), disabled=True)
    output = widgets.Output()

    prev_button = widgets.Button(description="◀️ Prev", layout=widgets.Layout(width="80px"))
    next_button = widgets.Button(description="Next ▶️", layout=widgets.Layout(width="80px"))
    nav_box = widgets.HBox([prev_button, index, next_button])

    def render_page(i):
        with output:
            clear_output(wait=True)
            page = pages[i].page
            score = pages[i].score
            key = f"{page.doc_name}_page{page.page_num}"
            ents = entity_map.get(key, []) if query_type == "ner" else []

            display(Markdown(f"## 🔎 Query: `{query}`"))

            if query_type == "qa":
                answer = qa.get("answer", "[No answer]")
                display(Markdown(f"### 🤖 QA Answer: `{answer}`"))
            else:
                answer = None  # Not used

            display(Markdown(f"**📄 Page {page.page_num} — Score: `{score:.2f}` — File: `{page.doc_name}`**"))

            highlighted_text = highlight_entities(
                text=page.text[:3000],
                entities=ents,
                highlight_answer=answer
            )

            display(HTML(f"<pre style='line-height:1.5;font-family:monospace'>{highlighted_text}</pre>"))

    def on_prev_clicked(_):
        if index.value > 0:
            index.value -= 1
            render_page(index.value)

    def on_next_clicked(_):
        if index.value < len(pages) - 1:
            index.value += 1
            render_page(index.value)

    prev_button.on_click(on_prev_clicked)
    next_button.on_click(on_next_clicked)

    display(nav_box, output)
    render_page(index.value)


In [0]:
for result in query_results:
    show_query_results(result)

