In [None]:
# Colab Notebook: JSON Knowledge Base Chatbot with Live Editable System Prompt

# When run you will need to enter an openAI API key that can be provided for testing purposes
# It will then ask you to load the json knowledge base

# The system prompt does not need to be changed. You can edit it if for instance you want the system to be biased towards "Cybernetics".
# This can be added in plain text

# When you ask the question you will recieve a few responses.. these will include infrormation about the pdfs selected for use through keyword and semantic search.
# You will then recieve the chatr response that is based on the selected PDFs and the system prompt

# 🛠️ Install dependencies
!pip install -q openai PyMuPDF requests numpy scikit-learn psutil ipywidgets pillow pdf2image pytesseract

# 📦 Imports
import os
import json
import fitz  # PyMuPDF
import requests
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from openai import OpenAI
from google.colab import files
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
import psutil
import re
import html
from functools import lru_cache

# Enable widgets in Colab
try:
    from google.colab import output as colab_output
    colab_output.enable_custom_widget_manager()
except Exception:
    pass

# ----------------------------
# Global state
# ----------------------------
chat_history_html = ""
api_key_ready = False

# ----------------------------
# API key UI
# ----------------------------
api_key_input = widgets.Password(
    description='API Key:',
    placeholder='Enter your OpenAI API key',
    layout=widgets.Layout(width='50%')
)
submit_button = widgets.Button(description='Save API Key', button_style='info')
output_area = widgets.Output()

def set_api_key(_):
    global api_key_ready
    with output_area:
        clear_output()
        if api_key_input.value.strip() == "":
            display(HTML("<b style='color:#e53935;'>❌ Please enter a valid API key.</b>"))
        else:
            os.environ["OPENAI_API_KEY"] = api_key_input.value.strip()
            api_key_ready = True
            display(HTML("<b style='color:#43a047;'>✅ API key saved. You can continue.</b>"))

submit_button.on_click(set_api_key)
display(HTML("<h3>🔑 Enter your OpenAI API key to continue:</h3>"))
display(api_key_input, submit_button, output_area)

# ----------------------------
# OpenAI client
# ----------------------------
def get_openai_client():
    key = os.getenv("OPENAI_API_KEY")
    try:
        if (not key) and api_key_input.value:
            key = api_key_input.value.strip()
            if key:
                os.environ["OPENAI_API_KEY"] = key
    except NameError:
        pass
    if not key:
        raise ValueError("API key not set. Please enter your API key and try again.")
    return OpenAI(api_key=key)

# ----------------------------
# JSON loader with JSON-LD support
# ----------------------------
def normalise_to_list(obj):
    if isinstance(obj, list):
        return obj
    if isinstance(obj, dict):
        for k in ["@graph", "graph", "data", "items", "entries", "documents", "records"]:
            if k in obj and isinstance(obj[k], list):
                return obj[k]
        return [obj]
    raise ValueError("Unsupported JSON structure")

def upload_json():
    display(HTML("<h3>📥 Upload your JSON knowledge base file...</h3>"))
    uploaded = files.upload()
    if not uploaded:
        raise ValueError("No file uploaded.")
    json_file_path = next(iter(uploaded))
    with open(json_file_path, "r") as f:
        kb_raw = json.load(f)
    kb = normalise_to_list(kb_raw)
    display(HTML(f"<b>✅ Loaded {len(kb)} entries from {html.escape(json_file_path)}</b>"))
    return kb

# ----------------------------
# Summaries and embeddings
# ----------------------------
def entry_summary(entry):
    for k in ["summary", "abstract", "description"]:
        if isinstance(entry.get(k), str) and entry[k].strip():
            return entry[k]
    parts = []
    for k in ["title", "date", "datePublished", "author", "keywords"]:
        v = entry.get(k)
        if isinstance(v, list):
            parts.append(" ".join(map(str, v)))
        elif isinstance(v, str):
            parts.append(v)
    if not parts:
        return json.dumps(entry)[:1000]
    return " | ".join(parts)

def create_embeddings(kb):
    client = get_openai_client()
    display(HTML("<b>🔄 Generating embeddings for entries...</b>"))

    def get_embedding(text):
        resp = client.embeddings.create(
            model="text-embedding-3-small",
            input=text if text else ""
        )
        return np.array(resp.data[0].embedding, dtype=np.float32)

    summaries = [entry_summary(e) for e in kb]
    summary_embeddings = np.stack([get_embedding(s) for s in summaries], axis=0)
    display(HTML("<b>✅ Embeddings created.</b>"))
    return summaries, summary_embeddings

# ----------------------------
# PDF fetch with cache and optional OCR
# ----------------------------
def maybe_install_ocr_deps():
    import shutil
    needs_poppler = shutil.which("pdftoppm") is None
    needs_tesseract = shutil.which("tesseract") is None
    if needs_poppler or needs_tesseract:
        print("Installing OCR system packages. This may take a minute...")
        os.system("apt-get update -y >/dev/null 2>&1")
        if needs_poppler:
            os.system("apt-get install -y poppler-utils >/dev/null 2>&1")
        if needs_tesseract:
            os.system("apt-get install -y tesseract-ocr >/dev/null 2>&1")

