<a href="https://colab.research.google.com/github/seansphd/CombinedSearch/blob/main/CombinedSearch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Colab Notebook: JSON Knowledge Base Chatbot with Live Editable System Prompt
#
# This notebook builds a lightweight QA tool over a JSON or JSON-LD knowledge base.
# You can steer behaviour with a system prompt and an optional bias theme.
# The tool pulls in text from linked PDFs, selected by simple keyword scoring and embedding similarity.

# When run you will need to enter an OpenAI API key that can be provided for testing purposes
# It will then ask you to load the json knowledge base

# The system prompt does not need to be changed. You can edit it if for instance you want the system to be biased towards "Cybernetics".
# This can be added in plain text

# When you ask the question you will recieve a few responses.. these will include infrormation about the pdfs selected for use through keyword and semantic search.
# You will then recieve the chatr response that is based on the selected PDFs and the system prompt

# 🛠️ Install dependencies
# Note: Colab supports shell commands with a leading !. This installs all Python packages used below.
!pip install -q openai PyMuPDF requests numpy scikit-learn psutil ipywidgets pillow pdf2image pytesseract

# 📦 Imports
import os
import json
import fitz  # PyMuPDF (fast PDF text extraction)
import requests
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from openai import OpenAI  # Official OpenAI Python SDK
from google.colab import files  # File upload dialog in Colab
from IPython.display import display, HTML, clear_output  # UI helpers
import ipywidgets as widgets  # Interactive controls
import psutil  # Process and memory information
import re
import html
from functools import lru_cache  # In-memory caching for repeated PDF fetches

# Enable widgets in Colab
# The custom widget manager is needed for certain ipywidgets to render correctly in Colab.
try:
    from google.colab import output as colab_output
    colab_output.enable_custom_widget_manager()
except Exception:
    # Safe to ignore if unavailable
    pass

# ----------------------------
# Global state
# ----------------------------
chat_history_html = ""  # Accumulates HTML content of the chat for display
api_key_ready = False   # Tracks whether the API key has been entered

# ----------------------------
# API key UI
# ----------------------------
# Password widget hides input while typing.
api_key_input = widgets.Password(
    description='API Key:',
    placeholder='Enter your OpenAI API key',
    layout=widgets.Layout(width='50%')
)
# Button to save the key into the environment.
submit_button = widgets.Button(description='Save API Key', button_style='info')
# Output area to print status messages without cluttering the notebook.
output_area = widgets.Output()

def set_api_key(_):
    """Save API key into environment and mark readiness."""
    global api_key_ready
    with output_area:
        clear_output()
        if api_key_input.value.strip() == "":
            # Basic validation to avoid empty keys.
            display(HTML("<b style='color:#e53935;'>❌ Please enter a valid API key.</b>"))
        else:
            os.environ["OPENAI_API_KEY"] = api_key_input.value.strip()
            api_key_ready = True
            display(HTML("<b style='color:#43a047;'>✅ API key saved. You can continue.</b>"))

# Wire up the click handler and render the input UI.
submit_button.on_click(set_api_key)
display(HTML("<h3>🔑 Enter your OpenAI API key to continue:</h3>"))
display(api_key_input, submit_button, output_area)

# ----------------------------
# OpenAI client
# ----------------------------
def get_openai_client():
    """
    Return an OpenAI client instance.
    Priority for key:
      1) Existing environment variable (already saved),
      2) Current widget value (if present).
    Raises a clear error if missing.
    """
    key = os.getenv("OPENAI_API_KEY")
    try:
        if (not key) and api_key_input.value:
            key = api_key_input.value.strip()
            if key:
                os.environ["OPENAI_API_KEY"] = key
    except NameError:
        # api_key_input can be missing in some contexts
        pass
    if not key:
        raise ValueError("API key not set. Please enter your API key and try again.")
    return OpenAI(api_key=key)

# ----------------------------
# JSON loader with JSON-LD support
# ----------------------------
def normalise_to_list(obj):
    """
    Accepts:
      - A list of entries,
      - A dict holding a list under common JSON-LD or container keys,
      - A single entry dict.
    Returns a list of entries for uniform downstream handling.
    """
    if isinstance(obj, list):
        return obj
    if isinstance(obj, dict):
        # Common containers in JSON-LD and ad-hoc exports.
        for k in ["@graph", "graph", "data", "items", "entries", "documents", "records"]:
            if k in obj and isinstance(obj[k], list):
                return obj[k]
        # Treat as a single entry if no container is found.
        return [obj]
    raise ValueError("Unsupported JSON structure")

