<a href="https://colab.research.google.com/github/prem-cre/Multirag/blob/main/compilance_rule_checker.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [21]:
# @title Install Required Packages
!pip install -q langchain langchain_groq langchain-community langchain-huggingface \
    langgraph faiss-cpu pypdf sentence-transformers langchain-google-genai pydantic

In [22]:
# @title Configure API Keys, LLM, and FAISS Vector Store
import os
import glob
import pickle
import re
import json
from google.colab import userdata

# API Keys
os.environ["GROQ_API_KEY"] = userdata.get('groq')

from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.document_loaders import PyPDFDirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.prompts import PromptTemplate

# Initialize LLM - Using a fast model is key for performance
llm = ChatGroq(model_name="llama-3.1-8b-instant", temperature=0)

# Initialize embeddings
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L12-v2")

def read_documents(directory_path: str):
    """Loads PDF documents from a specified directory."""
    loader = PyPDFDirectoryLoader(directory_path)
    documents = loader.load()
    return documents

def chunk_data(docs, chunk_size=1100, chunk_overlap=70):
    """Splits documents into smaller chunks."""
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap
    )
    chunks = text_splitter.split_documents(docs)
    return chunks

# Define paths
FAISS_DB_PATH = "faiss_index"
DOCUMENT_LIST_PATH = "indexed_documents.pkl"

# Get current PDF files
current_pdf_files = sorted(glob.glob('/content/*.pdf'))

# Load previously indexed documents
indexed_pdf_files = []
if os.path.exists(DOCUMENT_LIST_PATH):
    with open(DOCUMENT_LIST_PATH, 'rb') as f:
        indexed_pdf_files = pickle.load(f)

# Check if rebuild needed
if os.path.exists(FAISS_DB_PATH) and current_pdf_files == indexed_pdf_files:
    vector_store = FAISS.load_local(FAISS_DB_PATH, embeddings, allow_dangerous_deserialization=True)
    retriever = vector_store.as_retriever(search_kwargs={"k": 5}) # Reduced k for faster retrieval
    print("✓ Loaded existing FAISS index")
else:
    docs = read_documents('/content/')
    documents = chunk_data(docs)

    if documents:
        vector_store = FAISS.from_documents(documents, embeddings)
        retriever = vector_store.as_retriever(search_kwargs={"k": 5})
        vector_store.save_local(FAISS_DB_PATH)
        with open(DOCUMENT_LIST_PATH, 'wb') as f:
            pickle.dump(current_pdf_files, f)
        print(f"✓ Created new FAISS index with {len(documents)} chunks")
    else:
        vector_store = None
        retriever = None
        print("⚠ No documents found - using fallback context")

⚠ No documents found - using fallback context


In [23]:
# @title Define Pydantic Models for Structured Output
from pydantic import BaseModel, Field
from typing import List

class CheckResult(BaseModel):
    """Represents the result of a single check (primary or secondary)."""
    violation_score: float = Field(..., description="A score from 0.0 (compliant) to 1.0 (severe violations).")
    has_violations: bool = Field(..., description="True if the score indicates a violation, otherwise false.")
    violation_types: List[str] = Field(..., description="A list of specific violation types found.")
    summary: str = Field(..., description="A brief summary of the findings.")

class ParallelChecksResult(BaseModel):
    """The combined output of the parallel analysis in a single structure."""
    primary_check: CheckResult = Field(..., description="Results from the critical legal accuracy check.")
    secondary_check: CheckResult = Field(..., description="Results from the compliance and professional standards check.")

class ViolationDetail(BaseModel):
    """Detailed information about a single violation found in a sentence."""
    sentence: str = Field(..., description="The exact sentence where the violation occurred.")
    violation_type: str = Field(..., description="The specific type of violation detected.")
    severity: str = Field(..., description="The severity of the violation (critical, high, medium, low).")
    explanation: str = Field(..., description="A concise explanation of why the sentence is a violation.")

class ViolationAnalysisResult(BaseModel):
    """The complete list of all detailed violations found in the text."""
    violations: List[ViolationDetail] = Field(..., description="A list of all violation details. Should be empty if no violations are found.")

In [24]:
# @title Define State Schema
from typing import TypedDict, List, Dict, Any

class ComplianceState(TypedDict):
    input_text: str
    context: str  # To hold context from FAISS
    primary_check: Dict[str, Any]
    secondary_check: Dict[str, Any]
    merged_check: Dict[str, Any]
    violations: List[Dict[str, Any]]
    final_output: Dict[str, Any]

