<a href="https://colab.research.google.com/github/stephenkwok85/2048/blob/main/AI_Powered_RAG_Document_Chatbot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install gradio PyPDF2 pdf2image pytesseract faiss-cpu sentence-transformers transformers pillow

!apt-get update
!apt-get install -y poppler-utils

!apt-get install -y tesseract-ocr

In [None]:
import gradio as gr
import os
import tempfile
import time
from PyPDF2 import PdfReader
from pdf2image import convert_from_path
import pytesseract
from PIL import Image
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import pipeline
import re

class AdvancedRAGChatbot:
    def __init__(self):
        # Initialize models for embeddings and QA
        print("🔄 Loading embedding model...")
        self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

        print("🔄 Loading QA model...")
        self.qa_pipeline = pipeline(
            "question-answering",
            model="distilbert-base-cased-distilled-squad",
            tokenizer="distilbert-base-cased"
        )

        self.index = None
        self.chunks = []
        self.metadata = []
        self.chat_history = []
        print("✅ Models loaded successfully!")

    def extract_text_from_pdf(self, file_path, is_scanned=False):
        """Extract text from PDF, using OCR for scanned documents"""
        text = ""
        try:
            if is_scanned:
                # Convert PDF pages to images and use OCR
                images = convert_from_path(file_path, dpi=200)
                for i, image in enumerate(images):
                    page_text = pytesseract.image_to_string(image)
                    text += f"--- Page {i+1} ---\n{page_text}\n"
            else:
                # Direct text extraction for digital PDFs
                pdf_reader = PdfReader(file_path)
                for i, page in enumerate(pdf_reader.pages):
                    page_text = page.extract_text()
                    if page_text:
                        text += f"--- Page {i+1} ---\n{page_text}\n"
        except Exception as e:
            print(f"❌ Error extracting text: {e}")
        return text

    def clean_and_chunk_text(self, text, chunk_size=500, chunk_overlap=50):
        """Split text into manageable chunks for processing"""
        # Clean the text first
        text = re.sub(r'\n+', '\n', text)  # Remove excessive newlines
        text = re.sub(r'[^\x00-\x7F]+', ' ', text)  # Remove non-ASCII characters

        paragraphs = re.split(r'\n\s*\n', text)
        chunks = []
        current_chunk = ""

        for paragraph in paragraphs:
            paragraph = paragraph.strip()
            if not paragraph:
                continue

            if len(current_chunk) + len(paragraph) <= chunk_size:
                current_chunk += paragraph + "\n"
            else:
                if current_chunk:
                    chunks.append(current_chunk.strip())
                current_chunk = paragraph + "\n"

        if current_chunk.strip():
            chunks.append(current_chunk.strip())

        print(f"📝 Created {len(chunks)} chunks from text")
        return chunks

    def process_documents(self, files, document_types):
        """Process uploaded documents and create search index"""
        all_chunks = []
        all_metadata = []

        for i, file_info in enumerate(files):
            # Handle Gradio file object
            if isinstance(file_info, tuple):
                file_path = file_info[0]
                file_name = os.path.basename(file_path)
            elif hasattr(file_info, 'name'):
                file_path = file_info.name
                file_name = getattr(file_info, 'orig_name', os.path.basename(file_path))
            else:
                file_path = file_info
                file_name = os.path.basename(file_path)

            file_type = document_types[i] if i < len(document_types) else "digital"
            is_scanned = file_type == "scanned"

            print(f"📄 Processing {file_name} (type: {file_type})...")

            try:
                # Create temporary file for processing
                with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
                    if hasattr(file_info, 'read'):
                        tmp_file.write(file_info.read())
                    elif isinstance(file_info, bytes):
                        tmp_file.write(file_info)
                    else:
                        with open(file_path, 'rb') as f:
                            tmp_file.write(f.read())
                    temp_path = tmp_file.name

                text = self.extract_text_from_pdf(temp_path, is_scanned)

                if not text.strip():
                    print(f"⚠️ No text extracted from {file_name}")
                    continue

                print(f"📖 Extracted {len(text)} characters from {file_name}")

                chunks = self.clean_and_chunk_text(text)

                # Store chunks with metadata
                for j, chunk in enumerate(chunks):
                    all_chunks.append(chunk)
                    all_metadata.append({
                        "document_id": i,
                        "document_name": file_name,
                        "document_type": file_type,
                        "chunk_id": j,
                        "page_range": f"{(j//3)+1}-{(j//3)+2}",
                        "source": f"{file_name} - Chunk {j+1}"
                    })

                # Clean up
                os.unlink(temp_path)

            except Exception as e:
                print(f"❌ Error processing {file_name}: {str(e)}")
                return f"❌ Error processing {file_name}: {str(e)}"

        if not all_chunks:
            return "⚠️ No text could be extracted from the documents."

        # Create vector index for semantic search
        print("🔨 Creating vector index...")
        self.chunks = all_chunks
        self.metadata = all_metadata
        embeddings = self.embedding_model.encode(all_chunks)
        self.index = faiss.IndexFlatL2(embeddings.shape[1])
        self.index.add(np.array(embeddings))

        print(f"✅ Index created with {len(all_chunks)} chunks")
        return f"✅ Processed {len(files)} docs, {len(all_chunks)} chunks"

    def retrieve_context(self, query, k=3):
        """Retrieve relevant context using semantic search"""
        if self.index is None or len(self.chunks) == 0:
            return [], [], []

        print(f"🔍 Searching for: '{query}'")
        query_embedding = self.embedding_model.encode([query])
        distances, indices = self.index.search(np.array(query_embedding), min(k, len(self.chunks)))

        retrieved_chunks = [self.chunks[i] for i in indices[0]]
        retrieved_metadata = [self.metadata[i] for i in indices[0]]
        confidence_scores = [1 / (1 + dist) for dist in distances[0]]

        print(f"📚 Retrieved {len(retrieved_chunks)} relevant chunks")
        return retrieved_chunks, retrieved_metadata, confidence_scores

    def generate_answer(self, query, chat_history):
        """Generate answer using retrieved context and QA model"""
        if not self.chunks:
            return chat_history, "⚠️ Please upload and process documents first.", [], 0.0

        retrieved_chunks, retrieved_metadata, confidence_scores = self.retrieve_context(query)

        if not retrieved_chunks:
            return chat_history, "❌ No relevant context found in documents.", [], 0.0

        # Prepare context for QA model
        context = "\n\n".join([f"Source: {meta['source']}\nContent: {chunk}"
                              for chunk, meta in zip(retrieved_chunks, retrieved_metadata)])

        print(f"📋 Context length: {len(context)} characters")
        print(f"❓ Question: {query}")

        try:
            # Generate answer using transformer model
            print("🤖 Generating answer...")
            result = self.qa_pipeline({
                'context': context,
                'question': query
            })

            answer = result['answer']
            confidence = result['score']

            print(f"✅ Answer: {answer}")
            print(f"🎯 Confidence: {confidence}")

            # Prepare source citations
            sources = [{
                'source': meta['source'],
                'document_type': meta['document_type'],
                'confidence': round(conf, 3),
                'preview': chunk[:100] + "..."
            } for chunk, meta, conf in zip(retrieved_chunks, retrieved_metadata, confidence_scores)]

            # Format answer with confidence and source info
            confidence_emoji = "🔴" if confidence < 0.3 else "🟡" if confidence < 0.7 else "🟢"
            full_answer = f"{answer}\n\n{confidence_emoji} **Confidence: {confidence:.3f}** | 📚 **Sources: {len(sources)}**"

            chat_history.append((query, full_answer))
            return chat_history, "", sources, confidence

        except Exception as e:
            error_msg = f"❌ Error generating answer: {str(e)}"
            print(f"❌ QA Error: {e}")
            chat_history.append((query, error_msg))
            return chat_history, "", [], 0.0

    def save_chat_history(self):
        """Save chat history to a text file"""
        if not self.chat_history:
            return "⚠️ No chat history to save."

        filename = f"rag_chat_history_{int(time.time())}.txt"
        with open(filename, 'w') as f:
            f.write("🧠 RAG Chatbot - Conversation History\n")
            f.write("=" * 50 + "\n\n")
            for i, (question, answer) in enumerate(self.chat_history, 1):
                f.write(f"❓ Q{i}: {question}\n")
                f.write(f"💡 A{i}: {answer}\n")
                f.write("-" * 60 + "\n")

        return f"✅ Saved as {filename}"

