In [None]:
!pip install -U -q \
    "torchvision" \
    "torch" \
    "transformers>=4.40.0" \
    "sentence-transformers" \
    "accelerate>=0.29.0" \
    "bitsandbytes" \
    "langchain-huggingface" \
    "langchain-text-splitters" \
    "faiss-cpu" \
    "pypdf" \
    "gradio" \
    "torchaudio" \
    "peft" \
    "langchain" \
    "pymupdf" \
    "langchain-core" \
    "langchain-community" \
    "python-multipart" \
    "fastapi" \
    "uvicorn"

In [None]:
import os
import re
import gc
import torch
import shutil
import threading
import uvicorn
import requests
import gradio as gr
import time
import concurrent.futures
import random


from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyMuPDFLoader, PyPDFLoader
from langchain_community.vectorstores import FAISS
from fastapi import FastAPI, UploadFile
from fastapi.responses import StreamingResponse
from operator import itemgetter

In [None]:
embedding_model_name = "BAAI/bge-m3"
embeddings = HuggingFaceEmbeddings(
    model_name=embedding_model_name,
    model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"},
    encode_kwargs={"normalize_embeddings": True}
)

model_name = "Qwen/Qwen2.5-7B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    attn_implementation="sdpa",
    torch_dtype=torch.bfloat16,
)

text_generation_pipeline = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=4096,
    do_sample=True,
    temperature=0.4,
    streamer=None,
)

llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
print(f"Loaded {model_name}")

In [68]:
app = FastAPI(title="RAG Server")

vector_store = None
ingested_files = set()

def parse_pdf(file_info):
    path, original_filename = file_info
    try:
        loader = PyMuPDFLoader(path)
        docs = loader.load()
    except Exception as e:
        try:
            loader = PyPDFLoader(path)
            docs = loader.load()
        except Exception as e2:
            print(f"Failed to parse {original_filename}: {e2}")
            return []
            
    for doc in docs:
        doc.metadata['source'] = original_filename
    return docs


In [69]:
@app.post("/ingest")
async def ingest_files(files: list[UploadFile]):
    global vector_store, ingested_files
    
    temp_dir = "temp_server_files"
    if os.path.exists(temp_dir): shutil.rmtree(temp_dir)
    os.makedirs(temp_dir, exist_ok=True)
    
    saved_files = [] 
    
    try:
        for file in files:
            path = os.path.join(temp_dir, file.filename)
            with open(path, "wb") as f:
                shutil.copyfileobj(file.file, f)
            saved_files.append((path, file.filename))
            ingested_files.add(file.filename)
            
        all_docs = []
        with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
            results = list(executor.map(parse_pdf, saved_files))
            
        for res in results:
            all_docs.extend(res)

        if not all_docs:
            return {"status": "error", "message": "No text extracted from files."}

        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=400, 
            chunk_overlap=50,
            separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""]
        )
        splits = text_splitter.split_documents(all_docs)
        
        if vector_store is None:
            vector_store = FAISS.from_documents(splits, embeddings)
        else:
            vector_store.add_documents(splits)
        
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            
        return {"status": "success", "chunks": len(splits), "files": len(saved_files)}
    
    except Exception as e:
        return {"status": "error", "message": str(e)}