def upload_json():
    """
    Prompt the user to upload a JSON or JSON-LD file.
    Normalise it to a list and report entry count.
    """
    display(HTML("<h3>📥 Upload your JSON knowledge base file...</h3>"))
    uploaded = files.upload()  # Opens a file chooser in Colab
    if not uploaded:
        raise ValueError("No file uploaded.")
    # Take the first uploaded file
    json_file_path = next(iter(uploaded))
    with open(json_file_path, "r") as f:
        kb_raw = json.load(f)
    kb = normalise_to_list(kb_raw)
    display(HTML(f"<b>✅ Loaded {len(kb)} entries from {html.escape(json_file_path)}</b>"))
    return kb

# ----------------------------
# Summaries and embeddings
# ----------------------------
def entry_summary(entry):
    """
    Produce a brief text summary for an entry.
    Preference order:
      1) Existing summary fields,
      2) Concise metadata concatenation,
      3) JSON string fallback (truncated).
    """
    for k in ["summary", "abstract", "description"]:
        if isinstance(entry.get(k), str) and entry[k].strip():
            return entry[k]
    parts = []
    for k in ["title", "date", "datePublished", "author", "keywords"]:
        v = entry.get(k)
        if isinstance(v, list):
            parts.append(" ".join(map(str, v)))
        elif isinstance(v, str):
            parts.append(v)
    if not parts:
        # Fall back to compact JSON if no obvious human fields
        return json.dumps(entry)[:1000]
    return " | ".join(parts)

def create_embeddings(kb):
    """
    Create embeddings for each entry summary.
    Uses text-embedding-3-small for cost-effective vectors.
    Returns:
      summaries: list[str]
      summary_embeddings: np.ndarray shape (N, D)
    """
    client = get_openai_client()
    display(HTML("<b>🔄 Generating embeddings for entries...</b>"))

    def get_embedding(text):
        # The API expects a string input and returns a single vector for each input string.
        resp = client.embeddings.create(
            model="text-embedding-3-small",
            input=text if text else ""
        )
        return np.array(resp.data[0].embedding, dtype=np.float32)

    summaries = [entry_summary(e) for e in kb]
    # Stack into an array of shape (num_entries, embedding_dim)
    summary_embeddings = np.stack([get_embedding(s) for s in summaries], axis=0)
    display(HTML("<b>✅ Embeddings created.</b>"))
    return summaries, summary_embeddings

# ----------------------------
# PDF fetch with cache and optional OCR
# ----------------------------
def maybe_install_ocr_deps():
    """
    Install system packages used for OCR on demand:
      - poppler-utils provides pdftoppm for rasterisation,
      - tesseract-ocr performs OCR.
    Runs only if missing to save time.
    """
    import shutil
    needs_poppler = shutil.which("pdftoppm") is None
    needs_tesseract = shutil.which("tesseract") is None
    if needs_poppler or needs_tesseract:
        print("Installing OCR system packages. This may take a minute...")
        # Quiet apt to avoid excessive logs
        os.system("apt-get update -y >/dev/null 2>&1")
        if needs_poppler:
            os.system("apt-get install -y poppler-utils >/dev/null 2>&1")
        if needs_tesseract:
            os.system("apt-get install -y tesseract-ocr >/dev/null 2>&1")

@lru_cache(maxsize=256)
def fetch_pdf_bytes(url):
    """
    Download PDF bytes with a small cache to avoid repeated network calls.
    GitHub 'blob' URLs are converted to their raw counterparts.
    """
    if not url:
        raise ValueError("Empty URL")
    raw_url = url.replace("github.com", "raw.githubusercontent.com").replace("/blob", "")
    resp = requests.get(raw_url, timeout=60)
    resp.raise_for_status()
    return resp.content

