RAG System based on NUMEN Retrieval

Author: Sangeet Sharma 

:)

In [None]:
#NUMEN RAG 

# Cell 0: Setup & Installation 

!pip install datasets torch numpy faiss-cpu accelerate bitsandbytes -q
!pip install git+https://github.com/huggingface/transformers.git -q
print("Dependencies installed.")


# Download MIRAGE Benchmark
import os
if not os.path.exists("benchmark.json"):
    print("Downloading MIRAGE Benchmark...")
    !wget -q https://raw.githubusercontent.com/Teddy-XiongGZ/MIRAGE/main/benchmark.json -O benchmark.json
    print("Dataset 'benchmark.json' downloaded.")
else:
    print("Dataset already present.")

print("Setup complete!")


In [None]:
# Cell 1: Imports & Configuration

# Numen RAG: A Self-Auditing Medical QA System

import numpy as np
import zlib
import time
import os
from typing import List, Tuple, Dict, Union
from dataclasses import dataclass
from datasets import load_dataset
import warnings

# Try importing FAISS
try:
    import faiss
except ImportError:
    print("FAISS not found. Installing...")
    # Fallback if cell 0 wasn't run
    os.system('pip install faiss-cpu') 
    import faiss

warnings.filterwarnings('ignore')

@dataclass
class Config:
    DIM: int = 32768             # Numen Dimension (High dim for precision)
    NGRAMS: Tuple = (3, 4, 5, 6, 7, 8) # Extended n-grams for medical terms
    TOP_K: int = 10              # Final number of context documents (increased for better coverage)
    INITIAL_K: int = 30          # Initial retrieval before reranking (increased for recall)
    VETO_THRESHOLD: float = 0.10  # Below this = automatic rejection (extreme hallucination)
    HIGH_CONFIDENCE: float = 0.15 # Above this + oracle = high confidence     

config = Config()
print("Configuration Loaded (FAISS Scalable Mode).")


# LLM SETUP (DeepSeek-R1 Distill Llama 8B)

GLOBAL_LLM_CLIENT = None

class DeepSeekLLM:
    """Downloads and loads DeepSeek-R1 Distill 8B from HuggingFace."""
    def __init__(self, model_id="deepseek-ai/DeepSeek-R1-Distill-Llama-8B"):
        from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
        import torch
        
        print(f"[System] Downloading/Loading {model_id} from HF...")
        
        # 4-bit quantization to ensure it fits in 15GB GPU
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_quant_type="nf4"
        )
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            quantization_config=bnb_config,
            device_map="auto",
            trust_remote_code=True
        )
        self.chat = self._Chat(self.model, self.tokenizer)
        print(f"[System] {model_id} loaded successfully.")

    class _Chat:
        def __init__(self, model, tokenizer):
            self.completions = self._Completions(model, tokenizer)

        class _Completions:
            def __init__(self, model, tokenizer):
                self.model = model
                self.tokenizer = tokenizer

            def create(self, messages, temperature=0.7, max_tokens=1024, **kwargs):
                import torch
                # Standard HF chat template usage
                prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
                
                inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
                with torch.no_grad():
                    outputs = self.model.generate(
                        **inputs,
                        max_new_tokens=max_tokens,
                        temperature=temperature if temperature > 0 else 0.1,
                        do_sample=True if temperature > 0 else False,
                        pad_token_id=self.tokenizer.eos_token_id
                    )
                
                # Robust extraction: decode only the newly generated tokens
                input_length = inputs.input_ids.shape[1]
                generated_tokens = outputs[0][input_length:]
                response_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
                
                # If there's a reasoning block <thought>...</thought>, keep it or strip it?
                # R1 models often output <thought> segment. Let's keep it for better transparency.
                
                class MockResponse:
                    def __init__(self, text):
                        self.choices = [type('obj', (object,), {'message': type('obj', (object,), {'content': text})()})()]
                
                return MockResponse(response_text)