In [70]:
@app.post("/chat_stream")
async def chat_stream(data: dict):
    question = data.get("message", "")
    if vector_store is None:
        def err(): yield "Please ingest documents first."
        return StreamingResponse(err(), media_type="text/plain")

    def get_tokens(text):
        return set(re.split(r'[^a-z0-9]+', text.lower()))

    query_tokens = get_tokens(question)
    query_nums = {t for t in query_tokens if t.isdigit()}
    
    STOP_WORD_STEMS = {
        'contract', 'agreement', 'pdf', 'file', 'the', 'in', 'of', 'and', 'or', 
        'to', 'for', 'with', 'a', 'service', 'what', 'are', 'is', 
        'this', 'can', 'be', 'how', 'why', 'who', 'do', 'does', 'under', 'which',
        'about', 'these', 'those', 'say', 'work'
    }

    def get_stem(word):
        if word.endswith('ies') and len(word) > 4:
            return word[:-3] + 'y'
        if word.endswith('es') and len(word) > 4:
            return word[:-2]
        if word.endswith('s') and not word.endswith('ss') and len(word) > 3:
            return word[:-1]
        return word

    search_all_intent = any(w in query_tokens for w in ['all'])

    query_words = set()
    for t in query_tokens:
        if t in query_nums:
            continue
        
        stemmed_t = get_stem(t)
        
        if stemmed_t not in STOP_WORD_STEMS and len(stemmed_t) > 3:
            query_words.add(stemmed_t)
            
    target_files = []

    if search_all_intent:
        target_files = list(ingested_files)
    else:
        for file_name in ingested_files:
            base_name = file_name.rsplit('.', 1)[0].lower()
            file_tokens = {get_stem(t) for t in get_tokens(base_name)}
            file_nums = {t for t in file_tokens if t.isdigit()}
            
            is_match = False
            
            if query_nums and file_nums:
                if not (query_nums & file_nums):
                    continue 
    
            if query_nums & file_nums:
                is_match = True
            elif query_words & file_tokens:
                is_match = True
            elif not is_match and query_words:
                clean_file_str = "".join(file_tokens) 
                for qw in query_words:
                    if len(qw) > 4 and qw in clean_file_str:
                        is_match = True
                        break
    
            if is_match:
                target_files.append(file_name)

    target_files = list(set(target_files))
    
    print(f"Query: '{question}'\n   -> Extracted Nums: {query_nums}\n   -> Extracted Words: {query_words}\n   -> MATCHED FILES: {target_files if target_files else 'ALL FILES'}")

    final_docs = []

    if target_files:
        retriever = vector_store.as_retriever(
            search_type="similarity", 
            search_kwargs={
                'k': 10, 
                'filter': lambda metadata: metadata.get("source") in target_files
            }
        )
        final_docs = retriever.invoke(question)

        found_sources = set(d.metadata['source'] for d in final_docs)
        for target in target_files:
            if target not in found_sources:
                from langchain_core.documents import Document
                final_docs.append(Document(
                    page_content=f"SYSTEM NOTE: User explicitly asked about '{target}', but no relevant text was found.",
                    metadata={'source': target}
                ))
    else:
        retriever = vector_store.as_retriever(
            search_type="similarity_score_threshold", 
            search_kwargs={
                'k': 40,
                'score_threshold': 0.25
            }
        )
        clean_query_words = [w for w in question.lower().split() if get_stem(w) not in STOP_WORD_STEMS]
        search_query = " ".join(clean_query_words) if clean_query_words else question
    
        final_docs = retriever.invoke(search_query)

        if not final_docs:
            def empty_err(): yield "No highly relevant sections found in the documents."
            return StreamingResponse(empty_err(), media_type="text/plain")

    unique_sources = list(set([d.metadata['source'] for d in final_docs]))
    num_docs = len(unique_sources)
    
    if num_docs > 3:
        format_instr = "## FORMATTING: You found data from >3 files. You **MUST** use a **Markdown Table**."
    elif num_docs > 1:
        format_instr = "## FORMATTING: You found data from 2-3 files. Use a **Bulleted List**."
    else:
        format_instr = "## FORMATTING: You found data from 1 file. Provide a direct, concise sentence. Do NOT use a table."

    grouped_docs = {}
    for d in final_docs:
        source = d.metadata['source']
        if source not in grouped_docs:
            grouped_docs[source] = []
        grouped_docs[source].append(d.page_content)

    context_parts = []
    for source, texts in grouped_docs.items():
        combined_text = "\n...".join(texts)
        context_parts.append(f'<document filename="{source}">\n{combined_text}\n</document>')
    
    context = "\n\n".join(context_parts)

    prompt = f"""<|im_start|>system
You are an ultra-strict, literal-minded Legal Contract Auditor. Your ONLY job is to extract exact text from the provided context.

{format_instr}

### CRITICAL RULES (FAILURE IS NOT AN OPTION):
1. **ABSOLUTE ISOLATION:** The context is divided by <document> tags. You MUST evaluate each document in a vacuum. A document cannot "borrow" or "inherit" numbers, names, or terms from another document.
2. **ZERO HALLUCINATION:** If a specific document does not contain the answer, you are FORBIDDEN from looking at other documents to fill in the blank. 
3. **MANDATORY QUOTING:** For every single document, you must find the EXACT sentence that contains the answer. If you cannot find a relevant quote within that specific <document> block, the answer DOES NOT EXIST in that document.
4. **HANDLING MISSING DATA:** If a document lacks the requested information, you MUST output exactly "None found" in the Exact Quote column and "Not specified in document." in the Summary column. 
5. **TABLE FORMAT:** Use exactly three columns: | Document | Exact Quote | Summary |. Provide exactly ONE row per document. DO NOT put spaces before the table.
7. **NO LOOPHOLES OR BREACH ASSISTANCE:** You are strictly FORBIDDEN from identifying legal loopholes, vulnerabilities, or methods to circumvent the contract. You must NEVER provide instructions, suggestions, or advice on how to breach the agreement, avoid obligations, or engage in illegal activities.

### EXAMPLE OF CORRECT BEHAVIOR:
If asked for "hourly rate", and Doc_A says "$50/hr" but Doc_B mentions no money:
| Doc_A | "The rate is $50/hr." | The rate is $50/hr. |
| Doc_B | None found | Not specified in document. |
<|im_end|>
<|im_start|>user
### Context:
{context}

### Question:
{question}
<|im_end|>
<|im_start|>assistant
"""
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    gen_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, temperature=0.3, do_sample=True)
    
    thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
    thread.start()

    def generate():
        sources_str = ", ".join(unique_sources)
        if not sources_str: sources_str = "None"
            
        yield f"**Context Searched:** {sources_str} |---| \n"
        for new_text in streamer:
            yield new_text

    return StreamingResponse(generate(), media_type="text/plain")