In [25]:
# @title New Node: Retrieve Context from FAISS
def retrieve_context(state: ComplianceState) -> dict:
    """Retrieves relevant context from the FAISS vector store."""
    print("🔎 Retrieving context from FAISS...")
    if not retriever:
        print("   - No retriever available. Skipping.")
        return {"context": "No external context provided."}

    retrieved_docs = retriever.invoke(state["input_text"])
    context = "\n\n---\n\n".join([doc.page_content for doc in retrieved_docs])
    print(f"   - Retrieved {len(retrieved_docs)} context chunks.")
    return {"context": context}

In [26]:
# @title Fast Parallel Check Node (1 LLM Call)
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.exceptions import OutputParserException

def fast_parallel_initial_check(state: ComplianceState) -> dict:
    """
    Runs a single, powerful LLM check that simulates two experts (Primary and Secondary)
    and merges the results intelligently. This is the main speed improvement.
    """
    print("\n" + "="*60)
    print("🚀 RUNNING FAST PARALLEL VIOLATION DETECTION (1 LLM Call)")
    print("="*60)

    # Use the structured_output feature for reliable JSON
    structured_llm = llm.with_structured_output(ParallelChecksResult)

    prompt = ChatPromptTemplate.from_template("""
You are a dual-persona legal analysis AI. You must analyze the text from two perspectives simultaneously: a Critical Legal Expert and a Compliance Auditor.

**Relevant Legal Context from Documents:**
{context}

**Text to Analyze:**
{text}

**Instructions:**
Analyze the text and provide a single JSON response containing two nested objects: `primary_check` and `secondary_check`.

**1. Primary Check (Critical Legal Expert Persona):**
   - **Focus:** CRITICAL legal violations, factual falsehoods, dangerous misinformation.
   - **Scoring:** 0.0 (accurate) to 1.0 (severely incorrect).
   - **Ignore:** Minor tone/grammar issues.

**2. Secondary Check (Compliance Auditor Persona):**
   - **Focus:** Policy violations (grPC, LawVriksh), unprofessional language, missing citations, ambiguity.
   - **Scoring:** 0.0 (compliant) to 1.0 (major compliance failures).
   - **Ignore:** Minor typos.

Provide your complete analysis in the required JSON format. Ensure all strings are valid.
""")

    chain = prompt | structured_llm

    try:
        result: ParallelChecksResult = chain.invoke({
            "text": state["input_text"],
            "context": state["context"]
        })
    except (OutputParserException, Exception) as e:
        print(f"   - ⚠ Error in parallel check. Using fallback. Error: {e}")
        # Fallback to a default non-violating state to prevent crash
        result = ParallelChecksResult(
            primary_check=CheckResult(violation_score=0, has_violations=False, violation_types=[], summary="Error parsing initial check."),
            secondary_check=CheckResult(violation_score=0, has_violations=False, violation_types=[], summary="Error parsing initial check.")
        )


    # FIX: Replaced .dict() with .model_dump()
    primary_result = result.primary_check.model_dump()
    secondary_result = result.secondary_check.model_dump()

    # --- Intelligent Merging Logic (Identical to your original code) ---
    primary_score = primary_result.get("violation_score", 0)
    secondary_score = secondary_result.get("violation_score", 0)
    merged_score = (primary_score * 0.65) + (secondary_score * 0.35)

    primary_types = set(primary_result.get("violation_types", []))
    secondary_types = set(secondary_result.get("violation_types", []))
    all_violation_types = list(primary_types.union(secondary_types))

    has_violations = (
        merged_score >= 0.3 or
        primary_score >= 0.5 or
        secondary_score >= 0.6
    )

    combined_summary = f"""Critical Analysis: {primary_result.get('summary', 'N/A')}
Compliance Analysis: {secondary_result.get('summary', 'N/A')}"""

    merged_result = {
        "violation_score": round(merged_score, 2),
        "has_violations": has_violations,
        "violation_types": all_violation_types,
        "summary": combined_summary.strip(),
        "primary_score": round(primary_score, 2),
        "secondary_score": round(secondary_score, 2),
        "detection_confidence": "high" if abs(primary_score - secondary_score) < 0.25 else "medium"
    }

    print(f"Primary Check Score: {primary_score:.2f}")
    print(f"Secondary Check Score: {secondary_score:.2f}")
    print(f"\n📊 Merged Score: {merged_score:.2f}")
    print(f"   Violations Detected: {has_violations}")
    print(f"   Confidence: {merged_result['detection_confidence']}")

    return {
        "primary_check": primary_result,
        "secondary_check": secondary_result,
        "merged_check": merged_result
    }