def initialize_llm_client():
    global GLOBAL_LLM_CLIENT
    if GLOBAL_LLM_CLIENT is None:
        try:
            GLOBAL_LLM_CLIENT = DeepSeekLLM()
        except Exception as e:
            print(f"[System] LLM Init Failed: {e}")
    return GLOBAL_LLM_CLIENT




# Cell 2: Layer 1 - Numen Core

class NumenCore:
    """
    Handles fast, training-free vectorization using CRC32 n-gram hashing.
    Used for Retrieval and Entity-Level Hallucination Detection.
    """
    def __init__(self, dim=config.DIM, ngrams=config.NGRAMS):
        self.dim = dim
        self.ngrams = ngrams
        
    def encode(self, text: str) -> np.ndarray:
        text = text.lower().strip()
        vec = np.zeros(self.dim, dtype=np.float32)
        if not text: return vec
        
        count = 0
        for n in self.ngrams:
            if len(text) < n: continue
            for i in range(len(text) - n + 1):
                gram = text[i:i+n]
                # Deterministic Hash
                h = zlib.crc32(gram.encode('utf-8')) & 0xffffffff
                idx = h % self.dim
                # Entity-aware weighting: Longer n-grams (medical entities) get higher weight
                weight = 1.0 + (n - 3) * 0.8
                vec[idx] += weight
                count += 1
                
        if count > 0:
            vec = np.log1p(vec) # Log-saturation
            norm = np.linalg.norm(vec)
            if norm > 0: vec /= norm
        return vec

print("NumenCore Class Ready.")



# Cell 3: Numen RAG System (Dense + Chunking + FAISS)