In [None]:
@app.post("/evaluate")
async def evaluate_system(data: dict):
    num_questions = data.get("num_questions", 3)
    
    if vector_store is None:
        return {"status": "error", "message": "No documents ingested. Please ingest files first."}
        
    docs = list(vector_store.docstore._dict.values())
    if len(docs) < 2:
        return {"status": "error", "message": "Not enough documents. Need at least 2 chunks."}
        
    results = []
    scores = []
    
    def generate_eval_text(prompt_text, max_tokens=512):
        inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
        outputs = model.generate(**inputs, max_new_tokens=max_tokens, temperature=0.1, do_sample=False)
        response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
        return response.strip()
    
    for i in range(num_questions):
        doc = random.choice(docs)
        source_file = doc.metadata.get("source", "the document")
        
        context = f"Document File Name: {source_file}\n\nContext:\n{doc.page_content}"
        
        qa_prompt = (
            "<|im_start|>system\nUse the document provided to generate a HIGHLY SPECIFIC question-answer pair.\n"
            "The question must be narrow and point to a specific factual detail (e.g., a specific dollar amount, a specific timeframe, or a unique condition).\n"
            "CRITICAL: Do NOT ask broad summary questions like 'What are the key aspects?' or 'What does this document cover?'\n"
            "CRITICAL: You MUST explicitly include the 'Document File Name' in your generated question. (e.g., 'According to Contract_X.pdf, what is the...').\n"
            "Use the format:\nQuestion: (your question)\nAnswer: (your answer)\n"
            "DO NOT SAY ANYTHING ELSE.<|im_end|>\n"
            f"<|im_start|>user\n{context}<|im_end|>\n<|im_start|>assistant\n"
        )
        qa_pair = generate_eval_text(qa_prompt)
        
        try:
            synth_q = qa_pair.split("Answer:")[0].replace("Question:", "").strip()
            synth_a = qa_pair.split("Answer:")[1].strip()
        except IndexError:
            continue
            
        retrieved_docs = vector_store.as_retriever(search_kwargs={'k': 5}).invoke(synth_q)
        rag_context = "\n\n".join([d.page_content for d in retrieved_docs])
        rag_prompt = (
            "<|im_start|>system\nAnswer the question using only the context below.\n<|im_end|>\n"
            f"<|im_start|>user\nContext:\n{rag_context}\n\nQuestion: {synth_q}<|im_end|>\n<|im_start|>assistant\n"
        )
        rag_a = generate_eval_text(rag_prompt)
        
        eval_prompt = (
            "<|im_start|>system\nEvaluate the following Question-Answer pair for accuracy and completeness.\n"
            "Assume 'Answer 1' is the Ground Truth.\n"
            "Criteria:\n"
            "[1] Answer 2 lies, contradicts the ground truth, fails to answer the question, or misses critical context.\n"
            "[2] Answer 2 is factually accurate, contains the same core information as the ground truth, and is equally good or better.\n\n"
            "Output Format MUST start with the exact tag [1] or [2], followed by the justification.\n"
            "Example: [2] Answer 2 correctly identifies the same details as Answer 1.<|im_end|>\n"
            f"<|im_start|>user\nQuestion: {synth_q}\n\nAnswer 1 (Ground Truth): {synth_a}\n\nAnswer 2 (New Answer): {rag_a}<|im_end|>\n<|im_start|>assistant\n"
        )
        eval_res = generate_eval_text(eval_prompt)
        
        prefix = eval_res[:15]
        score_val = 2 if "[2]" in prefix or "Score: 2" in prefix else 1
        scores.append(score_val)
        
        results.append({
            "question": synth_q,
            "ground_truth": synth_a,
            "rag_answer": rag_a,
            "evaluation": eval_res,
            "score": score_val
        })
        
    if not scores:
        return {"status": "error", "message": "Evaluation failed to generate properly formatted Q&A pairs."}
        
    pref_score = sum([1 for s in scores if s == 2]) / len(scores)
    
    return {
        "status": "success",
        "preference_score": pref_score,
        "details": results
    }

