In [11]:
import os, json, re, time
from typing import List, Dict, Any, Optional
import numpy as np
import pandas as pd

from qdrant_client import QdrantClient
from qdrant_client.http.models import Filter, FieldCondition, MatchValue, Range
from sentence_transformers import SentenceTransformer, CrossEncoder
import gradio as gr
import requests

# ------------------ Qdrant Configuration ------------------
QDRANT_URL = "https://20851a9b-65fb-47d0-982e-38fdfc7d76f8.europe-west3-0.gcp.cloud.qdrant.io"   # Hosted Qdrant instance URL
QDRANT_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.ooylH9ScdhcxPm0MywCTUSBDULcNCuuL4U8wd52UwXY"  # API key for authentication
COLLECTION_NAME = "fda_maude_rag"          # Collection name in Qdrant
VECTOR_NAME = "foi_embedding"              # Name of vector field in the collection

# ------------------ Embedding & Re-ranking Models ------------------
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"                # Model to create embeddings
CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"           # Cross-encoder for reranking results

# ------------------ Generation Settings ------------------
USE_OLLAMA = True                                                      # Use local Ollama for generation
OLLAMA_HOST = "http://localhost:11434"                                 # Ollama local server address
OLLAMA_MODEL = "mistral"                                               # Model name for Ollama

USE_MISTRAL_API = False                                                 # Use cloud Mistral API instead of Ollama
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY", "YOUR_MISTRAL_KEY")      # Retrieve API key from environment variables
MISTRAL_MODEL = "mistral-small-latest"                                 # Cloud Mistral model name

# ------------------ Retrieval Parameters ------------------
TOP_K = 25             # Number of documents fetched from Qdrant before reranking
TOP_N_FINAL = 8        # Number of top reranked documents passed to the generator
BLEND_ALPHA = 0.55     # Weighting factor between vector search and reranker scores
MAX_CHUNK_CHARS = 1800 # Max characters per document chunk during reranking

# ------------------ Initialize Clients & Models ------------------
client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)           # Connect to Qdrant instance
embedder = SentenceTransformer(EMBED_MODEL)                            # Load sentence transformer for embeddings
cross_encoder = CrossEncoder(CROSS_ENCODER_MODEL)                      # Load cross-encoder for reranking

# ------------------ Startup Confirmation ------------------
print("✅ Connected to Qdrant and loaded models.")

✅ Connected to Qdrant and loaded models.


In [12]:
# ------------------ Synonym Dictionaries ------------------
# Maps high-level event types to a list of possible synonymous terms
EVENT_SYNONYMS = {
    "malfunction": ["malfunction", "failed", "failure", "error", "fault", "broken"],
    "injury": ["injury", "harm", "bleeding", "laceration"],
    "death": ["death", "fatal"],
    "no answer": ["no answer", "unknown"],
}

# Maps brand names to a list of aliases or related product terms
BRAND_ALIASES = {
    "dexcom": ["dexcom", "g6", "g7"],
    "abbott": ["abbott", "freestyle", "libre"],
    "eversense": ["eversense", "senseonics"],
}

# ------------------ Term Expansion Function ------------------
def expand_terms(q: str) -> List[str]:
    """
    Given a query string `q`, expand it with synonyms and brand aliases.
    Returns a list of unique terms (preserving original order).
    """
    ql = q.lower()      # Convert query to lowercase for case-insensitive matching
    out = [q]           # Start with the original query as the first term

    # Add event synonyms if the event keyword is found in the query
    for k, syns in EVENT_SYNONYMS.items():
        if k in ql:
            out += syns

    # Add brand aliases if the brand name is found in the query
    for k, syns in BRAND_ALIASES.items():
        if k in ql:
            out += syns

    # Remove duplicates while preserving the order of terms
    seen = set()
    res = []
    for t in out:
        if t not in seen:
            res.append(t)
            seen.add(t)

    return res

In [13]:
import re

# ------------------ Canonical Event Names ------------------
# Maps lowercase event keywords to their standardized label form
EVENT_CANON = {
    "malfunction": "Malfunction",
    "injury": "Injury",
    "death": "Death",
    "no answer": "No Answer",
}