def extract_text_pymupdf(pdf_bytes):
    """
    Extract text using PyMuPDF. Works well for digital PDFs.
    Returns an empty string when no extractable text exists.
    """
    try:
        text = ""
        with fitz.open(stream=pdf_bytes, filetype="pdf") as pdf_doc:
            for page in pdf_doc:
                # get_text returns a plain string per page
                text += page.get_text() or ""
        return text
    except Exception as e:
        # Keep errors readable in the UI
        return f"[PDF parse error: {e}]"

def ocr_pdf_bytes(pdf_bytes):
    """
    Fallback OCR path:
      1) Convert pages to images with pdf2image,
      2) Run Tesseract to recognise text,
      3) Join page texts.
    """
    try:
        from pdf2image import convert_from_bytes
        import pytesseract
        pages = convert_from_bytes(pdf_bytes, fmt="png", dpi=200)
        texts = []
        for img in pages:
            texts.append(pytesseract.image_to_string(img))
        return "\n".join(texts)
    except Exception as e:
        return f"[OCR error: {e}]"

def fetch_pdf_text(url, try_ocr=False):
    """
    Retrieve text from a PDF URL.
    If try_ocr is True and extracted text is very short, attempt OCR.
    """
    try:
        b = fetch_pdf_bytes(url)
        text = extract_text_pymupdf(b)
        if try_ocr:
            # Heuristic to decide whether OCR might help
            if len(text.strip()) < 100:
                ocr_text = ocr_pdf_bytes(b)
                if len(ocr_text.strip()) > len(text.strip()):
                    return ocr_text
        return text
    except Exception as e:
        return f"[PDF fetch error: {e}]"

# ----------------------------
# Helpers
# ----------------------------
# Minimal stopword list for simple entity extraction. This keeps logic transparent.
stopwords = set("the and is it to for in of on with a an as at by from this that".split())

def extract_entities(text):
    """
    Tokenise to words and drop stopwords and short tokens.
    This is a crude keyword extractor suited to quick scoring.
    """
    words = re.findall(r'\b\w+\b', text.lower())
    return [w for w in words if w not in stopwords and len(w) > 2]

def keyword_score(entry, entities):
    """
    Score an entry by counting occurrences of entity tokens
    across all field values concatenated into one string.
    """
    combined = " ".join(str(entry.get(field, "")) for field in entry).lower()
    return sum(combined.count(e) for e in entities)

def print_memory_usage(conversation_history):
    """
    Show process RAM, system RAM usage in Colab, and chat history size.
    Helps diagnose out-of-memory issues.
    """
    process = psutil.Process(os.getpid())
    ram_used_mb = process.memory_info().rss / (1024 * 1024)
    vmem = psutil.virtual_memory()
    colab_ram = vmem.used / (1024 * 1024)
    colab_total_ram = vmem.total / (1024 * 1024)
    display(HTML(f"""
        <div style='border:1px solid #444; padding:6px; margin:6px; border-radius:6px; background:#222; color:#ccc; font-family:monospace;'>
        🧠 <b>Memory Stats:</b><br>
        🔹 Process RAM: {ram_used_mb:.2f} MB<br>
        🔹 Colab RAM: {colab_ram:.2f} MB / {colab_total_ram:.2f} MB<br>
        🔹 Chat History Size: {len(json.dumps(conversation_history))/1024:.1f} KB
        </div>
    """))