# Initialize the chatbot
print("🚀 Initializing RAG Chatbot...")
rag_bot = AdvancedRAGChatbot()

# Fixed file processing wrapper
def process_docs_wrapper(files, doc_type, chunk_size_val):
    """Wrapper for document processing with proper file handling"""
    if not files:
        return "⚠️ Please upload PDF documents first"

    # Handle single file vs multiple files
    if not isinstance(files, list):
        files = [files]

    document_types = [doc_type] * len(files)
    return rag_bot.process_documents(files, document_types)

# Rest of your CSS and interface code remains the same...
custom_css = """
:root {
    --primary: #6366f1;
    --primary-dark: #4f46e5;
    --secondary: #f1f5f9;
    --accent: #8b5cf6;
    --success: #10b981;
    --warning: #f59e0b;
    --error: #ef4444;
    --text: #1e293b;
    --text-light: #64748b;
    --bg: #ffffff;
    --bg-secondary: #f8fafc;
    --border: #e2e8f0;
    --shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
}

.gradio-container {
    background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
    font-family: 'Inter', 'Segoe UI', system-ui, sans-serif !important;
    padding: 10px !important;
}

.main-container {
    background: var(--bg) !important;
    border-radius: 16px !important;
    box-shadow: var(--shadow) !important;
    margin: 0 !important;
    padding: 0 !important;
    max-width: 100% !important;
}

.gr-markdown h1 {
    background: linear-gradient(135deg, var(--primary) 0%, var(--accent) 100%) !important;
    -webkit-background-clip: text !important;
    -webkit-text-fill-color: transparent !important;
    background-clip: text !important;
    text-align: center !important;
    font-size: 2em !important;
    font-weight: 800 !important;
    margin: 10px 0 !important;
    padding: 0 !important;
}

.compact-card {
    background: var(--bg) !important;
    border: 1px solid var(--border) !important;
    border-radius: 12px !important;
    padding: 16px !important;
    margin: 8px 0 !important;
    box-shadow: var(--shadow) !important;
}

.horizontal-row {
    gap: 12px !important;
    margin: 0 !important;
    padding: 0 !important;
}

.horizontal-section {
    margin: 0 !important;
    padding: 0 !important;
}

.compact-button {
    padding: 10px 16px !important;
    margin: 4px 0 !important;
    border-radius: 10px !important;
    font-size: 13px !important;
    height: auto !important;
    min-height: auto !important;
}

.compact-input {
    padding: 10px 12px !important;
    border-radius: 10px !important;
    font-size: 13px !important;
    margin: 4px 0 !important;
}

.compact-chatbot {
    min-height: 350px !important;
    max-height: 350px !important;
    border-radius: 12px !important;
    margin: 4px 0 !important;
}

.compact-upload {
    border: 2px dashed var(--primary) !important;
    border-radius: 10px !important;
    padding: 20px 10px !important;
    margin: 4px 0 !important;
    min-height: 80px !important;
}

.compact-label {
    font-size: 12px !important;
    font-weight: 600 !important;
    margin-bottom: 4px !important;
    color: var(--text) !important;
}

.compact-json {
    max-height: 120px !important;
    overflow-y: auto !important;
    font-size: 11px !important;
    border-radius: 8px !important;
    padding: 8px !important;
    margin: 4px 0 !important;
}

.status-compact {
    font-size: 11px !important;
    padding: 6px 8px !important;
    border-radius: 8px !important;
    margin: 2px 0 !important;
}

.compact-radio {
    padding: 8px !important;
    border-radius: 8px !important;
    margin: 4px 0 !important;
}

.compact-slider {
    margin: 8px 0 !important;
    padding: 0 !important;
}

.compact-confidence {
    font-size: 12px !important;
    padding: 8px !important;
    border-radius: 8px !important;
    margin: 4px 0 !important;
}

.source-preview {
    font-size: 11px !important;
    max-width: 200px !important;
    white-space: nowrap !important;
    overflow: hidden !important;
    text-overflow: ellipsis !important;
}
"""