# ------------------ Regex Patterns for Event Matching ------------------
# Each tuple contains:
#   - A regex pattern (case-insensitive match handled by lowercasing query)
#   - The standardized event label to assign if matched
# Patterns handle:
#   - Plural forms (e.g., malfunctions, injuries, deaths)
#   - Variants (e.g., injury / injuries)
#   - Multi-word forms (e.g., "no answer")
EVENT_PATTERNS = [
    (r"\bmalfunction(s)?\b", "Malfunction"),
    (r"\binjur(y|ies)\b",    "Injury"),
    (r"\bdeath(s)?\b",       "Death"),
    (r"\bno\s+answer\b",     "No Answer"),
]

# ------------------ Query Filter Inference ------------------
def infer_filters_from_query(q: str) -> dict:
    """
    Given a text query `q`, infer structured filters such as:
      - event_type: Based on keyword match with plural-awareness
      - year_from, year_to: Based on year mentions in the query
      - (optionally) brand hints
    Returns a dictionary of inferred filters.
    """
    ql = q.lower()  # Normalize to lowercase for matching
    out = {}

    # Match event type using regex patterns
    for pat, val in EVENT_PATTERNS:
        if re.search(pat, ql):
            out["event_type"] = val
            break  # Stop after first match

    # Detect years (supports both 1900s and 2000s) and extract range
    years = re.findall(r"\b(19\d{2}|20\d{2})\b", ql)
    if years:
        ys = sorted({int(y) for y in years})  # Remove duplicates, sort
        out["year_from"] = ys[0]              # Earliest year
        out["year_to"]   = ys[-1]             # Latest year

    # Optional: brand detection logic could go here
    # Example:
    # if "eversense" in ql:
    #     out["brand"] = "eversense sensor"

    return out

In [14]:
import re
from collections import Counter
import numpy as np  # numerical operations (used for log boosting)

# ------------------ Tokenizer ------------------
def _tok(s: str) -> list:
    """
    Lowercase and tokenize a string into alphanumeric terms.
    Returns an empty list if input is None or empty.
    """
    if not s:
        return []
    return re.findall(r"[a-z0-9]+", str(s).lower())

# ------------------ Payload Fields to Search ------------------
# List of fields in the Qdrant payload to check for keyword matches
_PFIELDS = [
    "brand_name", "generic_name", "manufacturer_d_name", "event_type",
    "model_number", "report_number", "mdr_report_key", "date_received", "foi_text"
]

# ------------------ Keyword Match Scoring ------------------
def _keyword_score(query: str, payload: dict) -> float:
    """
    Compute a keyword match score between the query and a document payload.
    - Tokenizes both query and relevant payload fields
    - Applies a small term-frequency (TF) log boost for repeated terms
    - Normalizes score by query length to avoid bias toward longer queries
    """
    q = Counter(_tok(query))  # token frequency in query
    if not q:
        return 0.0

    hay = []  # all tokens from payload fields
    for f in _PFIELDS:
        v = payload.get(f)
        if v:
            hay.extend(_tok(v))

    if not hay:
        return 0.0

    h = Counter(hay)  # token frequency in payload
    score = 0.0
    for t, w in q.items():
        if t in h:
            # Match score += query term weight * (1 + log(1 + term freq in payload))
            score += w * (1.0 + np.log1p(h[t]))

    # Normalize by total query term count
    return score / (sum(q.values()) + 1e-6)

# ------------------ Blend Weights ------------------
# These control the relative importance of different scoring components
ALPHA_VEC = 0.45   # weight for vector similarity score from Qdrant
ALPHA_RER = 0.35   # weight for cross-encoder (re-ranking) score
ALPHA_KEY = 0.20   # weight for keyword-based payload match score

