In [1]:
!pip install gradio pdfplumber rank-bm25 pinecone huggingface_hub

Collecting pdfplumber
  Downloading pdfplumber-0.11.7-py3-none-any.whl.metadata (42 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/42.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.8/42.8 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting rank-bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl.metadata (3.2 kB)
Collecting pinecone
  Downloading pinecone-7.3.0-py3-none-any.whl.metadata (9.5 kB)
Collecting pdfminer.six==20250506 (from pdfplumber)
  Downloading pdfminer_six-20250506-py3-none-any.whl.metadata (4.2 kB)
Collecting pypdfium2>=4.18.0 (from pdfplumber)
  Downloading pypdfium2-4.30.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (48 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m48.2/48.2 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
Collecting pinecone-plugin-assistant<2.0.0,>=1.6.0 (from pinecone)
  Downloading pinecone_plugin_assistant-1

In [2]:
import os
import json
import numpy as np
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
import re
from sentence_transformers import SentenceTransformer
from pinecone import Pinecone
from rank_bm25 import BM25Okapi
import pickle
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import warnings
warnings.filterwarnings('ignore')

In [3]:
# Configuration
PINECONE_INDEX_NAME = "legal-contract-search"
EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
METADATA_NAMESPACE = "contract_metadata"
CHUNKS_NAMESPACE = "contract_chunks"

In [4]:
class QueryType(Enum):
    RISK_ANALYSIS = "risk_analysis"
    CLAUSE_EXTRACTION = "clause_extraction"
    COMPLIANCE_CHECK = "compliance_check"
    COMPARISON = "comparison"
    GENERAL_SEARCH = "general_search"

In [13]:
class SearchResult:
    def __init__(self, id: str, text: str, metadata: Dict[str, Any], score: float, search_type: str):
        self.id = id
        self.text = text
        self.metadata = metadata
        self.score = score
        self.search_type = search_type

    def __repr__(self):
        return f"SearchResult(id='{self.id}', score={self.score:.3f}, type='{self.search_type}')"

    def __str__(self):
        return f"SearchResult: {self.metadata.get('filename', 'Unknown')} (Score: {self.score:.3f})"

In [14]:
class HybridSearchEngine:
    def __init__(self, pinecone_api_key: str, index_name: str, embedding_model: str):
        self.pc = Pinecone(api_key=pinecone_api_key)
        self.index = self.pc.Index(index_name)
        self.model = SentenceTransformer(embedding_model)
        self.bm25 = None
        self.corpus_texts = []
        self.corpus_metadata = []

    def load_corpus_for_bm25(self, chunks_data: List[Dict[str, Any]]):
        """Load corpus for BM25 indexing"""
        self.corpus_texts = [chunk['text'] for chunk in chunks_data]
        self.corpus_metadata = [chunk['metadata'] for chunk in chunks_data]
        tokenized_corpus = [text.lower().split() for text in self.corpus_texts]
        self.bm25 = BM25Okapi(tokenized_corpus)

    def vector_search(self, query: str, namespace: str = CHUNKS_NAMESPACE,
                     top_k: int = 50, metadata_filter: Dict[str, Any] = None) -> List[SearchResult]:
        """Perform vector similarity search with optional metadata filtering"""
        try:
            query_embedding = self.model.encode([query])[0]

            query_params = {
                "vector": query_embedding.tolist(),
                "top_k": top_k,
                "namespace": namespace,
                "include_metadata": True
            }

            if metadata_filter:
                query_params["filter"] = metadata_filter

            results = self.index.query(**query_params)

            search_results = []
            for match in results['matches']:
                search_results.append(SearchResult(
                    id=match['id'],
                    text=match['metadata'].get('text', ''),
                    metadata=match['metadata'],
                    score=match['score'],
                    search_type="vector"
                ))

            return search_results
        except Exception as e:
            print(f"Vector search error: {e}")
            return []

    def bm25_search(self, query: str, top_k: int = 50,
                   filename_filter: str = None) -> List[SearchResult]:
        """Perform BM25 keyword search with optional filename filtering"""
        if self.bm25 is None:
            return []

        try:
            tokenized_query = query.lower().split()
            scores = self.bm25.get_scores(tokenized_query)

            top_indices = np.argsort(scores)[::-1][:top_k*2]

            search_results = []
            for idx in top_indices:
                if scores[idx] > 0:
                    metadata = self.corpus_metadata[idx]

                    if filename_filter and filename_filter.lower() not in metadata.get('filename', ''):
                        continue

                    search_results.append(SearchResult(
                        id=f"bm25_{idx}",
                        text=self.corpus_texts[idx],
                        metadata=metadata,
                        score=scores[idx],
                        search_type="bm25"
                    ))

                    if len(search_results) >= top_k:
                        break

            return search_results
        except Exception as e:
            print(f"BM25 search error: {e}")
            return []

    def search_metadata_namespace(self, query: str, top_k: int = 10,
                               metadata_filter: Dict[str, Any] = None) -> List[SearchResult]:
        """Search metadata namespace for quick factual queries"""
        try:
            query_embedding = self.model.encode([query])[0]

            query_params = {
                "vector": query_embedding.tolist(),
                "top_k": top_k,
                "namespace": METADATA_NAMESPACE,
                "include_metadata": True
            }

            if metadata_filter:
                query_params["filter"] = metadata_filter

            results = self.index.query(**query_params)

            search_results = []
            for match in results['matches']:
                search_results.append(SearchResult(
                    id=match['id'],
                    text=match['metadata'].get('embedding_text', ''),
                    metadata=match['metadata'],
                    score=match['score'],
                    search_type="metadata"
                ))

            return search_results
        except Exception as e:
            print(f"Metadata search error: {e}")
            return []

    def enhanced_hybrid_search(self, query: str, top_k: int = 20, alpha: float = 0.7,
                  metadata_filter: Dict[str, Any] = None,
                  filename_filter: str = None) -> List[SearchResult]:
        """Enhanced hybrid search that combines metadata and chunks information"""

        # Get results from both namespaces
        metadata_results = self.search_metadata_namespace(
            query, top_k=top_k//2, metadata_filter=metadata_filter
        )

        vector_results = self.vector_search(query, top_k=top_k*2, metadata_filter=metadata_filter)
        bm25_results = self.bm25_search(query, top_k=top_k*2, filename_filter=filename_filter)

        # Normalize scores for vector and BM25 results
        if vector_results:
            max_vector_score = max(r.score for r in vector_results)
            if max_vector_score > 0:
                for result in vector_results:
                    result.score = result.score / max_vector_score

        # CORRECTED: Proper BM25 normalization with zero-division protection
        if bm25_results:
            bm25_scores = [r.score for r in bm25_results]
            max_bm25_score = max(bm25_scores)
            min_bm25_score = min(bm25_scores)

            # Handle edge cases for normalization
            if max_bm25_score > 0 and max_bm25_score != min_bm25_score:
                # Use min-max normalization to avoid zero division
                score_range = max_bm25_score - min_bm25_score
                for result in bm25_results:
                    result.score = (result.score - min_bm25_score) / score_range
            elif max_bm25_score > 0:
                # All scores are the same positive value, normalize to 1.0
                for result in bm25_results:
                    result.score = 1.0
            else:
                # All scores are zero or negative, set to small positive value
                for result in bm25_results:
                    result.score = 0.001

        # Normalize metadata scores (they're already cosine similarity 0-1)
        # Give metadata results a slight boost since they contain high-level information
        for result in metadata_results:
            result.score = result.score * 1.1  # 10% boost for metadata

        # Combine all results
        combined_results = {}

        # Add metadata results first (they provide high-level context)
        for result in metadata_results:
            text_key = result.text[:100] + result.metadata.get('filename', '')
            combined_results[text_key] = result
            result.search_type = "metadata"

        # Add vector results
        for result in vector_results:
            text_key = result.text[:100] + result.metadata.get('filename', '')
            if text_key not in combined_results:  # Don't override metadata results
                combined_results[text_key] = result
                result.score = alpha * result.score

        # Add BM25 results
        for result in bm25_results:
            text_key = result.text[:100] + result.metadata.get('filename', '')
            if text_key in combined_results:
                # Boost existing results with BM25 score
                combined_results[text_key].score += (1 - alpha) * result.score
                if combined_results[text_key].search_type == "vector":
                    combined_results[text_key].search_type = "hybrid"
                elif combined_results[text_key].search_type == "metadata":
                    combined_results[text_key].search_type = "metadata+bm25"
            else:
                result.score = (1 - alpha) * result.score
                combined_results[text_key] = result

        # CORRECTED: Handle empty results case
        if not combined_results:
            print(f"Warning: No results found for query: {query}")
            return []

        final_results = sorted(combined_results.values(), key=lambda x: x.score, reverse=True)
        return final_results[:top_k]

In [15]:
class QueryClassifier:
    def __init__(self):
        self.risk_keywords = ['risk', 'liability', 'penalty', 'breach', 'default', 'damages', 'indemnification', 'exposure', 'consequence']
        self.clause_keywords = ['clause', 'term', 'provision', 'section', 'article', 'condition', 'requirement', 'obligation']
        self.compliance_keywords = ['compliance', 'regulation', 'law', 'legal', 'regulatory', 'standard', 'requirement', 'audit']
        self.comparison_keywords = ['compare', 'difference', 'similar', 'contrast', 'versus', 'vs', 'between', 'against']

        # Enhanced indicators with more precise patterns
        self.general_indicators = [
            'typical', 'common', 'usual', 'standard', 'general', 'normally', 'generally',
            'what are', 'what should', 'how to', 'best practices', 'types of', 'kinds of',
            'explain', 'describe', 'define', 'what is', 'how does', 'why do'
        ]

        # More specific contract indicators
        self.specific_indicators = [
            'in the', 'from the', 'show me', 'extract from', 'analyze the', 'review the',
            'find in', 'locate in', 'summarize the', 'summarise the', 'according to',
            'based on the', 'as per the'
        ]

    def classify_query(self, query: str) -> QueryType:
        """Classify the query type based on keywords"""
        query_lower = query.lower()

        if any(keyword in query_lower for keyword in self.risk_keywords):
            return QueryType.RISK_ANALYSIS
        elif any(keyword in query_lower for keyword in self.clause_keywords):
            return QueryType.CLAUSE_EXTRACTION
        elif any(keyword in query_lower for keyword in self.compliance_keywords):
            return QueryType.COMPLIANCE_CHECK
        elif any(keyword in query_lower for keyword in self.comparison_keywords):
            return QueryType.COMPARISON
        else:
            return QueryType.GENERAL_SEARCH

    def extract_contract_identifier(self, query: str) -> Optional[str]:
        """Enhanced contract identifier extraction with better filtering"""
        query_lower = query.lower()

        # Check for specific indicators first
        has_specific_indicator = any(indicator in query_lower for indicator in self.specific_indicators)

        if not has_specific_indicator:
            return None

        # Enhanced exclusion of general terms
        general_terms = [
            'typical', 'common', 'standard', 'general', 'what are', 'what should',
            'types of', 'kinds of', 'best', 'worst', 'most', 'least', 'any',
            'all', 'some', 'many', 'few', 'several', 'various'
        ]
        if any(term in query_lower for term in general_terms):
            return None

        # Improved contract patterns with more specific capture
        contract_patterns = [
            # Direct contract mentions with action words
            r'(?:summarise|summarize|analyze|review|examine|check)\s+(?:the\s+)?([a-zA-Z0-9_-]+(?:\s+[a-zA-Z0-9_-]+)*)\s+(?:contract|agreement|document)',

            # Prepositional phrases
            r'(?:in|from|of|within)\s+(?:the\s+)?([a-zA-Z0-9_-]+(?:\s+[a-zA-Z0-9_-]+)*)\s+(?:contract|agreement)',

            # Contract name patterns
            r'\b([a-zA-Z0-9_-]+(?:\s+[a-zA-Z0-9_-]+)*)\s+(?:contract|agreement)(?:\s+(?:file|document))?\b',

            # Named contracts
            r'(?:contract|agreement)\s+(?:called|named|titled|for)\s+([a-zA-Z0-9_-]+(?:\s+[a-zA-Z0-9_-]+)*)',
        ]

        for pattern in contract_patterns:
            matches = re.findall(pattern, query_lower, re.IGNORECASE)
            if matches:
                potential_contract = matches[0].strip()

                # Enhanced filtering
                excluded_words = {
                    'the', 'this', 'that', 'a', 'an', 'contract', 'agreement', 'service',
                    'typical', 'common', 'standard', 'general', 'main', 'key', 'important',
                    'new', 'old', 'current', 'existing', 'draft', 'final', 'signed',
                    'any', 'all', 'some', 'every', 'each', 'most', 'best', 'worst',
                    'employment', 'lease', 'franchise', 'license', 'nda', 'merger'
                }

                # Clean and validate
                contract_parts = potential_contract.split()
                clean_parts = [part for part in contract_parts if part not in excluded_words and len(part) > 1]

                if clean_parts and len(''.join(clean_parts)) > 2:
                    return '_'.join(clean_parts).lower()

        return None

    def is_general_query(self, query: str) -> bool:
        """Check if query is asking for general information rather than specific contract analysis"""
        query_lower = query.lower()
        return any(indicator in query_lower for indicator in self.general_indicators)

In [16]:
class ContractAgent:
    def __init__(self, agent_type: QueryType, model, tokenizer, device):
        self.agent_type = agent_type
        self.model = model
        self.tokenizer = tokenizer
        self.device = device

        self.system_prompts = {
            QueryType.RISK_ANALYSIS: """You are a legal risk analysis expert. Analyze the provided contract chunks and identify potential risks, liabilities, and areas of concern. Focus on:
- Financial risks and liability exposure
- Compliance risks
- Operational risks
- Legal risks
- Confidentiality and Data privacy risks
Provide a structured analysis with specific references to contract terms.""",

            QueryType.CLAUSE_EXTRACTION: """You are a contract clause extraction specialist. Extract and analyze specific clauses from the provided contract chunks. Focus on:
- Key terms and conditions
- Rights and obligations
- Performance requirements
- Termination clauses
Present the information in a clear, structured format.""",

            QueryType.COMPLIANCE_CHECK: """You are a compliance verification expert. Review the provided contract chunks for compliance with legal and regulatory requirements. Focus on:
- Regulatory compliance
- Industry standards
- Legal requirements
- Best practices
Highlight any compliance gaps or concerns.""",

            QueryType.COMPARISON: """You are a contract comparison specialist. Compare and contrast the provided contract chunks to identify:
- Similarities and differences
- Variations in terms
- Inconsistencies
- Relative advantages/disadvantages
Provide a detailed comparative analysis.""",

            QueryType.GENERAL_SEARCH: """You are a general contract analysis expert. Analyze the provided contract chunks and provide relevant insights based on the user's query. Be comprehensive and accurate in your analysis."""
        }

    def generate_response(self, query: str, search_results: List[SearchResult],
             context_type: str = "database", min_token: int = 512, max_tokens: int = 2048) -> str:
        """Enhanced response generation with better context organization"""
        try:
            if context_type == "uploaded":
                # Enhanced context for uploaded documents
                doc_text = search_results[0].text
                filename = search_results[0].metadata.get('filename', 'Unknown')

                # Extract key sections if document is long
                if len(doc_text) > 3000:
                    sentences = doc_text.split('. ')
                    # Use query keywords to find most relevant sections
                    query_words = set(query.lower().split())
                    scored_sentences = []

                    for i, sentence in enumerate(sentences):
                        sentence_words = set(sentence.lower().split())
                        overlap = len(query_words.intersection(sentence_words))
                        if overlap > 0:
                            scored_sentences.append((overlap, i, sentence))

                    # Sort by relevance and take top sections
                    scored_sentences.sort(reverse=True)
                    relevant_text = '. '.join([sent[2] for sent in scored_sentences[:10]])
                    context = f"Document: {filename}\n\nMost Relevant Sections:\n{relevant_text}"
                else:
                    context = f"Document: {filename}\n\n{doc_text}"

            else:
                # Enhanced database context organization
                metadata_results = [r for r in search_results if r.search_type in ["metadata", "metadata+bm25"]]
                chunk_results = [r for r in search_results if r.search_type not in ["metadata", "metadata+bm25"]]

                context_parts = []

                # Add metadata overview
                if metadata_results:
                    context_parts.append("=== CONTRACT OVERVIEW ===")
                    for result in metadata_results[:2]:
                        context_parts.append(f"Document: {result.metadata.get('filename', 'Unknown')}")
                        context_parts.append(f"Summary: {result.text[:1000]}")
                        context_parts.append("")

                # Group chunks by document for better organization
                if chunk_results:
                    context_parts.append("=== DETAILED CONTENT ===")

                    # Group by filename
                    doc_chunks = {}
                    for result in chunk_results:
                        filename = result.metadata.get('filename', 'Unknown')
                        if filename not in doc_chunks:
                            doc_chunks[filename] = []
                        doc_chunks[filename].append(result)

                    # Add chunks grouped by document
                    for filename, chunks in list(doc_chunks.items())[:3]:  # Max 3 documents
                        context_parts.append(f"--- {filename.upper()} ---")
                        for i, chunk in enumerate(chunks[:2]):  # Max 2 chunks per doc
                            context_parts.append(f"Section {i+1}:")
                            context_parts.append(chunk.text[:1500])
                            context_parts.append("")

                context = "\n".join(context_parts)

            # Enhanced system prompt based on query type
            system_prompt = self.system_prompts.get(self.agent_type, "")

            # Build comprehensive prompt
            user_prompt = f"""System: {system_prompt}

Query: {query}

Contract Information:
{context}

Instructions:
- Provide a detailed, accurate response based on the contract information
- Quote specific clauses or sections when relevant
- Explain legal implications clearly
- Structure your response with clear headings if appropriate
- If information is insufficient, state what additional details would be helpful

Response:"""

            # Generate with improved parameters
            inputs = self.tokenizer(user_prompt, return_tensors="pt", truncation=True, max_length=3072)
            inputs = {k: v.to(self.device) for k, v in inputs.items()}

            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=max_tokens,
                    do_sample=True,
                    temperature=0.2,  # Lower temperature for more focused responses
                    top_p=0.9,  # Add top_p sampling
                    pad_token_id=self.tokenizer.eos_token_id,
                    repetition_penalty=1.15,  # Slightly higher to reduce repetition
                    no_repeat_ngram_size=3,
                    early_stopping=True
                )

            # Enhanced response cleaning
            full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

            # Extract generated part
            if "Response:" in full_response:
                response = full_response.split("Response:")[-1].strip()
            else:
                response = full_response[len(user_prompt):].strip()

            # Clean and format response
            response = self._clean_response(response)

            return response if response else "I apologize, but I couldn't generate a proper response. Please try rephrasing your query."

        except Exception as e:
            print(f"Generation error: {e}")
            return f"Error generating response: {str(e)}"

    def _clean_response(self, response: str) -> str:
        """Clean and format the generated response"""
        lines = response.split('\n')
        cleaned_lines = []
        prev_line = ""

        for line in lines:
            line = line.strip()
            # Remove repetitive lines and empty lines at the start
            if line and line != prev_line and not (line.startswith("System:") or line.startswith("Query:")):
                cleaned_lines.append(line)
                prev_line = line

        # Join and remove any remaining artifacts
        final_response = '\n'.join(cleaned_lines)

        # Remove common generation artifacts
        artifacts_to_remove = [
            "Based on the contract information provided:",
            "According to the contract:",
            "Here is my analysis:",
            "Response:",
            "Answer:"
        ]

        for artifact in artifacts_to_remove:
            if final_response.startswith(artifact):
                final_response = final_response[len(artifact):].strip()

        return final_response

In [17]:
class UnifiedContractSystem:
    def __init__(self, pinecone_api_key: str, index_name: str, embedding_model: str):
        self.search_engine = HybridSearchEngine(pinecone_api_key, index_name, embedding_model)
        self.query_classifier = QueryClassifier()

        # Initialize LLM
        try:
            self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token

            self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
            self.model.to(self.device)
            self.model.eval()

            # Initialize agents
            self.agents = {
                query_type: ContractAgent(query_type, self.model, self.tokenizer, self.device)
                for query_type in QueryType
            }

        except Exception as e:
            print(f"Error initializing LLM: {e}")
            self.model = None
            self.tokenizer = None
            self.agents = {}

        self.current_document = None
        self.rag_initialized = False

    def initialize_rag(self, chunks_data: List[Dict[str, Any]]):
        """Initialize RAG system with dataset"""
        try:
            self.search_engine.load_corpus_for_bm25(chunks_data)
            self.rag_initialized = True
            print("RAG system initialized successfully")
        except Exception as e:
            print(f"Error initializing RAG: {str(e)}")
            self.rag_initialized = False

    def load_document(self, document_text: str, filename: str = "uploaded_contract"):
        """Load the uploaded document for analysis"""
        self.current_document = {
            "text": document_text,
            "filename": filename.lower()  # Ensure filename is lowered to match metadata
        }

    def determine_query_mode(self, query: str, has_uploaded_doc: bool) -> Tuple[str, Optional[str]]:
        """Determine whether to use uploaded document, RAG, or specific contract search"""
        query_lower = query.lower()

        # Check if this is a general query first
        is_general = self.query_classifier.is_general_query(query)

        # Check for specific contract identifier
        contract_id = self.query_classifier.extract_contract_identifier(query)

        # Priority order:
        # 1. If user uploaded a document and query is about "this contract" or "the contract"
        if has_uploaded_doc and any(phrase in query_lower for phrase in ['this contract', 'uploaded contract', 'the contract', 'this document']):
            return "uploaded", None

        # 2. If query is clearly general/typical, use RAG regardless of contract identifier
        if is_general:
            return "rag", None

        # 3. If query mentions specific contract identifier and RAG is available
        if contract_id and self.rag_initialized:
            return "specific_contract", contract_id

        # 4. If user uploaded a document and query doesn't have general indicators
        if has_uploaded_doc and not is_general:
            return "uploaded", None

        # 5. Default to RAG for general questions
        return "rag", None

    def create_metadata_filter(self, contract_id: str) -> Dict[str, Any]:
        """Create metadata filter for Pinecone search with exact match on lowered filenames"""
        # Since all filenames in metadata are lowered, use $eq for exact match
        return {"filename": {"$eq": contract_id.lower()}}

    def process_query(self, query: str, top_k: int = 15) -> Dict[str, Any]:
        """Process user query through the unified system"""
        if not query or query.strip() == "":
            return {
                "query": query,
                "query_type": "error",
                "query_mode": "error",
                "response": "Please enter a valid query.",
                "sources": []
            }

        has_uploaded_doc = self.current_document is not None
        query_mode, contract_id = self.determine_query_mode(query, has_uploaded_doc)
        query_type = self.query_classifier.classify_query(query)

        # Debug information
        print(f"Query: {query}")
        print(f"Query mode: {query_mode}")
        print(f"Contract ID: {contract_id}")
        print(f"Query type: {query_type}")

        # Check if LLM is available
        if not self.model or not self.tokenizer:
            return {
                "query": query,
                "query_type": "error",
                "query_mode": "error",
                "response": "Language model not initialized. Cannot generate responses.",
                "sources": []
            }

        try:
            if query_mode == "uploaded":
                # Process uploaded document
                search_results = [SearchResult(
                    id="uploaded_doc",
                    text=self.current_document["text"],
                    metadata={"filename": self.current_document["filename"]},
                    score=1.0,
                    search_type="uploaded_document"
                )]

                agent = self.agents[query_type]
                response = agent.generate_response(query, search_results, context_type="uploaded")

                return {
                    "query": query,
                    "query_type": query_type.value,
                    "query_mode": "uploaded_document",
                    "search_results": 1,
                    "response": response,
                    "sources": [{
                        "filename": self.current_document["filename"],
                        "score": 1.0,
                        "search_type": "uploaded_document"
                    }]
                }

            elif query_mode == "specific_contract":
                # Search for specific contract in database
                if not self.rag_initialized:
                    return {
                        "query": query,
                        "query_type": query_type.value,
                        "query_mode": "error",
                        "response": "RAG system not initialized. Cannot search contract database.",
                        "sources": []
                    }

                # Create metadata filter for both namespaces
                metadata_filter = self.create_metadata_filter(contract_id)

                # Use enhanced search that combines metadata and chunks
                search_results = self.search_engine.enhanced_hybrid_search(
                    query,
                    top_k=top_k,
                    metadata_filter=metadata_filter,
                    filename_filter=contract_id.lower()
                )

                # CORRECTED: Handle empty results with fallback search
                if not search_results:
                    print(f"No exact match for contract_id: {contract_id}. Attempting broader search with post-filtering.")
                    all_results = self.search_engine.enhanced_hybrid_search(query, top_k=top_k*2)
                    search_results = [
                        result for result in all_results
                        if re.search(contract_id.lower(), result.metadata.get('filename', ''), re.IGNORECASE)
                    ][:top_k]

                # CORRECTED: Better handling of no results case
                if not search_results:
                    return {
                        "query": query,
                        "query_type": query_type.value,
                        "query_mode": "specific_contract",
                        "response": f"No documents found for contract '{contract_id}' in the database. Please check the contract name and try again, or try a more general query.",
                        "sources": [],
                        "search_results": 0
                    }

                agent = self.agents[query_type]
                response = agent.generate_response(query, search_results, context_type="database")

                return {
                    "query": query,
                    "query_type": query_type.value,
                    "query_mode": "specific_contract",
                    "search_results": len(search_results),
                    "response": response,
                    "sources": [{
                        "filename": result.metadata.get('filename', 'Unknown'),
                        "score": result.score,
                        "search_type": result.search_type
                    } for result in search_results[:5]]
                }

            else:  # RAG mode
                if not self.rag_initialized:
                    return {
                        "query": query,
                        "query_type": query_type.value,
                        "query_mode": "error",
                        "response": "RAG system not initialized. Cannot answer general questions.",
                        "sources": []
                    }

                # Use enhanced search that combines metadata and chunks for comprehensive results
                search_results = self.search_engine.enhanced_hybrid_search(query, top_k=top_k)

                # CORRECTED: Better handling of no results in RAG mode
                if not search_results:
                    return {
                        "query": query,
                        "query_type": query_type.value,
                        "query_mode": "rag",
                        "response": "No relevant documents found in the database. Please try rephrasing your query or using different keywords.",
                        "sources": [],
                        "search_results": 0
                    }

                agent = self.agents[query_type]
                response = agent.generate_response(query, search_results, context_type="database")

                return {
                    "query": query,
                    "query_type": query_type.value,
                    "query_mode": "rag",
                    "search_results": len(search_results),
                    "response": response,
                    "sources": [{
                        "filename": result.metadata.get('filename', 'Unknown'),
                        "score": result.score,
                        "search_type": result.search_type
                    } for result in search_results[:5]]
                }

        except Exception as e:
            print(f"Error in process_query: {e}")  # Added logging
            return {
                "query": query,
                "query_type": "error",
                "query_mode": "error",
                "response": f"Error processing query: {str(e)}",
                "sources": [],
                "search_results": 0
            }

In [18]:
# Gradio Interface
import gradio as gr
import pdfplumber
import tempfile

def extract_text_from_file(file_path: str) -> str:
    """Extract text from uploaded txt or pdf file."""
    try:
        if file_path.endswith(".txt"):
            with open(file_path, 'r', encoding='utf-8') as f:
                return f.read()
        elif file_path.endswith(".pdf"):
            with pdfplumber.open(file_path) as pdf:
                return "\n".join([page.extract_text() or "" for page in pdf.pages])
        else:
            return ""
    except Exception as e:
        print(f"Error extracting text: {e}")
        return ""

# Initialize system globally
system = UnifiedContractSystem(
    pinecone_api_key=os.environ.get("PINECONE_API_KEY", "pcsk_5mdDMR_HB4yT8PAsi5THNDFUZrpoiRHm68NKU6CmLaNj4AqmN46MtdSZM3h8TanPGVnZtk"),
    index_name=PINECONE_INDEX_NAME,
    embedding_model=EMBEDDING_MODEL
)

# Try to initialize RAG system
try:
    chunks_with_metadata = pickle.load(open('/content/chunks_with_metadata.pkl', 'rb'))
    system.initialize_rag(chunks_with_metadata)
except FileNotFoundError:
    print("Warning: chunks_with_metadata.pkl not found. RAG system not initialized.")
except Exception as e:
    print(f"Warning: Error loading RAG data: {str(e)}")

def process_contract_query(file, query):
    """Process the uploaded file and query."""
    try:
        if not query or query.strip() == "":
            return "Please enter a query.", "", ""

        # Handle file upload
        if file is not None:
            extracted_text = extract_text_from_file(file)
            if not extracted_text.strip():
                return "No text could be extracted from the uploaded file.", "", ""

            filename = os.path.basename(file).lower() if file else "uploaded_contract"
            system.load_document(extracted_text, filename)

        # Process query
        result = system.process_query(query)

        response = result["response"]

        # Format sources
        sources_text = ""
        if result["sources"]:
            sources_text = f"Query Mode: {result.get('query_mode', 'unknown')}\n"
            sources_text += f"Query Type: {result.get('query_type', 'unknown')}\n\n"
            sources_text += "Sources:\n" + "\n".join([
                f"• {src['filename']} (Score: {src['score']:.3f}, Type: {src['search_type']})"
                for src in result["sources"]
            ])

        # Format metadata
        metadata_text = f"Found {result.get('search_results', 0)} relevant chunks"

        return response, sources_text, metadata_text

    except Exception as e:
        return f"Error processing request: {str(e)}", "", ""

RAG system initialized successfully


In [19]:
def create_interface():
    with gr.Blocks(title="Unified Contract Analysis System") as demo:
        gr.Markdown("# 📋 Unified Contract Analysis System")
        gr.Markdown("""
        **Three modes of operation:**
        - **Upload a document**: Analyze specific contracts you upload
        - **Ask about specific contracts**: Query contracts in the database (e.g., "What are the risks in acme contract?")
        - **General questions**: Ask general legal questions using our knowledge base
        """)

        with gr.Row():
            with gr.Column(scale=1):
                file_input = gr.File(
                    label="Upload Contract Document (Optional)",
                    file_types=[".txt", ".pdf"],
                    type="filepath"
                )

                query_input = gr.Textbox(
                    label="Your Query",
                    placeholder="e.g., What are typical termination clauses? or What are the risks in acme contract?",
                    lines=3
                )

                submit_btn = gr.Button("Analyze", variant="primary")
                clear_btn = gr.Button("Clear Document", variant="secondary")

        with gr.Row():
            with gr.Column(scale=2):
                response_output = gr.Textbox(
                    label="AI Analysis",
                    lines=20,
                    max_lines=30,
                    interactive=False
                )

            with gr.Column(scale=1):
                sources_output = gr.Textbox(
                    label="Sources & Metadata",
                    lines=10,
                    interactive=False
                )

                metadata_output = gr.Textbox(
                    label="Search Info",
                    lines=3,
                    interactive=False
                )

        submit_btn.click(
            fn=process_contract_query,
            inputs=[file_input, query_input],
            outputs=[response_output, sources_output, metadata_output]
        )

        def clear_document():
            system.current_document = None
            return None, "", "", ""

        clear_btn.click(
            fn=clear_document,
            outputs=[file_input, response_output, sources_output, metadata_output]
        )

        gr.Markdown("### Example Queries:")
        gr.Markdown("""
        **For uploaded documents:**
        - "What are the main risks in this contract?"
        - "Extract all termination clauses from this document"

        **For specific contracts in database:**
        - "What are the parties involved in acme contract?"
        - "Show me liability clauses in microsoft contract"
        - "Analyze the adamsgolf contract"

        **General questions:**
        - "What are typical risks in service agreements?"
        - "What are typical termination clauses?"
        - "What should I look for in liability clauses?"
        """)

    return demo

In [20]:
# Launch the interface
if __name__ == "__main__":
    demo = create_interface()
    demo.launch(share=True, show_error=True)

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://363a924e52ceac2743.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