class NumenRAG:
    def __init__(self):
        self.numen = NumenCore()
        
        self.index = None   # FAISS Index
        self.documents = [] # Stores chunks
        self.client = None
        
        # Initialize LLM Client
        self.client = initialize_llm_client()

    def chunk_text(self, text: str, chunk_size: int = 800, overlap: int = 300) -> List[str]:
        """Semantic chunking at sentence boundaries for better context preservation."""
        if len(text) <= chunk_size:
            return [text]
        
        # Split into sentences first
        sentences = []
        for delimiter in ['. ', '\n', '? ', '! ']:
            if delimiter in text:
                parts = text.split(delimiter)
                sentences.extend([p.strip() + delimiter.strip() for p in parts if p.strip()])
                break
        
        if not sentences:
            sentences = [text]
        
        # Group sentences into chunks
        chunks = []
        current_chunk = ""
        
        for sent in sentences:
            if len(current_chunk) + len(sent) <= chunk_size:
                current_chunk += " " + sent
            else:
                if current_chunk:
                    chunks.append(current_chunk.strip())
                current_chunk = sent
        
        if current_chunk:
            chunks.append(current_chunk.strip())
        
        # Add overlap between chunks
        overlapped = []
        for i, chunk in enumerate(chunks):
            if i > 0 and len(chunks[i-1]) > overlap:
                # Add last part of previous chunk
                prev_overlap = chunks[i-1][-overlap:]
                overlapped.append(prev_overlap + " " + chunk)
            else:
                overlapped.append(chunk)
        
        return [c for c in overlapped if len(c) > 50]

    def index_data(self, doc_source: Union[Dict[str, str], List[str]]):
        """
        Builds Numen FAISS Index. Accepts Dict {id: text} or List [text].
        Auto-chunks long documents.
        """
        raw_docs = []
        if isinstance(doc_source, dict):
            raw_docs = list(doc_source.values())
        else:
            raw_docs = doc_source
            
        print(f"[Index] Pre-processing {len(raw_docs)} documents...")
        t0 = time.time()
        
        # Chunking Phase
        self.documents = []
        for doc in raw_docs:
            chunks = self.chunk_text(doc)
            self.documents.extend(chunks)
            
        print(f"[Index] Created {len(self.documents)} retrievable chunks.")

        # Batch Encode
        print("[Index] Encoding vectors...")
        # Note: For huge datasets, we would encode in batches.
        # For <1M docs, this is fine in memory.
        matrix = np.zeros((len(self.documents), self.numen.dim), dtype=np.float32)
        for i, text in enumerate(self.documents):
            matrix[i] = self.numen.encode(text)

        # Build FAISS Index
        print("[Index] Building FAISS Index...")
        self.index = faiss.IndexFlatIP(self.numen.dim) # Inner Product (Cosine since normalized)
        self.index.add(matrix)
            
        print(f"[Index] FAISS Index Built in {time.time() - t0:.3f}s. Stored {self.index.ntotal} vectors.")



    def retrieve(self, query: str, k=config.TOP_K) -> List[str]:
        """Hybrid Retrieval: Numen (Stage 1) + LLM Reranking (Stage 2)."""
        if not self.index:
            return []
        
        # Stage 1: Numen retrieval (fast, broad)
        q_vec = self.numen.encode(query).reshape(1, -1)
        scores, indices = self.index.search(q_vec, config.INITIAL_K)
        
        candidates = [self.documents[i] for i in indices[0] if i >= 0]
        if len(candidates) <= k:
            return candidates
        
        # Stage 2: LLM reranking (accurate, narrow)
        if not self.client:
            return candidates[:k]
        
        try:
            docs_str = "\n\n".join([f"[{i}] {doc[:200]}..." for i, doc in enumerate(candidates)])
            prompt = f"""Query: {query}

Documents:
{docs_str}

Rank the document IDs by relevance to the query. Output ONLY the top {k} IDs as comma-separated numbers (e.g., 3,7,1,9,2)."""
            
            resp = self.client.chat.completions.create(
                messages=[{"role": "user", "content": prompt}],
                temperature=0
            )
            
            # Parse ranked IDs
            ranked_ids = [int(x.strip()) for x in resp.choices[0].message.content.split(',') if x.strip().isdigit()]
            reranked = [candidates[i] for i in ranked_ids if i < len(candidates)]
            
            # Fallback to original order if parsing fails
            if len(reranked) < k:
                reranked.extend([c for c in candidates if c not in reranked])
            
            return reranked[:k]
        except:
            return candidates[:k]

    def multi_hop_retrieve(self, query: str, initial_docs: List[str]) -> List[str]:
        """Multi-hop reasoning: Extract key entities and retrieve additional context."""
        if not self.client or len(initial_docs) < 3:
            return initial_docs
        
        try:
            # Extract key medical entities from initial context
            sample = " ".join(initial_docs[:3])[:500]
            prompt = f"""Question: {query}
Context: {sample}

Extract 3-5 key medical terms/entities that need more context. Output as comma-separated list."""
            
            resp = self.client.chat.completions.create(
                messages=[{"role": "user", "content": prompt}],
                temperature=0,
                max_tokens=50
            )
            
            entities = resp.choices[0].message.content.strip()
            if entities and len(entities) > 5:
                # Second hop: retrieve using extracted entities
                hop2_vec = self.numen.encode(entities).reshape(1, -1)
                scores, indices = self.index.search(hop2_vec, 5)
                hop2_docs = [self.documents[i] for i in indices[0] if i >= 0]
                
                # Combine and deduplicate
                combined = initial_docs + [d for d in hop2_docs if d not in initial_docs]
                return combined[:config.TOP_K]
        except:
            pass
        
        return initial_docs

    def generate(self, query: str, context: List[str], strict=False, use_multihop=False) -> str:
        """Generates answer from context with nuanced medical reasoning."""
        if not self.client:
            return "[Error] API Client not initialized."
        
        # Multi-hop for complex questions
        if use_multihop and len(context) >= 3:
            context = self.multi_hop_retrieve(query, context)

        ctx_str = "\n\n".join([f"[{i+1}] {doc}" for i, doc in enumerate(context)])
        
        # Cleaner Prompt Structure for R1
        system_msg = """You are a medical expert assistant. 
Use the provided Sources to answer the Question.
- Think step-by-step in <thought> tags (optional).
- Answer PRECISELY based on the text.
- If multiple choice, pick the best letter (A/B/C/D).
- Cite sources [1] etc."""
        
        user_msg = f"""=== Sources ===
{ctx_str}

=== Question ===
{query}

=== Instruction ===
Answer based strictly on the above sources."""

        try:
            resp = self.client.chat.completions.create(
                messages=[ 
                    {"role": "system", "content": system_msg},
                    {"role": "user", "content": user_msg}
                ],
                temperature=0.6 
            )
            raw_content = resp.choices[0].message.content
            
            # Clean BPE Artifacts (DeepSeek/Llama specific)
            raw_content = raw_content.replace('Ġ', ' ').replace('Ċ', '\n').replace('ĉ', '\t').replace('Ã', '')
            
            # Post-Process: Extract Final Answer
            import re
            clean_answer = raw_content
            thinking = ""
            
            # Pattern to capture thought blocks (R1 style)
            thought_match = re.search(r'<thought>(.*?)</thought>', raw_content, re.DOTALL)
            if thought_match:
                thinking = thought_match.group(1).strip()
                # Remove thought from answer
                clean_answer = re.sub(r'<thought>.*?</thought>', '', raw_content, flags=re.DOTALL).strip()
            
            if thinking:
                print(f"\n[Reasoning] {thinking[:500]}...")  # Increased visible reasoning length
            
            return clean_answer
        except Exception as e:
            return f"[Gen Error: {e}]"