In [15]:
def make_filter(
    brand: Optional[str] = None,
    event_type: Optional[str] = None,
    adverse_flag: Optional[str] = None,
    product_flag: Optional[str] = None,
    year_from: Optional[int] = None,
    year_to: Optional[int] = None
) -> Optional[Filter]:
    """
    Building a Qdrant Filter object based on optional search parameters.

    Parameters:
        brand         : Filter by device brand name
        event_type    : Filter by type of event (e.g., "Injury", "Malfunction")
        adverse_flag  : Filter by adverse event flag value (ignored if 'Any')
        product_flag  : Filter by product problem flag value (ignored if 'Any')
        year_from     : Lower bound for year (inclusive)
        year_to       : Upper bound for year (inclusive)

    Returns:
        Filter object if any conditions are set, otherwise None.
    """
    must = []  # List of FieldCondition objects that all must be satisfied

    # Match brand name if provided
    if brand:
        must.append(FieldCondition(key="brand_name", match=MatchValue(value=brand)))

    # Match event type if provided
    if event_type:
        must.append(FieldCondition(key="event_type", match=MatchValue(value=event_type)))

    # Match adverse event flag (skip if 'Any')
    if adverse_flag and adverse_flag not in ("Any", "any"):
        must.append(FieldCondition(key="adverse_event_flag", match=MatchValue(value=adverse_flag)))

    # Match product problem flag (skip if 'Any')
    if product_flag and product_flag not in ("Any", "any"):
        must.append(FieldCondition(key="product_problem_flag", match=MatchValue(value=product_flag)))

    # Match year range if specified
    if year_from or year_to:
        yr = {}
        if year_from:
            yr["gte"] = int(year_from)  # greater than or equal to
        if year_to:
            yr["lte"] = int(year_to)    # less than or equal to
        must.append(FieldCondition(key="year", range=Range(**yr)))

    # Return a Filter if any conditions exist, otherwise None
    return Filter(must=must) if must else None

In [16]:
def retrieve(query: str,
             brand=None, event_type=None, adverse_flag=None, product_flag=None,
             year_from=None, year_to=None,
             top_k: int = TOP_K, top_n_final: int = TOP_N_FINAL):
    """
    Retrieve and rank documents from Qdrant based on a query.

    Steps:
    1. Optionally infer filters (event type, year range) from the query text.
    2. Expand the query with synonyms/aliases for better recall.
    3. Encode the query into a vector using the embedding model.
    4. Search Qdrant with optional structured filters.
    5. If no results and an event type was inferred, retry without the filter (fallback).
    6. Rerank results with a cross-encoder + keyword match score.
    7. Blend vector, rerank, and keyword scores into a final score.
    8. Apply final sorting and optional hard filtering by event type.
    """

    # ---- 1. Infer filters from the query (if function available) ----
    auto = infer_filters_from_query(query) if 'infer_filters_from_query' in globals() else {}
    inferred_event = auto.get("event_type")

    # Force event_type from inference if present (e.g., "malfunctions", "injuries")
    if inferred_event:
        event_type = inferred_event

    # Fill year range from inferred values if not explicitly provided
    year_from = year_from if year_from is not None else auto.get("year_from")
    year_to   = year_to   if year_to   is not None else auto.get("year_to")

    # ---- 2. Expand query with synonyms and aliases ----
    q_text = " ".join(expand_terms(query))

    # ---- 3. Encode query into vector form ----
    q_vec = embedder.encode([q_text], normalize_embeddings=True)[0].tolist()

    # ---- 4. Build Qdrant filter ----
    filt = make_filter(brand, event_type, adverse_flag, product_flag, year_from, year_to)

    # ---- 5. Initial vector search in Qdrant ----
    hits = client.search(
        collection_name=COLLECTION_NAME,
        query_vector={"name": VECTOR_NAME, "vector": q_vec},
        limit=top_k,
        with_payload=True,   # include metadata
        with_vectors=False,  # we don’t need the stored vectors back
        query_filter=filt
    )

    # ---- Fallback: retry without event filter if nothing found ----
    if inferred_event and not hits:
        hits = client.search(
            collection_name=COLLECTION_NAME,
            query_vector={"name": VECTOR_NAME, "vector": q_vec},
            limit=top_k,
            with_payload=True,
            with_vectors=False,
            query_filter=None
        )

    # ---- No results case ----
    if not hits:
        return []

    # ---- 6. Prepare reranking pairs and metadata ----
    pairs, meta = [], []
    for h in hits:
        pl = h.payload or {}
        txt = (pl.get("foi_text") or "").strip()

        # Truncate long text for reranker speed
        if len(txt) > MAX_CHUNK_CHARS:
            txt = txt[:MAX_CHUNK_CHARS] + "…"

        pairs.append([q_text, txt if txt else " "])
        meta.append({
            "id": h.id,
            "score_vec": float(h.score),  # original vector similarity score
            "payload": pl
        })

    # ---- 7. Rerank results using cross-encoder ----
    ce_scores = cross_encoder.predict(pairs).tolist()

    # ---- 8. Compute keyword scores and blend final score ----
    out = []
    for m, ce in zip(meta, ce_scores):
        kw = _keyword_score(query, m["payload"])   # keyword match score
        ce_norm = 1.0 / (1.0 + np.exp(-ce / 5.0))  # normalize cross-encoder score (sigmoid)
        final = (ALPHA_VEC * m["score_vec"]) + (ALPHA_RER * ce_norm) + (ALPHA_KEY * kw)

        # Store scoring breakdown for debugging
        m["score_ce"] = float(ce)
        m["score_kw"] = float(kw)
        m["score_final"] = float(final)
        out.append(m)

    # ---- 9. Sort by final blended score ----
    out.sort(key=lambda x: x["score_final"], reverse=True)

    # ---- 10. Hard filter by inferred event type if needed ----
    if inferred_event:
        out = [d for d in out if (d["payload"] or {}).get("event_type") == inferred_event]

    # ---- 11. Return top N results ----
    return out[:top_n_final]