In [75]:
def start_server(port):
    try:
        config = uvicorn.Config(app, host="0.0.0.0", port=port, log_level="error")
        server = uvicorn.Server(config)
        server.run()
    except Exception as e:
        print(f"Port {port} failed: {e}")

if 'current_port' not in globals():
    current_port = 7865
else:
    current_port += 1 

if 'server_thread' in globals() and server_thread.is_alive():
    print("Stopping existing server...")
    time.sleep(1)

server_thread = threading.Thread(target=start_server, args=(current_port,), daemon=True)
server_thread.start()
print(f"Server running on http://localhost:{current_port}")

Server running on http://localhost:7874


In [76]:
def process_files_client(files):
    if not files: return "No files selected."
    file_payload = [('files', (os.path.basename(f), open(f, 'rb'), 'application/pdf')) for f in files]
    try:
        response = requests.post(f"http://localhost:{current_port}/ingest", files=file_payload)
        res_data = response.json()
        if res_data.get("status") == "success":
            return f"Indexed {res_data.get('chunks')} chunks."
        else:
            return f"Server Error: {res_data.get('message')}"
    except Exception as e:
        return f"Connection Failed: {str(e)}"

def chat_client(message, history):
    try:
        response = requests.post(f"http://localhost:{current_port}/chat_stream", json={"message": message}, stream=True)
        full_text, sources, started = "", "", False
        for chunk in response.iter_content(chunk_size=None, decode_unicode=True):
            if not chunk: continue
            if "|---|" in chunk and not started:
                parts = chunk.split("|---|")
                sources = parts[0].replace("SOURCES:", "**Sources:**")
                full_text += parts[1]
                started = True
            else:
                full_text += chunk
            yield f"{full_text}\n\n---\n{sources}"
    except Exception as e:
        yield f"Stream Error: {str(e)}"
        
def run_eval_client(num_qs):
    try:
        response = requests.post(f"http://localhost:{current_port}/evaluate", json={"num_questions": int(num_qs)})
        data = response.json()
        
        if data.get("status") == "error":
            return data.get("message"), ""
            
        score = data.get("preference_score", 0)
        
        stats_md = f"### Overall Preference Score: {score * 100:.2f}%\n"
        stats_md += f"*(The pipeline successfully answered and matched/exceeded the ground truth on {int(score * num_qs)} out of {num_qs} questions)*"
        
        details_md = ""
        for i, res in enumerate(data.get("details", [])):
            details_md += f"### QA Pair {i+1}\n"
            details_md += f"**Question:** {res['question']}\n\n"
            details_md += f"**Ground Truth (Synthetic):** {res['ground_truth']}\n\n"
            details_md += f"**RAG Answer:** {res['rag_answer']}\n\n"
            details_md += f"**Judge Evaluation:** {res['evaluation']}\n\n"
            details_md += "---\n"
            
        return stats_md, details_md
    except Exception as e:
        return f"**Error:** {str(e)}", ""

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# Smart Legal Contract RAG Assistant")
    
    with gr.Tabs():
        with gr.Tab("Chat"):
            with gr.Row():
                with gr.Column(scale=1):
                    files = gr.File(label="Upload PDFs", file_count="multiple")
                    btn = gr.Button("Process", variant="primary")
                    stat = gr.Textbox(label="Status")
                with gr.Column(scale=2):
                    gr.ChatInterface(fn=chat_client)
                    
            btn.click(process_files_client, inputs=[files], outputs=[stat])

        with gr.Tab("Evaluation"):
            gr.Markdown("### RAG Evaluation via LLM-as-a-Judge\n"
                        "This tool generates synthetic question-answer pairs from your ingested documents, "
                        "runs the questions through your retrieval pipeline, and uses the LLM to score the RAG answer.")
            
            with gr.Row():
                num_qs_slider = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Number of Test Questions")
                eval_btn = gr.Button("Run Evaluation Test", variant="primary")
            
            with gr.Row():
                with gr.Column():
                    eval_stats = gr.Markdown(label="Statistics")
                    eval_details = gr.Markdown(label="Detailed Breakdown")
            
            eval_btn.click(run_eval_client, inputs=[num_qs_slider], outputs=[eval_stats, eval_details])

demo.launch(share=True)

  with gr.Blocks(theme=gr.themes.Soft()) as demo:


* Running on local URL:  http://127.0.0.1:7875
* Running on public URL: https://47dcda7ffbcfa26100.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