# Create the interface
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="🧠 RAG Chatbot") as demo:

    gr.Markdown(
        """
        # 🧠 Advanced RAG Chatbot
        *Upload PDFs • Ask Questions • Get AI Answers with Sources*
        """,
        elem_classes=["main-header"]
    )

    # Main horizontal layout
    with gr.Row(equal_height=True, variant="compact", elem_classes=["horizontal-row"]):

        # Left: Document Upload & Processing
        with gr.Column(scale=1, min_width=300, elem_classes=["horizontal-section"]):
            with gr.Group(elem_classes=["compact-card"]):
                gr.Markdown("**📄 Document Control**", elem_classes=["compact-label"])

                with gr.Row():
                    file_input = gr.File(
                        file_count="multiple",
                        file_types=[".pdf"],
                        label="Upload PDFs",
                        elem_classes=["compact-upload"],
                        scale=3
                    )

                with gr.Row():
                    document_type = gr.Radio(
                        choices=["digital", "scanned"],
                        value="digital",
                        label="Type",
                        elem_classes=["compact-radio"],
                        scale=2
                    )
                    process_btn = gr.Button(
                        "🚀 Process",
                        variant="primary",
                        elem_classes=["compact-button"],
                        scale=1
                    )

                process_status = gr.Textbox(
                    label="Status",
                    interactive=False,
                    elem_classes=["status-compact"],
                    show_label=True
                )

        # Middle: Chat Input & Controls
        with gr.Column(scale=2, min_width=400, elem_classes=["horizontal-section"]):
            with gr.Group(elem_classes=["compact-card"]):
                gr.Markdown("**💬 Chat Interface**", elem_classes=["compact-label"])

                chatbot = gr.Chatbot(
                    label="Conversation",
                    elem_classes=["compact-chatbot"],
                    show_copy_button=True,
                    bubble_full_width=False
                )

                with gr.Row():
                    query_input = gr.Textbox(
                        placeholder="Ask a question about your documents...",
                        label="Question",
                        elem_classes=["compact-input"],
                        scale=4,
                        max_lines=2
                    )
                    ask_btn = gr.Button(
                        "Ask 🚀",
                        variant="primary",
                        elem_classes=["compact-button"],
                        scale=1
                    )

        # Right: Settings & Analytics
        with gr.Column(scale=1, min_width=250, elem_classes=["horizontal-section"]):
            with gr.Group(elem_classes=["compact-card"]):
                gr.Markdown("**⚙️ Controls & Analytics**", elem_classes=["compact-label"])

                with gr.Row():
                    save_btn = gr.Button(
                        "💾 Save Chat",
                        elem_classes=["compact-button"],
                        scale=1
                    )
                    retrieval_k = gr.Slider(
                        1, 10, value=3, step=1,
                        label="Retrieval Count",
                        elem_classes=["compact-slider"],
                        scale=2
                    )

                confidence_score = gr.Number(
                    label="Confidence Score",
                    precision=3,
                    elem_classes=["compact-confidence"]
                )

    # BOTTOM ROW: Results & Sources
    with gr.Row(equal_height=True, variant="compact", elem_classes=["horizontal-row"]):

        # Left: Advanced Settings
        with gr.Column(scale=1, min_width=200, elem_classes=["horizontal-section"]):
            with gr.Group(elem_classes=["compact-card"]):
                gr.Markdown("**🔧 Advanced Settings**", elem_classes=["compact-label"])

                chunk_size = gr.Slider(
                    100, 1000, value=500,
                    label="Chunk Size",
                    elem_classes=["compact-slider"]
                )

                save_status = gr.Textbox(
                    label="Save Status",
                    interactive=False,
                    elem_classes=["status-compact"],
                    show_label=True
                )

        # Right: Source Citations (wider)
        with gr.Column(scale=3, min_width=500, elem_classes=["horizontal-section"]):
            with gr.Group(elem_classes=["compact-card"]):
                gr.Markdown("**📚 Source Citations & Context**", elem_classes=["compact-label"])
                sources_output = gr.JSON(
                    label="Retrieved Sources",
                    elem_classes=["compact-json"]
                )

    def ask_question_wrapper(query, chat_history, k_value):
        """Wrapper for question answering"""
        if not query.strip():
            return chat_history, "", [], 0.0
        # Temporarily override retrieve_context with custom k value
        original_retrieve = rag_bot.retrieve_context
        rag_bot.retrieve_context = lambda q: original_retrieve(q, k_value)
        result = rag_bot.generate_answer(query, chat_history)
        rag_bot.retrieve_context = original_retrieve
        return result

    # Connect UI components to functions
    process_btn.click(
        fn=process_docs_wrapper,
        inputs=[file_input, document_type, chunk_size],
        outputs=[process_status]
    )

    ask_btn.click(
        fn=ask_question_wrapper,
        inputs=[query_input, chatbot, retrieval_k],
        outputs=[chatbot, query_input, sources_output, confidence_score]
    )

    query_input.submit(
        fn=ask_question_wrapper,
        inputs=[query_input, chatbot, retrieval_k],
        outputs=[chatbot, query_input, sources_output, confidence_score]
    )

    save_btn.click(
        fn=rag_bot.save_chat_history,
        outputs=[save_status]
    )

    # Initialization message
    def init_message():
        return "👋 Upload PDFs and click Process to start"

    demo.load(init_message, outputs=[process_status])

# Launch the application
if __name__ == "__main__":
    print("🌐 Launching Gradio interface...")
    demo.launch(
        share=True,
        debug=True,
        show_error=True,
        server_name="0.0.0.0",
        server_port=7860
    )