print("NumenRAG Ready.")



In [None]:

print("NumenRAG System v7 (SOTA Push: Semantic Chunking + Multi-Hop Reasoning) Ready.")


# Cell 5: Dual Verification Loop (Hash Veto + Oracle + Confidence Levels)
# ------------------------------------------------------------------
def run_complete_loop(rag_system, query: str):
    print(f"\n[Question] {query[:80]}...")
    
    # 1. Hybrid Retrieve (Numen + LLM Reranking)
    docs = rag_system.retrieve(query)
    if not docs:
        return {'status': 'NO_CONTEXT', 'final_answer': 'No relevant context found', 'confidence': 'NONE', 'metrics': 0.0}
    
    # 2. Generate with Dual Verification
    best_answer = ""
    best_verdict = "UNSAFE"
    best_confidence = "LOW"
    best_hash = 0.0
    
    for attempt in range(2):
        # Use multi-hop on second attempt if first failed
        use_multihop = (attempt > 0 and best_confidence in ["LOW", "REJECTED"])
        cand_ans = rag_system.generate(query, docs, strict=(attempt>0), use_multihop=use_multihop)
        
        # Hash Score (Lexical Overlap)
        ans_vec = rag_system.numen.encode(cand_ans)
        ctx_vec = rag_system.numen.encode(" ".join(docs))
        hash_score = float(np.dot(ans_vec, ctx_vec))
        
        # Oracle Check (Semantic Verification)
        oracle_v = "SAFE"
        if rag_system.client:
            ctx_sample = " ".join(docs)[:1200]
            p = f"""Context: {ctx_sample}

Answer: {cand_ans}

Is the answer supported by the context? Reply ONLY 'YES' or 'NO'."""
            try:
                r = rag_system.client.chat.completions.create(
                    messages=[{"role": "user", "content": p}],
                    temperature=0
                )
                if "NO" in r.choices[0].message.content.upper(): 
                    oracle_v = "UNSAFE"
            except: 
                pass
        
        # Dual Verification Logic with Answer-Length Aware Veto
        answer_words = len(cand_ans.split())
        if answer_words < 15:
            # Short answer (yes/no, single choice): lenient veto
            veto_threshold = 0.05
        else:
            # Long answer (clinical explanation): standard veto
            veto_threshold = config.VETO_THRESHOLD
        
        if hash_score < veto_threshold:
            # VETO: Extreme hallucination detected
            verdict = "UNSAFE"
            confidence = "REJECTED"
        elif hash_score >= config.HIGH_CONFIDENCE and oracle_v == "SAFE":
            # HIGH: Both hash and oracle agree
            verdict = "SAFE"
            confidence = "HIGH"
        elif oracle_v == "SAFE":
            # MEDIUM: Oracle approves, but low lexical overlap (paraphrase/synonym)
            verdict = "SAFE"
            confidence = "MEDIUM"
        else:
            # LOW: Oracle rejected
            verdict = "UNSAFE"
            confidence = "LOW"
        
        # Keep best candidate
        if confidence in ["HIGH", "MEDIUM"] and best_confidence not in ["HIGH", "MEDIUM"]:
            best_answer = cand_ans
            best_verdict = verdict
            best_confidence = confidence
            best_hash = hash_score
            if confidence == "HIGH":
                break  # Found high confidence answer
        elif attempt == 1 and best_confidence == "LOW":
            best_answer = cand_ans
            best_verdict = verdict
            best_confidence = confidence
            best_hash = hash_score

    print(f"  > [Audit] Hash: {best_hash:.3f} | Oracle: {best_verdict} | Confidence: {best_confidence}")
    print(f"  > [Answer] {best_answer}")
        
    return {
        'status': best_verdict,
        'final_answer': best_answer,
        'confidence': best_confidence,
        'metrics': 1.0 if best_verdict == "SAFE" else 0.0
    }