In [17]:
# Ensure an embedding model is available for semantic grouping
try:
    embedding_model  # already defined?
except NameError:
    try:
        # If your retrieval uses `embedder`, reuse it to avoid loading twice
        if 'embedder' in globals() and embedder is not None:
            embedding_model = embedder
        else:
            from sentence_transformers import SentenceTransformer
            embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
    except Exception as e:
        raise RuntimeError(f"Failed to initialize embedding_model: {e}")


In [18]:
# ===== LLM: Ollama Local Configuration =====
import os, requests

# Enable local Ollama usage; disable remote Mistral
USE_OLLAMA = True
USE_MISTRAL_API = False  # Ensure cloud Mistral is turned off when using Ollama locally

# Local Ollama server configuration (overridable via environment variables)
OLLAMA_HOST  = os.getenv("OLLAMA_HOST", "http://localhost:11434")  # Local API endpoint
OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "mistral")                # Model name to use (e.g., "mistral", "mistral:latest", "llama3")

def call_ollama(prompt: str) -> str:
    """
    Call a locally hosted Ollama model via HTTP API and return its response.

    Parameters:
        prompt (str): The text prompt to send to the model.

    Returns:
        str: Model's generated text, or an error message if request fails.
    """
    try:
        # Send generation request to Ollama API
        r = requests.post(
            f"{OLLAMA_HOST}/api/generate",
            json={"model": OLLAMA_MODEL, "prompt": prompt, "stream": False},
            timeout=180,  # Timeout in seconds for long generations
        )
        r.raise_for_status()  # Raise exception for HTTP errors

        # Extract and return the 'response' field from JSON output
        return r.json().get("response", "")
    except Exception as e:
        # Return an error message if request or parsing fails
        return f"[Generation error] Ollama request failed: {e}"

In [19]:
def _llm_generate(prompt: str) -> str:
    """
    Route a text prompt to the configured LLM backend and return the generated output.

    Priority:
        1. Local Ollama model (if USE_OLLAMA is True)
        2. Remote Mistral API (if USE_MISTRAL_API is True)

    Parameters:
        prompt (str): The text input for the language model.

    Returns:
        str: Generated text from the chosen backend, or an error/status message.
    """
    try:
        # Use local Ollama if enabled
        if USE_OLLAMA:
            return call_ollama(prompt)

        # Use cloud Mistral API if enabled
        if USE_MISTRAL_API:
            return call_mistral_api(prompt)

    except Exception as e:
        # Capture and return any generation errors
        return f"[Generation error] {e}"

    # If neither backend is configured, return a notice
    return "[No generation backend configured]"

In [20]:
# ==========================================
# Minimal ChatGPT-style RAG with history
# - Single answer style
# - History-aware responses
# - Sources shown separately in a table (not embedded in answer)
# ==========================================
import re, time, pandas as pd, gradio as gr
from sentence_transformers import util
import numpy as np

# ---- LLM routing (select backend) ----
def _llm_generate(prompt: str) -> str:
    """
    Sends prompt to the selected LLM backend.
    Priority:
        1. Ollama local model if USE_OLLAMA is True
        2. Remote Mistral API if USE_MISTRAL_API is True
    """
    try:
        if USE_OLLAMA:
            return call_ollama(prompt)
        if USE_MISTRAL_API:
            return call_mistral_api(prompt)
    except Exception as e:
        return f"[Generation error] {e}"
    return "[No generation backend configured]"

# ---- Prompt-building configuration ----
MAX_TURNS_IN_HISTORY = 6      # Number of recent user/assistant turns to keep for continuity
MAX_SNIPPET_CHARS   = 2200    # Approximate token budget for evidence snippets