@lru_cache(maxsize=256)
def fetch_pdf_bytes(url):
    if not url:
        raise ValueError("Empty URL")
    raw_url = url.replace("github.com", "raw.githubusercontent.com").replace("/blob", "")
    resp = requests.get(raw_url, timeout=60)
    resp.raise_for_status()
    return resp.content

def extract_text_pymupdf(pdf_bytes):
    try:
        text = ""
        with fitz.open(stream=pdf_bytes, filetype="pdf") as pdf_doc:
            for page in pdf_doc:
                text += page.get_text() or ""
        return text
    except Exception as e:
        return f"[PDF parse error: {e}]"

def ocr_pdf_bytes(pdf_bytes):
    try:
        from pdf2image import convert_from_bytes
        import pytesseract
        pages = convert_from_bytes(pdf_bytes, fmt="png", dpi=200)
        texts = []
        for img in pages:
            texts.append(pytesseract.image_to_string(img))
        return "\n".join(texts)
    except Exception as e:
        return f"[OCR error: {e}]"

def fetch_pdf_text(url, try_ocr=False):
    try:
        b = fetch_pdf_bytes(url)
        text = extract_text_pymupdf(b)
        if try_ocr:
            if len(text.strip()) < 100:
                ocr_text = ocr_pdf_bytes(b)
                if len(ocr_text.strip()) > len(text.strip()):
                    return ocr_text
        return text
    except Exception as e:
        return f"[PDF fetch error: {e}]"

# ----------------------------
# Helpers
# ----------------------------
stopwords = set("the and is it to for in of on with a an as at by from this that".split())
def extract_entities(text):
    words = re.findall(r'\b\w+\b', text.lower())
    return [w for w in words if w not in stopwords and len(w) > 2]

def keyword_score(entry, entities):
    combined = " ".join(str(entry.get(field, "")) for field in entry).lower()
    return sum(combined.count(e) for e in entities)

def print_memory_usage(conversation_history):
    process = psutil.Process(os.getpid())
    ram_used_mb = process.memory_info().rss / (1024 * 1024)
    vmem = psutil.virtual_memory()
    colab_ram = vmem.used / (1024 * 1024)
    colab_total_ram = vmem.total / (1024 * 1024)
    display(HTML(f"""
        <div style='border:1px solid #444; padding:6px; margin:6px; border-radius:6px; background:#222; color:#ccc; font-family:monospace;'>
        🧠 <b>Memory Stats:</b><br>
        🔹 Process RAM: {ram_used_mb:.2f} MB<br>
        🔹 Colab RAM: {colab_ram:.2f} MB / {colab_total_ram:.2f} MB<br>
        🔹 Chat History Size: {len(json.dumps(conversation_history))/1024:.1f} KB
        </div>
    """))