print("Verification Loop Ready.")


# Cell 6: Data Loading & Main Execution (with Accuracy)
# ------------------------------------------------------------------
def load_mirage_data():
    """
    Loads official MIRAGE benchmark data from 'benchmark.json'.
    Returns the full dictionary of datasets.
    """
    import json
    import os
    import urllib.request

    filename = "benchmark.json"
    url = "https://raw.githubusercontent.com/Teddy-XiongGZ/MIRAGE/main/benchmark.json"

    # 1. auto-download
    if not os.path.exists(filename):
        print(f"[Data] Downloading MIRAGE Benchmark from {url}...")
        try:
            urllib.request.urlretrieve(url, filename)
            print(" > Download complete.")
        except Exception as e:
            print(f"[Data] Download failed: {e}")
            return None

    # 2. Load JSON
    try:
        print(f"[Data] Loading MIRAGE Benchmark suite from {filename}...")
        with open(filename, 'r', encoding='utf-8') as f:
            full_benchmark = json.load(f)
            
        print(f" > Benchmark Datasets Found: {list(full_benchmark.keys())}")
        return full_benchmark
        
    except Exception as e:
        print(f"[Data] Error processing benchmark.json: {e}")
        return None

def load_real_pubmed_data():
    """
    Robustly loads PubMedQA from Hugging Face.
    Returns: documents (dict), queries (list of dicts with 'question', 'id', 'truth')
    """
    print("\n[Data] Loading PubMedQA dataset (Standard PQA-L)...")
    try:
        # User request: avoid BigBio/trust_remote_code issues. 
        # Using standard 'pqa_labeled' subset which is clean and safe.
        dataset = load_dataset("pubmed_qa", "pqa_labeled", split="train")
        print(" > Loaded Standard PubMedQA")
    except Exception as e:
        print(f" > PubMedQA Load Failed: {e}")
        return {}, []
        
    documents = {}
    queries = []
    
    # Limit for demo speed
    MAX_DOCS = 1000 
    TEST_SET_SIZE = 10  # Number of questions to evaluate
    
    print(f"[Data] Processing first {MAX_DOCS} samples...")
    for i, item in enumerate(dataset):
        if i >= MAX_DOCS: break
        
        # Handle Schema
        pid = str(item.get('pubid') or item.get('document_id') or str(i))
        
        # Context
        context_raw = item['context']
        if isinstance(context_raw, dict) and 'contexts' in context_raw:
            text = " ".join(context_raw['contexts'])
        elif isinstance(context_raw, list):
            text = " ".join(context_raw)
        else:
            text = str(context_raw)
            
        documents[pid] = text
        
        # Extract Ground Truth (Long Answer)
        # Standard has 'long_answer', BigBio might use 'answer'
        truth = item.get('long_answer') or item.get('final_decision') or "N/A"
        
        if i < TEST_SET_SIZE: 
            queries.append({
                'id': pid, 
                'question': item['question'],
                'truth': truth
            })
            
    return documents, queries