# ---- Compose grouped context from semantically similar docs ----
def _compose_context_semantic_grouped(docs, similarity_threshold=0.85):
    """
    Groups semantically similar FOI reports to avoid redundancy before passing to LLM.
    Groups are based on cosine similarity between embeddings of FOI text.
    """
    entries = []
    # Extract key metadata and FOI text from each doc
    for d in docs:
        p = d["payload"] or {}
        text = (p.get("foi_text") or "").strip()
        if not text:
            continue
        entries.append({
            "brand": p.get("brand_name", ""),
            "event": p.get("event_type", ""),
            "date": p.get("date_received", ""),
            "text": text
        })

    if not entries:
        return ""

    # Encode FOI texts into embeddings
    corpus_embeddings = embedding_model.encode(
        [e["text"] for e in entries],
        convert_to_tensor=True
    )

    # Group assignments
    grouped = []
    used = set()
    for i in range(len(entries)):
        if i in used:
            continue
        group_indices = [i]
        used.add(i)

        # Compare entry i to all later entries
        for j in range(i + 1, len(entries)):
            if j in used:
                continue
            sim = util.pytorch_cos_sim(corpus_embeddings[i], corpus_embeddings[j]).item()
            if sim >= similarity_threshold:
                group_indices.append(j)
                used.add(j)

        grouped.append(group_indices)

    # Build text block for each group
    parts = []
    for idx, group in enumerate(grouped, 1):
        brand = entries[group[0]]["brand"]
        event = entries[group[0]]["event"]
        dates = sorted({entries[k]["date"] for k in group})
        detail = entries[group[0]]["text"]

        parts.append(
            f"[{idx}] brand={brand}, event={event}, dates={', '.join(dates)}\n{detail}\n"
        )

    return "\n".join(parts)

# ---- Build the final prompt for the LLM ----
def _build_chat_prompt(history, user_query, docs):
    """
    Builds the full chat prompt:
    - System instruction block
    - Short conversation history
    - Context from retrieved FOI docs (grouped by semantic similarity)
    - Current user query
    """
    # Keep last few turns to guide LLM style and context continuity
    history_lines = []
    for u, b in history[-MAX_TURNS_IN_HISTORY:]:
        history_lines.append(f"User: {u}")
        history_lines.append(f"Assistant: {b}")
    history_text = "\n".join(history_lines) if history_lines else "[no prior turns]"

    # Gather grouped evidence snippets
    context = _compose_context_semantic_grouped(docs, similarity_threshold=0.85)

    # System role and style instructions
    sys = (
        "You are an academic researcher specializing in medical device safety analysis. "
        "Write responses in a formal, evidence-based, and professional tone suitable for an academic or regulatory report. "
        "Summarize findings using precise, objective language, avoiding conversational expressions. "
        "Organize the output into a 'Summary' section followed by a 'Key Reports' section with enumerated points. "
        "Each key report should clearly state the device, event type, and date(s), followed by a concise description of the incident. "
        "Do NOT include a 'Sources' section, document IDs, or any reference numbers in the answer. "
        "If the evidence is insufficient, state this formally."
    )

    # Final assembled prompt
    prompt = (
        f"{sys}\n\n"
        f"Conversation so far:\n{history_text}\n\n"
        f"Evidence snippets:\n{context}\n\n"
        f"User: {user_query}\n"
        f"Assistant:"
    )
    return prompt

# ---- Create sources table from retrieved docs ----
def _sources_table(docs):
    """Create a DataFrame of source metadata for display below the chat."""
    rows = []
    for i, d in enumerate(docs, 1):
        p = d["payload"] or {}
        rows.append([
            i,
            p.get("brand_name", ""),
            p.get("event_type", ""),
            p.get("date_received", ""),
            p.get("report_number", p.get("mdr_report_key", "")),
            f"{d.get('score_final', 0.0):.3f}"
        ])
    return pd.DataFrame(rows, columns=["#", "Brand", "Event", "Date", "Report", "Score"])

# ---- Strip any model-added 'Sources' text from answer ----
def _strip_sources_block(text: str) -> str:
    """Remove any 'Sources:' section the model might have added."""
    if not text:
        return text
    m = re.search(r"\n\s*Sources:\s*.*", text, flags=re.I | re.S)
    return text[:m.start()].rstrip() if m else text