In [27]:
# @title Conditional Router - Decides Next Step
def should_continue(state: ComplianceState) -> str:
    """
    Route based on merged violation detection.
    If violations detected, perform detailed sentence-level analysis.
    Otherwise, skip to no violations output.
    """
    merged_check = state["merged_check"]
    has_violations = merged_check.get("has_violations", False)
    score = merged_check.get("violation_score", 0)

    print(f"\n🔀 Routing Decision: {'ANALYZE VIOLATIONS' if has_violations else 'NO VIOLATIONS'}")
    print(f"   (Score: {score}, Threshold: 0.3)\n")

    if has_violations:
        return "analyze_violations"
    else:
        return "no_violations_output"

In [28]:
# @title Holistic Detailed Violation Analysis (1 LLM Call)
from groq.types.chat import ChatCompletion
from groq import BadRequestError

def holistic_sentence_analysis(state: ComplianceState) -> dict:
    """
    Performs a single, holistic, sentence-level analysis of the entire text,
    returning a structured list of all violations. Replaces the slow loop.
    """
    print("\n" + "="*60)
    print("⚡ RUNNING HOLISTIC SENTENCE-LEVEL ANALYSIS (1 LLM Call)")
    print("="*60)

    structured_llm = llm.with_structured_output(ViolationAnalysisResult)
    merged_check = state["merged_check"]

    prompt = ChatPromptTemplate.from_template("""
You are a legal compliance analyst. An initial screening of the text below has detected potential violations. Your task is to perform a detailed analysis and identify every specific sentence that contains a violation.

**Initial Findings Summary:**
{summary}

**Potential Violation Types Detected:**
{violation_types}

**Relevant Legal Context from Documents:**
{context}

**Full Text to Analyze:**
{text}

**Instructions:**
1. Read the entire text.
2. For EACH sentence that contains a violation, create a JSON object with the exact sentence, violation type, severity, and explanation.
3. Combine all these objects into a final list under the "violations" key.
4. If, upon closer inspection, you find no specific sentences with violations, return an empty "violations" list.
5. **CRITICAL JSON RULE:** Your output MUST be a single, valid JSON object. All strings inside the JSON, especially the 'sentence' and 'explanation' fields, must be correctly formatted. Properly escape any special characters like apostrophes (') or double quotes ("). For example, write "it didn't happen" instead of "it didn"t happen".
""")

    chain = prompt | structured_llm
    violations = []

    # FIX: Added try...except block to handle the BadRequestError gracefully
    try:
        result: ViolationAnalysisResult = chain.invoke({
            "text": state["input_text"],
            "context": state["context"],
            "summary": merged_check["summary"],
            "violation_types": ", ".join(merged_check["violation_types"])
        })
        # FIX: Replaced .dict() with .model_dump()
        violations = [v.model_dump() for v in result.violations]

    except BadRequestError as e:
        print(f"   - ⚠ CRITICAL ERROR: The LLM failed to generate valid structured JSON for detailed analysis.")
        print(f"   - This is often due to unescaped characters in its output. The workflow will continue with 0 violations found.")
        print(f"   - Groq Error: {e}")
        # Return an empty list to prevent the graph from crashing
        violations = []

    print(f"   - Found {len(violations)} specific violations.")
    return {"violations": violations}

In [29]:
# @title Output and Summary Nodes
def no_violations_output(state: ComplianceState) -> dict:
    """Generate output when no violations are detected."""
    merged_check = state["merged_check"]
    output = {
        "status": "compliant",
        "message": "No violations detected",
        "score": merged_check.get("violation_score", 0),
        "primary_score": merged_check.get("primary_score", 0),
        "secondary_score": merged_check.get("secondary_score", 0),
        "summary": "Text passes all compliance checks."
    }
    print("\n✅ FINAL RESULT: NO VIOLATIONS DETECTED\n")
    return {"final_output": output}


def generate_summary(state: ComplianceState) -> dict:
    """Generate comprehensive summary of all violations found."""
    violations = state.get("violations")
    if not violations:
        return no_violations_output(state)

    merged_check = state["merged_check"]

    # Group violations by severity
    critical = [v for v in violations if v["severity"] == "critical"]
    high = [v for v in violations if v["severity"] == "high"]
    medium = [v for v in violations if v["severity"] == "medium"]
    low = [v for v in violations if v["severity"] == "low"]

    output = {
        "status": "violations_detected",
        "overall_score": merged_check.get("violation_score", 0),
        "primary_score": merged_check.get("primary_score", 0),
        "secondary_score": merged_check.get("secondary_score", 0),
        "detection_confidence": merged_check.get("detection_confidence", "unknown"),
        "total_violations": len(violations),
        "breakdown": {
            "critical": len(critical),
            "high": len(high),
            "medium": len(medium),
            "low": len(low)
        },
        "violations": violations,
        "summary": merged_check.get("summary", ""),
        "violation_types": merged_check.get("violation_types", [])
    }

    print("\n" + "="*60)
    print("FINAL COMPLIANCE REPORT")
    print("="*60)
    print(f"Status: {output['status']}")
    print(f"Overall Score: {output['overall_score']}")
    print(f"Total Violations: {output['total_violations']}")
    print(f"  Critical: {len(critical)} | High: {len(high)} | Medium: {len(medium)} | Low: {len(low)}")
    print("="*60 + "\n")

    return {"final_output": output}