# ----------------------------
# Chat UI
# ----------------------------
def chat_ui(kb, summaries, summary_embeddings):
    global chat_history_html
    client = get_openai_client()

    # Controls
    base_system_prompt = "You are a helpful assistant with access to a structured JSON knowledge base and linked PDFs."
    bias_input = widgets.Text(
        value="",
        placeholder='Optional bias theme, e.g. Cybernetics',
        description='Bias:',
        layout=widgets.Layout(width='100%')
    )
    system_prompt_input = widgets.Textarea(
        value=base_system_prompt,
        placeholder='Enter the base system prompt here...',
        description='System Prompt:',
        layout=widgets.Layout(width='100%', height='80px')
    )
    top_k_slider = widgets.IntSlider(value=2, min=1, max=10, step=1, description='Top K:')
    kw_weight_slider = widgets.FloatSlider(value=0.5, min=0.0, max=2.0, step=0.1, description='Keyword weight:')
    ctx_chars_slider = widgets.IntSlider(value=6000, min=1000, max=20000, step=500, description='Chars/doc:')
    ocr_checkbox = widgets.Checkbox(value=False, description='Use OCR fallback for scanned PDFs')
    show_mem_checkbox = widgets.Checkbox(value=False, description='Show memory stats after replies')

    # State
    chat_messages = []  # list of dicts
    query_embed_cache = {}

    def get_embedding(text):
        key = text[:2000]
        if key in query_embed_cache:
            return query_embed_cache[key]
        resp = client.embeddings.create(model="text-embedding-3-small", input=text if text else "")
        arr = np.array(resp.data[0].embedding, dtype=np.float32)
        query_embed_cache[key] = arr
        return arr

    def current_system_prompt():
        bias = bias_input.value.strip()
        if bias:
            return f"{system_prompt_input.value}\nBias theme: {bias}"
        return system_prompt_input.value

    def update_chat_display():
        safe_prompt = html.escape(current_system_prompt())
        display(HTML(f"""
        <div style='position:sticky; top:0; background:#222; color:#fff; padding:8px; border-bottom:2px solid #555; z-index:10;'>
          📝 <b>System Prompt:</b> {safe_prompt} | 🔑 <b>API Key:</b> ✅ Active
        </div>
        <div style='max-height:420px; overflow-y:auto; border:1px solid #444; background:#111; color:#ddd; padding:10px; border-radius:8px;'>
          {chat_history_html}
        </div>
        """))

    def pick_top_entries(user_text, top_k, kw_weight):
        q_emb = get_embedding(user_text)
        sims = cosine_similarity([q_emb], summary_embeddings)[0]
        ents = extract_entities(user_text)
        kw_scores = np.array([keyword_score(e, ents) for e in kb], dtype=np.float32)
        combined = sims + kw_weight * (kw_scores / (kw_scores.max() if kw_scores.max() > 0 else 1.0))
        idxs = np.argsort(combined)[-top_k:][::-1]
        return [idx for idx in idxs], combined

    def safe_html(s):
        return html.escape(s).replace("\n", "<br>")

    def on_send(_):
        global chat_history_html
        user_input = text_input.value.strip()
        text_input.value = ""
        if not user_input:
            return

        if ocr_checkbox.value:
            maybe_install_ocr_deps()

        top_k = top_k_slider.value
        kw_w = kw_weight_slider.value
        ctx_chars = ctx_chars_slider.value

        top_idxs, combined = pick_top_entries(user_input, top_k=top_k, kw_weight=kw_w)
        selected_entries = [kb[i] for i in top_idxs]

        for i in top_idxs:
            entry = kb[i]
            title = entry.get('title') or entry.get('name') or 'Untitled'
            date = entry.get('date') or entry.get('datePublished') or ''
            url = entry.get('url') or entry.get('pdf') or entry.get('link') or ''
            summ = summaries[i][:300] + ("..." if len(summaries[i]) > 300 else "")
            url_html = f"<a href='{html.escape(url)}' target='_blank' style='color:#1e90ff;'>Open</a>" if url else "<i>No link</i>"
            chat_history_html += f"""
            <div style='background:#333; padding:8px; margin:6px 0; border-radius:6px;'>
              📄 <b>{html.escape(title)}</b> {f"({html.escape(str(date))})" if date else ""}<br>
              <i>{html.escape(summ)}</i><br>
              {url_html}
            </div>
            """

        texts = []
        for e in selected_entries:
            url = e.get('url') or e.get('pdf') or e.get('link') or ''
            doc_text = fetch_pdf_text(url, try_ocr=ocr_checkbox.value) if url else ""
            if isinstance(doc_text, str) and len(doc_text) > ctx_chars:
                doc_text = doc_text[:ctx_chars] + f"\n[Truncated to {ctx_chars} characters]"
            texts.append(doc_text)

        combined_pdf_text = "\n\n---\n\n".join(texts)

        user_message = {
            "role": "user",
            "content": f"Relevant PDF Content:\n{combined_pdf_text}\n\nUser Question: {user_input}"
        }
        chat_messages.append(user_message)

        conversation_history = [{"role": "system", "content": current_system_prompt()}] + chat_messages

        try:
            response = client.chat.completions.create(
                model="gpt-4o-mini",
                messages=conversation_history,
                max_tokens=700,
                temperature=0.7,
            )
            assistant_msg = response.choices[0].message.content
        except Exception as e:
            assistant_msg = f"⚠️ API Error: {e}"

        chat_messages.append({"role": "assistant", "content": assistant_msg})

        chat_history_html += f"""
        <div style='background:#222; padding:8px; margin:6px 0; border-radius:6px;'>
          <b style='color:#ffdd57;'>Assistant:</b><br>{safe_html(assistant_msg)}
        </div>
        """

        update_chat_display()
        if show_mem_checkbox.value:
            print_memory_usage(conversation_history)

    text_input = widgets.Text(placeholder='Type your question here...')
    send_button = widgets.Button(description='Send', button_style='success')
    send_button.on_click(on_send)
    clear_button = widgets.Button(description='Clear Chat', button_style='danger')

    def on_clear(_):
        global chat_history_html
        chat_history_html = ""
        chat_messages.clear()
        update_chat_display()
    clear_button.on_click(on_clear)

    controls = widgets.HBox([top_k_slider, kw_weight_slider, ctx_chars_slider, ocr_checkbox, show_mem_checkbox])
    update_chat_display()
    display(widgets.VBox([bias_input, system_prompt_input, controls, widgets.HBox([text_input, send_button, clear_button])]))

# ----------------------------
# Proceed gate
# ----------------------------
proceed_button = widgets.Button(description='Proceed', button_style='primary')
proceed_out = widgets.Output()

def proceed(_):
    with proceed_out:
        clear_output()
        if not api_key_ready and not os.getenv("OPENAI_API_KEY"):
            display(HTML("<b style='color:#e53935;'>❌ Please save your API key first.</b>"))
            return
        try:
            kb = upload_json()
            summaries, summary_embeddings = create_embeddings(kb)
            chat_ui(kb, summaries, summary_embeddings)
        except Exception as e:
            display(HTML(f"<b style='color:#e53935;'>⚠️ Error:</b> {html.escape(str(e))}"))

display(HTML("<h3>▶️ When your API key shows as saved, click Proceed:</h3>"))
display(proceed_button, proceed_out)
proceed_button.on_click(proceed)