# ---- One conversational turn ----
TOP_K = 25          # How many docs to retrieve from vector DB
TOP_N_FINAL = 8     # How many to keep after reranking

def chat_turn(user_msg, history):
    """
    Process a single chat turn:
    - Retrieve relevant docs
    - Build and send prompt to LLM
    - Return updated chat history and sources table
    """
    # If user input is empty, return current state with empty sources
    if not user_msg or not user_msg.strip():
        empty_tbl = pd.DataFrame(columns=["#", "Brand", "Event", "Date", "Report", "Score"])
        return history, empty_tbl, history

    # Retrieve docs relevant to query
    docs = retrieve(user_msg, top_k=TOP_K, top_n_final=TOP_N_FINAL)
    if not docs:
        bot = "I couldn’t find anything clearly relevant. Try adding a brand, model, or year (e.g., 'Dexcom G6 malfunctions 2020')."
        new_hist = history + [(user_msg, bot)]
        empty_tbl = pd.DataFrame(columns=["#", "Brand", "Event", "Date", "Report", "Score"])
        return new_hist, empty_tbl, new_hist

    # Build final prompt for LLM and generate answer
    prompt = _build_chat_prompt(history, user_msg, docs)
    answer = _strip_sources_block(_llm_generate(prompt))

    # Create table of sources
    table = _sources_table(docs)

    # Update conversation history
    new_hist = history + [(user_msg, answer)]
    return new_hist, table, new_hist

# ---- Gradio UI setup ----
with gr.Blocks(title="FDA Medical Device Chatbot") as chat_demo:
    gr.Markdown("### FDA Medical Device Chatbot")
    state = gr.State([])  # Holds conversation history [(user, bot), ...]

    # Main chat window
    chatbot = gr.Chatbot(height=440, label=None, show_copy_button=True, value=[])

    # Input row
    with gr.Row():
        msg = gr.Textbox(placeholder="Ask anything about MAUDE adverse events…", scale=6)
        send = gr.Button("Send", variant="primary", scale=1)
        clear = gr.Button("Clear")

    # Sources table display
    sources = gr.Dataframe(
        headers=["#", "Brand", "Event", "Date", "Report", "Score"],
        interactive=False,
        wrap=True,
        label="Sources (last answer)"
    )

    # Send button click handler
    send.click(
        fn=chat_turn,
        inputs=[msg, state],
        outputs=[chatbot, sources, state],  # Writes back updated history to state
    ).then(lambda: "", None, msg)  # Clears input box after sending

    # Pressing Enter in the textbox triggers same as send
    msg.submit(
        fn=chat_turn,
        inputs=[msg, state],
        outputs=[chatbot, sources, state],
    ).then(lambda: "", None, msg)

    # Clear chat button handler
    def _clear():
        """Reset chat, sources table, and state."""
        return [], pd.DataFrame(columns=["#", "Brand", "Event", "Date", "Report", "Score"]), []
    clear.click(_clear, inputs=None, outputs=[chatbot, sources, state])

# Launch Gradio app (share=True for public link)
chat_demo.launch(share="True")

  chatbot = gr.Chatbot(height=440, label=None, show_copy_button=True, value=[])


* Running on local URL:  http://127.0.0.1:7861


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