In [30]:
# @title Build and Compile the High-Speed LangGraph Workflow
from langgraph.graph import StateGraph, END

# Initialize graph
workflow = StateGraph(ComplianceState)

# Add nodes
workflow.add_node("retrieve_context", retrieve_context)
workflow.add_node("parallel_check", fast_parallel_initial_check)
workflow.add_node("analyze_violations", holistic_sentence_analysis)
workflow.add_node("no_violations", no_violations_output)
workflow.add_node("generate_summary", generate_summary)

# Define edges
workflow.set_entry_point("retrieve_context")
workflow.add_edge("retrieve_context", "parallel_check")

workflow.add_conditional_edges(
    "parallel_check",
    should_continue,
    {
        "analyze_violations": "analyze_violations",
        "no_violations_output": "no_violations"
    }
)

workflow.add_edge("analyze_violations", "generate_summary")
workflow.add_edge("no_violations", END)
workflow.add_edge("generate_summary", END)

# Compile graph
fast_compliance_checker = workflow.compile()

In [34]:
# @title Test Cases
import time

# --- Test Cases ---
test_clean = """
Dowry is just a normal part of Indian weddings, and everyone gives gifts so it’s not really illegal. If the bride’s family gives money or stuff to the groom’s side, it’s fine as long as no one complains. The law doesn’t clearly say what counts as dowry, so people can interpret it however they want. Most dowry cases are fake and filed just to harass the husband’s family. Judges usually ignore these cases unless there’s a video or something.

IPC Section 498A says that if a husband is mean, he can be arrested without any proof. That’s why a lot of innocent men get jailed for no reason. The Dowry Act doesn’t apply if the gifts are given out of love, so it’s okay to give gold and cars. Police can arrest anyone accused of dowry harassment even if there’s no FIR. Courts don’t need evidence—they just go with what feels right.

If someone returns the dowry items, the case is automatically closed. The law says that dowry is only illegal if it’s forced, so voluntary dowry is fine. Legal citations aren’t really needed in dowry cases because everyone knows the rules. The judge can decide based on common sense and doesn’t have to explain the verdict.

However, under IPC Section 498A, cruelty related to dowry demands is a cognizable and non-bailable offense. Courts have held that mental harassment and threats are sufficient grounds for prosecution. The Supreme Court has clarified that even indirect pressure for dowry can constitute an offense.

Legal writing in dowry cases must be precise, supported by evidence, and aligned with statutory definitions. The law mandates that any dowry articles received must be returned to the bride within a reasonable time. Failure to do so can result in prosecution under Section 6 of the Dowry Prohibition Act, 1961.

Dowry deaths are rare and usually happen because of other reasons like accidents or depression. The law doesn’t define cruelty properly, so anything can be called harassment. If the husband says sorry, the case can be withdrawn. Lawyers often skip formal language in dowry cases because it’s too technical. The law changes all the time, so it’s hard to keep track of what’s allowed.
"""


result = fast_compliance_checker.invoke({"input_text": test_clean})

print(json.dumps(result["final_output"], indent=2))



TEST: CLEAN TEXT
🔎 Retrieving context from FAISS...
   - No retriever available. Skipping.

🚀 RUNNING FAST PARALLEL VIOLATION DETECTION (1 LLM Call)
Primary Check Score: 0.90
Secondary Check Score: 0.80

📊 Merged Score: 0.86
   Violations Detected: True
   Confidence: high

🔀 Routing Decision: ANALYZE VIOLATIONS
   (Score: 0.86, Threshold: 0.3)


⚡ RUNNING HOLISTIC SENTENCE-LEVEL ANALYSIS (1 LLM Call)
   - Found 19 specific violations.

FINAL COMPLIANCE REPORT
Status: violations_detected
Overall Score: 0.86
Total Violations: 19
  Critical: 0 | High: 0 | Medium: 0 | Low: 0

{
  "status": "violations_detected",
  "overall_score": 0.86,
  "primary_score": 0.9,
  "secondary_score": 0.8,
  "detection_confidence": "high",
  "total_violations": 19,
  "breakdown": {
    "critical": 0,
    "high": 0,
    "medium": 0,
    "low": 0
  },
  "violations": [
    {
      "sentence": "Dowry is just a normal part of Indian weddings, and everyone gives gifts so it\u2019s not really illegal.",
      "vio