# ----------------------------
# Chat UI
# ----------------------------
def chat_ui(kb, summaries, summary_embeddings):
    """
    Build and run the interactive chat UI:
      - Choose bias and system prompt,
      - Pick relevant entries,
      - Pull PDF text context,
      - Call the Chat Completions API,
      - Render a running transcript.
    """
    global chat_history_html
    client = get_openai_client()

    # Controls
    base_system_prompt = "You are a helpful assistant with access to a structured JSON knowledge base and linked PDFs."
    bias_input = widgets.Text(
        value="",
        placeholder='Optional bias theme, e.g. Cybernetics',
        description='Bias:',
        layout=widgets.Layout(width='100%')
    )
    system_prompt_input = widgets.Textarea(
        value=base_system_prompt,
        placeholder='Enter the base system prompt here...',
        description='System Prompt:',
        layout=widgets.Layout(width='100%', height='80px')
    )
    # Top K entries to include as context
    top_k_slider = widgets.IntSlider(value=2, min=1, max=10, step=1, description='Top K:')
    # Balance between embeddings and keyword hits
    kw_weight_slider = widgets.FloatSlider(value=0.5, min=0.0, max=2.0, step=0.1, description='Keyword weight:')
    # Max characters per document context included in the prompt
    ctx_chars_slider = widgets.IntSlider(value=6000, min=1000, max=20000, step=500, description='Chars/doc:')
    # OCR toggle for scanned PDFs that lack extractable text
    ocr_checkbox = widgets.Checkbox(value=False, description='Use OCR fallback for scanned PDFs')
    # Optional memory stats after each reply to aid debugging
    show_mem_checkbox = widgets.Checkbox(value=False, description='Show memory stats after replies')

    # State
    chat_messages = []  # Conversation turns for the Chat Completions API
    query_embed_cache = {}  # Avoid recomputing embeddings for repeated queries

    def get_embedding(text):
        """
        Cache embeddings per unique query prefix to save tokens and time.
        """
        key = text[:2000]
        if key in query_embed_cache:
            return query_embed_cache[key]
        resp = client.embeddings.create(model="text-embedding-3-small", input=text if text else "")
        arr = np.array(resp.data[0].embedding, dtype=np.float32)
        query_embed_cache[key] = arr
        return arr

    def current_system_prompt():
        """
        Combine base system prompt and optional bias theme.
        The bias is appended as plain text for transparency.
        """
        bias = bias_input.value.strip()
        if bias:
            return f"{system_prompt_input.value}\nBias theme: {bias}"
        return system_prompt_input.value

    def update_chat_display():
        """
        Render the fixed system prompt banner and scrollable chat area.
        """
        safe_prompt = html.escape(current_system_prompt())
        display(HTML(f"""
        <div style='position:sticky; top:0; background:#222; color:#fff; padding:8px; border-bottom:2px solid #555; z-index:10;'>
          📝 <b>System Prompt:</b> {safe_prompt} | 🔑 <b>API Key:</b> ✅ Active
        </div>
        <div style='max-height:420px; overflow-y:auto; border:1px solid #444; background:#111; color:#ddd; padding:10px; border-radius:8px;'>
          {chat_history_html}
        </div>
        """))

    def pick_top_entries(user_text, top_k, kw_weight):
        """
        Rank entries by:
          cosine_similarity(embedding(query), embedding(summary))
          plus a weighted keyword score.
        Return indices and the combined score vector.
        """
        q_emb = get_embedding(user_text)
        sims = cosine_similarity([q_emb], summary_embeddings)[0]
        ents = extract_entities(user_text)
        kw_scores = np.array([keyword_score(e, ents) for e in kb], dtype=np.float32)
        # Normalise keyword scores by max to avoid scaling issues
        combined = sims + kw_weight * (kw_scores / (kw_scores.max() if kw_scores.max() > 0 else 1.0))
        idxs = np.argsort(combined)[-top_k:][::-1]
        return [idx for idx in idxs], combined

    def safe_html(s):
        """
        Escape HTML and convert newlines to <br> for display.
        """
        return html.escape(s).replace("\n", "<br>")

    def on_send(_):
        """
        Handle a user query:
          - Optionally install OCR tools,
          - Select top entries,
          - Fetch and truncate PDF texts,
          - Call the chat model,
          - Update the transcript and optional memory stats.
        """
        global chat_history_html
        user_input = text_input.value.strip()
        text_input.value = ""
        if not user_input:
            return

        if ocr_checkbox.value:
            maybe_install_ocr_deps()

        top_k = top_k_slider.value
        kw_w = kw_weight_slider.value
        ctx_chars = ctx_chars_slider.value

        # Entry selection using combined score
        top_idxs, combined = pick_top_entries(user_input, top_k=top_k, kw_weight=kw_w)
        selected_entries = [kb[i] for i in top_idxs]

        # Show brief cards for each selected entry
        for i in top_idxs:
            entry = kb[i]
            title = entry.get('title') or entry.get('name') or 'Untitled'
            date = entry.get('date') or entry.get('datePublished') or ''
            url = entry.get('url') or entry.get('pdf') or entry.get('link') or ''
            summ = summaries[i][:300] + ("..." if len(summaries[i]) > 300 else "")
            url_html = f"<a href='{html.escape(url)}' target='_blank' style='color:#1e90ff;'>Open</a>" if url else "<i>No link</i>"
            chat_history_html += f"""
            <div style='background:#333; padding:8px; margin:6px 0; border-radius:6px;'>
              📄 <b>{html.escape(title)}</b> {f"({html.escape(str(date))})" if date else ""}<br>
              <i>{html.escape(summ)}</i><br>
              {url_html}
            </div>
            """

        # Pull text from linked PDFs and truncate to the per-doc limit
        texts = []
        for e in selected_entries:
            url = e.get('url') or e.get('pdf') or e.get('link') or ''
            doc_text = fetch_pdf_text(url, try_ocr=ocr_checkbox.value) if url else ""
            if isinstance(doc_text, str) and len(doc_text) > ctx_chars:
                doc_text = doc_text[:ctx_chars] + f"\n[Truncated to {ctx_chars} characters]"
            texts.append(doc_text)

        # Concatenate the selected texts with clear separators
        combined_pdf_text = "\n\n---\n\n".join(texts)

        # Construct a user message that includes the gathered context plainly
        user_message = {
            "role": "user",
            "content": f"Relevant PDF Content:\n{combined_pdf_text}\n\nUser Question: {user_input}"
        }
        chat_messages.append(user_message)

        # Prepend the current system prompt for each request
        conversation_history = [{"role": "system", "content": current_system_prompt()}] + chat_messages

        # Call the chat model. Temperature set to 0.7 for balanced output.
        try:
            response = client.chat.completions.create(
                model="gpt-4o-mini",
                messages=conversation_history,
                max_tokens=700,
                temperature=0.7,
            )
            assistant_msg = response.choices[0].message.content
        except Exception as e:
            # Surface API errors to the UI without stopping the app
            assistant_msg = f"⚠️ API Error: {e}"

        # Append assistant reply to our running state
        chat_messages.append({"role": "assistant", "content": assistant_msg})

        # Render the assistant reply block
        chat_history_html += f"""
        <div style='background:#222; padding:8px; margin:6px 0; border-radius:6px;'>
          <b style='color:#ffdd57;'>Assistant:</b><br>{safe_html(assistant_msg)}
        </div>
        """

        update_chat_display()
        # Optional memory diagnostics
        if show_mem_checkbox.value:
            print_memory_usage(conversation_history)

    # Input field and buttons for the chat loop
    text_input = widgets.Text(placeholder='Type your question here...')
    send_button = widgets.Button(description='Send', button_style='success')
    send_button.on_click(on_send)
    clear_button = widgets.Button(description='Clear Chat', button_style='danger')

    def on_clear(_):
        """
        Clear the transcript and in-memory message history.
        """
        global chat_history_html
        chat_history_html = ""
        chat_messages.clear()
        update_chat_display()
    clear_button.on_click(on_clear)

    # Group controls together and show the UI
    controls = widgets.HBox([top_k_slider, kw_weight_slider, ctx_chars_slider, ocr_checkbox, show_mem_checkbox])
    update_chat_display()
    display(widgets.VBox([bias_input, system_prompt_input, controls, widgets.HBox([text_input, send_button, clear_button])]))

# ----------------------------
# Proceed gate
# ----------------------------
# Separate button to enforce ordering:
#  1) Provide API key,
#  2) Upload JSON,
#  3) Build embeddings,
#  4) Open chat UI.
proceed_button = widgets.Button(description='Proceed', button_style='primary')
proceed_out = widgets.Output()

def proceed(_):
    """
    Guarded flow to ensure the API key is present before prompting for a file.
    Any exceptions are displayed inline.
    """
    with proceed_out:
        clear_output()
        if not api_key_ready and not os.getenv("OPENAI_API_KEY"):
            display(HTML("<b style='color:#e53935;'>❌ Please save your API key first.</b>"))
            return
        try:
            kb = upload_json()
            summaries, summary_embeddings = create_embeddings(kb)
            chat_ui(kb, summaries, summary_embeddings)
        except Exception as e:
            display(HTML(f"<b style='color:#e53935;'>⚠️ Error:</b> {html.escape(str(e))}"))

display(HTML("<h3>▶️ When your API key shows as saved, click Proceed:</h3>"))
display(proceed_button, proceed_out)
proceed_button.on_click(proceed)