* Running on public URL: https://6b28fce633e24fcc35.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




  return forward_call(*args, **kwargs)
  hits = client.search(
  return forward_call(*args, **kwargs)
  hits = client.search(


In [21]:
import time
import numpy as np
import pandas as pd
from tqdm import tqdm

import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
try:
    nltk.data.find("tokenizers/punkt")
except LookupError:
    nltk.download("punkt", quiet=True)

from rouge_score import rouge_scorer
from bert_score import score as bertscore

smooth = SmoothingFunction().method4
rouge = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)

def _gen_answer_with_latency(query: str):
    t0 = time.time()
    docs = retrieve(query, top_k=TOP_K, top_n_final=TOP_N_FINAL)
    t_retrieve = time.time() - t0
    if not docs:
        return "", t_retrieve, 0.0
    prompt = _build_chat_prompt([], query, docs)
    t1 = time.time()
    out = _llm_generate(prompt).strip()
    t_gen = time.time() - t1
    return out, t_retrieve, t_gen

def _bleu(ref: str, hyp: str):
    if not ref or not hyp:
        return np.nan
    return sentence_bleu([ref.split()], hyp.split(), smoothing_function=smooth)

def _rouge_l(ref: str, hyp: str):
    if not ref or not hyp:
        return np.nan
    return rouge.score(ref, hyp)["rougeL"].fmeasure

def _bertscore_f1(ref: str, hyp: str):
    if not ref or not hyp:
        return np.nan
    P, R, F1 = bertscore([hyp], [ref], lang="en")
    return float(F1.mean().item())


In [22]:
test_items = [
    {
        "query": "Events where Death Occured",
        "reference": """Summary:
The analysis focused on six events involving medical devices that resulted in death, spanning from 2020 to 2021. The devices include True Metrix and Dexcom G6 continuous glucose monitoring systems. The evidence gathered indicates that the root causes of these incidents are not fully determined due to insufficient information about the underlying conditions and potential contributing factors.

Key Reports:
1. Device: True Metrix, Event Type: Death, Date(s): 2021-07-23
   - The customer's deceased husband had used a True Metrix meter prior to his death. The customer reported issues with the meter but declined further investigation or blood tests. No cause of death was disclosed for either the customer or her husband.

2. Device: G6 Continuous Glucose Monitoring System, Event Type: Death, Date(s): 2020-12-15
   - The patient's relatives were unable to be reached, and the cause of death was unknown to both the patient's doctor and primary care physician.

3. Device: Dexcom G6 Continuous Glucose Monitoring System, Event Type: Death, Date(s): 2020-07-01
   - Signal loss over one hour was reported, but no product or data was provided for evaluation. The cause of death and a probable cause could not be determined as no injury or medical intervention was reported.

4. Device: True Metrix Air, Event Type: Death, Date(s): 2021-10-22
   - The customer passed away with a high blood glucose level (600 mg/dl), but the fasting/non-fasting status was not disclosed. The death was likely related to the patient being a terminal diabetic, according to their caregiver. No product or data were returned for evaluation.

5. Device: Dexcom G6 Continuous Glucose Monitoring System, Event Type: Death, Date(s): 2020-07-01, 2021-07-30
   - Insufficient information was provided to determine the cause of death or the role of the device in these incidents.

6. Device: True Metrix, Event Type: Death, Date(s): 2021-07-23
   - The customer reported issues with recharging their deceased husband's True Metrix meter, but no further investigation or details were provided. No cause of death was disclosed for the customer or her husband."""
    },
    {
        "query": "Adverse Events in Glucose Monitoring systems",
        "reference": """Summary:
This report presents an analysis of adverse events associated with the Dexcom G6 Continuous Glucose Monitoring System (CGMS). Between 2019 and 2021, multiple incidents were reported where the CGMS readings differed from those obtained by a blood glucose meter. The events occurred on specific dates mentioned above. The discrepancy in values fell within the B Zone of the Parkes Error Grid, indicating moderate inaccuracies, although no injuries or medical interventions were reported.

Key Reports:
1. Device: Dexcom G6 CGMS, Event Type: Malfunction, Date(s): 2019-09-11, 2020-09-25, 2020-10-05, 2020-11-12, 2021-01-13, 2021-03-02, 2021-08-03, 2021-10-07
   Description: Inaccuracies between the CGMS and blood glucose meter were reported on several occasions. The sensor was inserted off-label into the arm on 09/02/2020, and data investigation confirmed the allegation of discrepancies. However, the cause could not be determined via data analysis. The reported glucose values fell within the D Zone of the Parkes Error Grid but were found to be different enough to fall within the B Zone. No injury or medical intervention was required."""
    },
    {
        "query": "Malfunctions eversense sensor devices in 2019",
        "reference": """Summary:
The review of available evidence reveals multiple reported malfunction incidents involving the Eversense sensor device during the year 2019. The specific issues encountered included difficulties with explantation, which required more than one attempt to remove the sensor.

Key Reports:
1. Device: Eversense Sensor
   Event Type: Malfunction - Explantation Issue
   Date(s): August 30, September 13, October 14 (2019)
   Description: In one incident, a healthcare professional encountered difficulty in explanting the Eversense sensor on the first attempt made on August 30. The sensor was successfully removed during the second attempt made on October 14.

2. Device: Eversense Sensor
   Event Type: Malfunction - Explantation Issue
   Date(s): August 2, September 18, October 9, October 14 (2019)
   Description: In another incident, a healthcare professional faced difficulties in explanting the Eversense sensor on the first attempt. The exact date of the initial attempt and subsequent removal is not specified in the available data."""
    },
    {
        "query": "Injurys in libre Devices in 2021",
        "reference": """Summary:

This analysis focuses on adverse events related to the use of Freestyle Libre glucose monitoring systems during the year 2021. A total of five incidents were identified, all categorized as injuries. The incidents involved two different models: Freestyle Libre 14-day and Freestyle Libre 2.

Key Reports:

1. Event: Injury related to Freestyle Libre 14-day device
   - Date(s): March 8, 2021, December 29, 2021
   - Description: Upon removal of the sensor filament remained in the customer's skin, causing pain. No further information was provided regarding the extent or duration of the injury.

2. Event: Injury related to Freestyle Libre 2 device
   - Date: February 25, 2021
   - Description: The applicator needle got stuck in the sensor during application, causing "hemorrhage". The needle had to be surgically removed due to threatening bleeding. No further treatment was reported.

3. Event: Injury related to Freestyle Libre 14-day device
   - Date: June 8, 2021
   - Description: Insertion issue resulted in a broken sensor filament that remained in the customer's arm, causing pain, swelling, and bleeding. A healthcare provider removed the sensor filament, and no further treatment was required.

4. Event: Injury related to Freestyle Libre 2 device
   - Date: October 18, 2021
   - Description: Bleeding occurred after sensor insertion, which continued into the next day. A healthcare professional attempted to stop bleeding with a pressure bandage, but this was not effective. The wound was then cauterized with silver nitrate. No further treatment or medication was required.

5. Event: Injury related to Freestyle Libre 14-day device
   - Date: February 1, 2021
   - Description: A caregiver reported heavy bleeding while a customer was wearing the sensor. The customer lost consciousness but was able to self-treat once they regained consciousness. No further information was provided regarding the extent or duration of the injury."""
    },
]


In [23]:
rows = []
for item in tqdm(test_items, desc="Evaluating"):
    q = item["query"].strip()
    ref = (item.get("reference") or "").strip()

    hyp, t_retrieve, t_gen = _gen_answer_with_latency(q)

    rows.append({
        "query": q,
        "bleu": _bleu(ref, hyp),
        "rougeL": _rouge_l(ref, hyp),
        "bertscore_f1": _bertscore_f1(ref, hyp),
        "ref_len": len(ref.split()),
        "hyp_len": len(hyp.split()),
        "latency_retrieve_s": t_retrieve,
        "latency_generate_s": t_gen,
        "latency_total_s": t_retrieve + t_gen,
        "hyp_preview": hyp[:220].replace("\n", " ") + ("…" if len(hyp) > 220 else "")
    })

df_gen = pd.DataFrame(rows)

macro = df_gen[["bleu","rougeL","bertscore_f1","latency_retrieve_s","latency_generate_s","latency_total_s"]].mean().to_frame("mean").T
macro.index = ["macro_avg"]

print("Macro averages")
display(macro)

print("Per-query scores")
display(df_gen)


  return forward_call(*args, **kwargs)
  hits = client.search(
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  return forward_call(*args, **kwargs)
Evaluating:  25%|██▌       | 1/4 [01:38<04:56, 98.93s/it]Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Evaluating:  50%|█████     | 2/4 [02:30<02:22, 71.03s/it]Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be ab

Macro averages


Unnamed: 0,bleu,rougeL,bertscore_f1,latency_retrieve_s,latency_generate_s,latency_total_s
macro_avg,0.272436,0.539485,0.92208,6.428841,53.47934,59.908182


Per-query scores


Unnamed: 0,query,bleu,rougeL,bertscore_f1,ref_len,hyp_len,latency_retrieve_s,latency_generate_s,latency_total_s,hyp_preview
0,Events where Death Occured,0.241838,0.475436,0.904291,328,275,4.520098,87.343865,91.863963,Summary: This analysis presents five incidents...
1,Adverse Events in Glucose Monitoring systems,0.292434,0.564767,0.933937,168,180,6.475954,36.342108,42.818062,Summary: The analysis focuses on adverse event...
2,Malfunctions eversense sensor devices in 2019,0.285814,0.569659,0.928134,148,173,7.136031,37.834176,44.970207,Summary: The analysis of incidents involving t...
3,Injurys in libre Devices in 2021,0.269655,0.548077,0.921957,296,321,7.583283,52.397212,59.980495,Summary: The analysis of incident reports from...