def evaluate_accuracy(client, system_answer, ground_truth):
    """Uses LLM to judge if system answer matches ground truth."""
    if not client: 
        print("  > [Eval] Skipped (No Client)")
        return False
    
    prompt = f"""Compare these two medical answers.
Ground Truth: {ground_truth}
System Answer: {system_answer}

Are they factually consistent? Reply EXACTLY 'YES' or 'NO'."""
    try:
        resp = client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}]
        )
        return "YES" in resp.choices[0].message.content.upper()
    except:
        return False

if __name__ == "__main__":
    print("\n=== Numen RAG: Production Loop (Full MIRAGE Suite) ===")
    
    # 1. Load Full Benchmark
    mirage_suite = load_mirage_data()
    
    if not mirage_suite:
        # Fallback to PubMedQA (Standard) if MIRAGE fails
        print("[System] MIRAGE failed, falling back to Standard PubMedQA...")
        docs, queries = load_real_pubmed_data()
        datasets_to_run = {"Standard_PubMedQA": queries}
    else:
        datasets_to_run = {}
        # Parse MIRAGE into our usable format
        for ds_name, ds_data in mirage_suite.items():
            print(f"\n[Parser] Processing {ds_name}...")
            parsed_queries = []
            
            # Sort keys to ensure deterministic order
            for qid in sorted(ds_data.keys()):
                # No limit - test all questions for overnight run
                
                item = ds_data[qid]
                q_text = item['question']
                
                # Check for options (MCE) vs Yes/No
                options = item.get('options', {})
                answer_key = item['answer']
                
                if options:
                    # Multiple Choice
                    truth_val = options.get(answer_key, "N/A")
                    truth_str = f"{answer_key}: {truth_val}"
                    # Context for Indexing: Question + Options
                    doc_blob = f"Question: {q_text}\nOptions:\n" + \
                               "\n".join([f"{k}: {v}" for k,v in options.items()])
                else:
                    # Yes/No (PubMedQA/BioASQ in MIRAGE format)
                    truth_str = answer_key # likely "yes", "no", "maybe"
                    # Context for Indexing: Question only (Self-retrieval)
                    doc_blob = f"Question: {q_text}"

                parsed_queries.append({
                    'id': qid,
                    'question': q_text,
                    'truth': truth_str,
                    'doc_context': doc_blob 
                })
            
            datasets_to_run[ds_name] = parsed_queries

    # 2. Build One Massive Index for Production Run
    rag = NumenRAG()
    
    print(f"\n[Production] Building unified index for all datasets...")
    all_docs = []
    for ds_name, test_queries in datasets_to_run.items():
        all_docs.extend([q['doc_context'] for q in test_queries])
    
    print(f"[Production] Indexing {len(all_docs)} documents across all benchmarks...")
    rag.index_data(all_docs)
    
    # 3. Execution Loop across all datasets
    for ds_name, test_queries in datasets_to_run.items():
        print(f"\n" + "="*60)
        print(f"BENCHMARK: {ds_name.upper()}")
        print(f"="*60)
        
        if not test_queries:
            print(f"[Warn] No queries found for {ds_name}.")
            continue
        
        correct = 0
        for i, q_item in enumerate(test_queries):
            print(f"\n[Progress] Question {i+1}/{len(test_queries)} | Correct: {correct}/{i if i > 0 else 1} ({(correct/max(i,1))*100:.1f}%)")
            q_text = q_item['question']
            print(f"[Question] {q_text}")
            truth = q_item['truth']
            
            # Run Pipeline
            result = run_complete_loop(rag, q_text)
            print(f"[Answer] {result['final_answer']}")
            # Eval
            is_correct = evaluate_accuracy(rag.client, result['final_answer'], truth)
            if is_correct:
                correct += 1
                print(f"  > [Eval] PASS")
            else:
                print(f"  > [Eval] FAIL (Expected: {truth})")
            print("-" * 40)
            
        print(f"\n[Score] {ds_name}: {correct}/{len(test_queries)} ({(correct/len(test_queries))*100:.1f}%)")

    print(f"\n[System] Full Benchmark Compliance Run Complete.")
