In [12]:
# Runtime > Change runtime type > GPU (T4/A100 preferred)
!nvidia-smi -L || print("No GPU. Enable it in Runtime settings for faster training.")


/bin/bash: -c: line 1: syntax error near unexpected token `"No GPU. Enable it in Runtime settings for faster training."'
/bin/bash: -c: line 1: `nvidia-smi -L || print("No GPU. Enable it in Runtime settings for faster training.")'


In [1]:
!pip -q install sentence-transformers faiss-cpu pandas scikit-learn tqdm


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.4/31.4 MB[0m [31m78.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
# OPTION A: manual upload
# from google.colab import files
# uploaded = files.upload()
# DATA_PATH = next(iter(uploaded))

# OPTION B: if file is in Drive
# from google.colab import drive
# drive.mount('/content/drive')
# DATA_PATH = "/content/drive/MyDrive/medquad.csv"

# 👉 set this if you just uploaded:
DATA_PATH = "/content/medquad.csv"

import pandas as pd, re, numpy as np
df = pd.read_csv(DATA_PATH)
df = df.dropna(subset=["question","answer"]).reset_index(drop=True)
print(df.shape, df.columns.tolist())

# sentence chunker
SENT_SPLIT = re.compile(r'(?<=[.!?])\s+(?=[A-Z0-9(])')
def chunk_passages(text, window=3, overlap=1):
    sents = [s.strip() for s in SENT_SPLIT.split(str(text)) if s.strip()]
    chunks, i = [], 0
    while i < len(sents):
        ch = " ".join(sents[i:i+window])
        if ch: chunks.append(ch)
        i += max(1, window-overlap)
    return chunks or [str(text).strip()]

# build corpus (passages) + relevance map
corpus = []
qid2docids = {}
for qi, row in df.iterrows():
    chunks = chunk_passages(row["answer"])
    ids = []
    for j, ch in enumerate(chunks):
        pid = f"d{qi}_{j}"
        corpus.append({"doc_id": pid, "text": ch})
        ids.append(pid)
    qid2docids[f"q{qi}"] = set(ids)

import pandas as pd
corpus_df = pd.DataFrame(corpus)
corpus_dict = {r.doc_id: r.text for _, r in corpus_df.iterrows()}
print("Questions:", len(df), "Passages:", len(corpus_df))


(16407, 2) ['question', 'answer']
Questions: 16407 Passages: 87316


In [3]:
from sklearn.model_selection import train_test_split

qids = [f"q{i}" for i in range(len(df))]
train_q, test_q = train_test_split(qids, test_size=0.2, random_state=42)
train_q, dev_q  = train_test_split(train_q, test_size=0.2, random_state=42)

def make_split(qid_list):
    return {
        "queries": {qid: df.iloc[int(qid[1:])]["question"] for qid in qid_list},
        "relevant": {qid: {pid:1 for pid in qid2docids[qid]} for qid in qid_list}
    }

train_split = make_split(train_q)
dev_split   = make_split(dev_q)
test_split  = make_split(test_q)

len(train_q), len(dev_q), len(test_q)


(10500, 2625, 3282)

In [4]:
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import InformationRetrievalEvaluator
import torch, time, gc

device = "cuda" if torch.cuda.is_available() else "cpu"
CANDIDATES = [
    "sentence-transformers/msmarco-distilbert-base-v4",   # fast, strong general
    "multi-qa-MiniLM-L6-cos-v1",                          # very small, fast
    "pritamdeka/S-PubMedBert-MS-MARCO",                   # biomedical domain
]

results = {}
for m in CANDIDATES:
    print(f"\n=== {m} ===")
    t0 = time.time()
    model = SentenceTransformer(m, device=device)
    evaluator = InformationRetrievalEvaluator(
        queries=dev_split["queries"],
        corpus=corpus_dict,
        relevant_docs=dev_split["relevant"],
        show_progress_bar=True,
        name="IR"
    )
    scores = evaluator(model, output_path=None)
    # keep the common metrics
    keep = {k:v for k,v in scores.items() if any(x in k for x in ["mrr@10","ndcg@10","map","recall@10"])}
    results[m] = keep
    # cleanup
    del model; gc.collect();
    if device=="cuda":
        torch.cuda.empty_cache()
    print("Time (min):", round((time.time()-t0)/60,2))
    print(keep)

results



=== sentence-transformers/msmarco-distilbert-base-v4 ===


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/545 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/265M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/319 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Batches:   0%|          | 0/83 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/1563 [00:00<?, ?it/s]

Corpus Chunks:  50%|█████     | 1/2 [00:29<00:29, 29.96s/it]

Batches:   0%|          | 0/1167 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 2/2 [00:50<00:00, 25.50s/it]


Time (min): 0.96
{'IR_cosine_recall@10': 0.3495415396699336, 'IR_cosine_ndcg@10': 0.34829766288357955, 'IR_cosine_mrr@10': 0.558679213907785, 'IR_cosine_map@100': 0.2615622024752071}

=== multi-qa-MiniLM-L6-cos-v1 ===


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/383 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Batches:   0%|          | 0/83 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/1563 [00:00<?, ?it/s]

Corpus Chunks:  50%|█████     | 1/2 [00:14<00:14, 14.90s/it]

Batches:   0%|          | 0/1167 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 2/2 [00:25<00:00, 12.67s/it]


Time (min): 0.49
{'IR_cosine_recall@10': 0.3674803418872263, 'IR_cosine_ndcg@10': 0.36133157046177483, 'IR_cosine_mrr@10': 0.5699644746787593, 'IR_cosine_map@100': 0.274859037800378}

=== pritamdeka/S-PubMedBert-MS-MARCO ===


modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/123 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/666 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/388 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Batches:   0%|          | 0/83 [00:00<?, ?it/s]

Corpus Chunks:   0%|          | 0/2 [00:00<?, ?it/s]

Batches:   0%|          | 0/1563 [00:00<?, ?it/s]

Corpus Chunks:  50%|█████     | 1/2 [00:52<00:52, 52.13s/it]

Batches:   0%|          | 0/1167 [00:00<?, ?it/s]

Corpus Chunks: 100%|██████████| 2/2 [01:29<00:00, 44.57s/it]


Time (min): 1.58
{'IR_cosine_recall@10': 0.4302946218162567, 'IR_cosine_ndcg@10': 0.4207182907510491, 'IR_cosine_mrr@10': 0.6491455782312913, 'IR_cosine_map@100': 0.32778627360178475}


{'sentence-transformers/msmarco-distilbert-base-v4': {'IR_cosine_recall@10': 0.3495415396699336,
  'IR_cosine_ndcg@10': 0.34829766288357955,
  'IR_cosine_mrr@10': 0.558679213907785,
  'IR_cosine_map@100': 0.2615622024752071},
 'multi-qa-MiniLM-L6-cos-v1': {'IR_cosine_recall@10': 0.3674803418872263,
  'IR_cosine_ndcg@10': 0.36133157046177483,
  'IR_cosine_mrr@10': 0.5699644746787593,
  'IR_cosine_map@100': 0.274859037800378},
 'pritamdeka/S-PubMedBert-MS-MARCO': {'IR_cosine_recall@10': 0.4302946218162567,
  'IR_cosine_ndcg@10': 0.4207182907510491,
  'IR_cosine_mrr@10': 0.6491455782312913,
  'IR_cosine_map@100': 0.32778627360178475}}

In [6]:
import os
os.environ["WANDB_DISABLED"] = "true"   # fully disable
os.environ["WANDB_MODE"] = "offline"    # extra safety; no network
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # silence HF tokenizer warnings


In [7]:
from sentence_transformers import SentenceTransformer, InputExample, losses
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from torch.utils.data import DataLoader
import torch, math, random

BASE_MODEL = "pritamdeka/S-PubMedBert-MS-MARCO"  # ← replace with your best
EPOCHS = 1         # start with 1; try 2 if time allows
BATCH_SIZE = 32    # lower if OOM
LR = 2e-5

# build training pairs (question, positive passage)
train_examples = []
for qid in train_q:
    qtext = df.iloc[int(qid[1:])]["question"]
    for pid in list(qid2docids[qid])[:1]:  # 1 positive per question to start
        train_examples.append(InputExample(texts=[qtext, corpus_dict[pid]]))

loader = DataLoader(train_examples, shuffle=True, batch_size=BATCH_SIZE, drop_last=True)

model = SentenceTransformer(BASE_MODEL, device=device)
train_loss = losses.MultipleNegativesRankingLoss(model)

dev_evaluator = InformationRetrievalEvaluator(
    queries=dev_split["queries"], corpus=corpus_dict, relevant_docs=dev_split["relevant"], name="dev"
)

warmup_steps = int(len(loader) * EPOCHS * 0.1)
SAVE_PATH = "/content/retriever_ft"

model.fit(
    train_objectives=[(loader, train_loss)],
    epochs=EPOCHS,
    optimizer_params={"lr": LR},
    warmup_steps=warmup_steps,
    evaluator=dev_evaluator,
    evaluation_steps=max(50, len(loader)//2),
    output_path=SAVE_PATH,
    show_progress_bar=True
)


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss,Validation Loss,Dev Cosine Accuracy@1,Dev Cosine Accuracy@3,Dev Cosine Accuracy@5,Dev Cosine Accuracy@10,Dev Cosine Precision@1,Dev Cosine Precision@3,Dev Cosine Precision@5,Dev Cosine Precision@10,Dev Cosine Recall@1,Dev Cosine Recall@3,Dev Cosine Recall@5,Dev Cosine Recall@10,Dev Cosine Ndcg@10,Dev Cosine Mrr@10,Dev Cosine Map@100
164,No log,No log,0.656762,0.790095,0.829714,0.866286,0.656762,0.403302,0.301257,0.19139,0.239281,0.373494,0.431463,0.50558,0.514847,0.729959,0.42753
328,No log,No log,0.660952,0.792381,0.832,0.868952,0.660952,0.411556,0.308038,0.196152,0.238922,0.377516,0.438199,0.513071,0.521787,0.73416,0.434599


In [8]:
from sentence_transformers import SentenceTransformer
ft_model = SentenceTransformer("/content/retriever_ft", device=device)

test_eval = InformationRetrievalEvaluator(
    queries=test_split["queries"],
    corpus=corpus_dict,
    relevant_docs=test_split["relevant"],
    name="test"
)
scores = test_eval(ft_model, output_path=None)
{k:v for k,v in scores.items() if any(x in k for x in ["mrr@10","ndcg@10","map","recall@10"])}


{'test_cosine_recall@10': 0.5123365846945027,
 'test_cosine_ndcg@10': 0.5188079360700242,
 'test_cosine_mrr@10': 0.7411528007506074,
 'test_cosine_map@100': 0.42470657083689756}

In [9]:
from sentence_transformers import CrossEncoder, InputExample
from torch.utils.data import DataLoader
import random

pairs = []
for qid in train_q:
    qtext = df.iloc[int(qid[1:])]["question"]
    pos_ids = list(qid2docids[qid])
    if not pos_ids:
        continue
    # 1 positive
    pairs.append(InputExample(texts=[qtext, corpus_dict[pos_ids[0]]], label=1.0))
    # a few negatives
    all_ids = list(corpus_dict.keys())
    negs = [pid for pid in random.sample(all_ids, k=min(3, len(all_ids))) if pid not in pos_ids]
    for pid in negs:
        pairs.append(InputExample(texts=[qtext, corpus_dict[pid]], label=0.0))

loader = DataLoader(pairs, batch_size=16, shuffle=True)
ce = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", num_labels=1, device=device)
ce.fit(train_dataloader=loader, epochs=1, warmup_steps=int(0.1*len(loader)), output_path="/content/crossencoder_ft")
print("Cross-encoder saved to /content/crossencoder_ft")


config.json:   0%|          | 0.00/794 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Token indices sequence length is longer than the specified maximum sequence length for this model (515 > 512). Running this sequence through the model will result in indexing errors


Step,Training Loss
500,0.3402
1000,0.1678
1500,0.1611
2000,0.155
2500,0.151


Cross-encoder saved to /content/crossencoder_ft


In [10]:
import os, faiss, numpy as np
from sentence_transformers import SentenceTransformer
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
FT_RETRIEVER = "/content/retriever_ft"  # from Cell 5
BASE_MODEL = "pritamdeka/S-PubMedBert-MS-MARCO"  # or your best from Cell 4

# Load fine-tuned retriever if present, else baseline
ret_model = SentenceTransformer(FT_RETRIEVER if os.path.exists(FT_RETRIEVER) else BASE_MODEL, device=device)

passage_ids = list(corpus_dict.keys())
passages = [corpus_dict[pid] for pid in passage_ids]

# Embed passages in batches
emb = ret_model.encode(
    passages,
    batch_size=256,
    convert_to_numpy=True,
    normalize_embeddings=True,
    show_progress_bar=True
)

# Cosine similarity via normalized dot product
index = faiss.IndexFlatIP(emb.shape[1])
index.add(emb)

def dense_retrieve(query, topn=200):
    q = ret_model.encode([query], convert_to_numpy=True, normalize_embeddings=True)
    D, I = index.search(q, topn)
    return [(passage_ids[i], float(D[0][j])) for j, i in enumerate(I[0])]


Batches:   0%|          | 0/342 [00:00<?, ?it/s]

In [11]:
from sentence_transformers import CrossEncoder
import numpy as np

FT_CE = "/content/crossencoder_ft"  # from Cell 7
if os.path.exists(FT_CE):
    ce = CrossEncoder(FT_CE, device=device)
else:
    ce = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", device=device)  # fallback

def ask(query, topn=200, topk=5):
    # 1) dense retrieve
    cand = dense_retrieve(query, topn=topn)       # [(pid, score), ...]
    texts = [corpus_dict[pid] for pid,_ in cand]
    # 2) cross-encode rerank
    pairs = [(query, t) for t in texts]
    ce_scores = ce.predict(pairs)                 # higher = more relevant
    order = np.argsort(-ce_scores)[:topk]
    results = [{"doc_id": cand[i][0],
                "score": float(ce_scores[i]),
                "text": texts[i]} for i in order]
    answer = results[0]["text"]
    return answer, results


In [12]:
tests = [
    "Can ibuprofen cause stomach pain?",
    "What is a normal A1C level?",
    "Do antihistamines cause drowsiness?",
]
for q in tests:
    ans, res = ask(q, topn=200, topk=5)
    print("Q:", q)
    print("A:", ans)
    print("Top1 score:", round(res[0]["score"], 3))
    print("-"*80)


Q: Can ibuprofen cause stomach pain?
A: The most common symptoms of IBS are abdominal pain or discomfort, often reported as cramping, along with diarrhea, constipation, or both. IBS may give a sensation of bloating because of increased sensitivity to normal amounts of gas. Eating a lot of fatty food can delay stomach emptying and cause bloating and discomfort, but not necessarily too much gas.
Top1 score: 1.496
--------------------------------------------------------------------------------
Q: What is a normal A1C level?
A: The higher the percentage, the higher a persons blood glucose levels have been. A normal A1C level is below 5.7 percent. An A1C of 5.7 to 6.4 percent indicates prediabetes.
Top1 score: 8.386
--------------------------------------------------------------------------------
Q: Do antihistamines cause drowsiness?
A: However, antihistamines often cause drowsiness, and a person should take antihistamines just before bedtime to help with nighttime itching. A health care pr

In [13]:
import numpy as np

def mrr_at_k(relevant_set, ranked_ids, k=10):
    for i, pid in enumerate(ranked_ids[:k], 1):
        if pid in relevant_set:
            return 1.0 / i
    return 0.0

def recall_at_k(relevant_set, ranked_ids, k=10):
    return len(relevant_set & set(ranked_ids[:k])) / (len(relevant_set) or 1)

def ndcg_at_k(relevant_set, ranked_ids, k=10):
    dcg = 0.0
    for i, pid in enumerate(ranked_ids[:k], 1):
        rel = 1.0 if pid in relevant_set else 0.0
        dcg += rel / np.log2(i + 1)
    idcg = sum(1.0 / np.log2(i + 1) for i in range(1, min(len(relevant_set), k) + 1))
    return dcg / (idcg or 1.0)

mrr = rec = ndcg = 0.0
N = min(500, len(test_q))  # cap for speed; remove cap for full test

for qid in test_q[:N]:
    query = df.iloc[int(qid[1:])]["question"]
    _, res = ask(query, topn=200, topk=10)  # ask returns reranked top-10
    ranked = [r["doc_id"] for r in res]
    rel = qid2docids[qid]
    mrr  += mrr_at_k(rel, ranked, 10)
    rec  += recall_at_k(rel, ranked, 10)
    ndcg += ndcg_at_k(rel, ranked, 10)

print({"MRR@10": round(mrr/N, 4), "Recall@10": round(rec/N, 4), "nDCG@10": round(ndcg/N, 4)})


{'MRR@10': 0.6837, 'Recall@10': 0.4304, 'nDCG@10': np.float64(0.433)}


In [14]:
import json, os
SAVE_DIR = "/content/medquad_artifacts"
os.makedirs(SAVE_DIR, exist_ok=True)

# Save retriever (if you fine-tuned)
if os.path.exists("/content/retriever_ft"):
    from sentence_transformers import SentenceTransformer
    SentenceTransformer("/content/retriever_ft").save(SAVE_DIR + "/retriever_ft")

# Save cross-encoder (if you fine-tuned)
if os.path.exists("/content/crossencoder_ft"):
    from sentence_transformers import CrossEncoder
    CrossEncoder("/content/crossencoder_ft").save(SAVE_DIR + "/crossencoder_ft")

# Save FAISS + id mapping
faiss.write_index(index, SAVE_DIR + "/corpus.index")
with open(SAVE_DIR + "/passage_ids.json", "w") as f:
    json.dump(passage_ids, f)

print("Saved to", SAVE_DIR)


Saved to /content/medquad_artifacts


In [17]:
# ✅ Replace your previous chatbot cell with this one

import gradio as gr, re, numpy as np

UNSURE_THR = 0.35  # raise to be more cautious
EMERGENCY_PAT = re.compile(r"\b(chest pain|stroke|not breathing|unconscious|overdose|suicid|self[- ]?harm)\b", re.I)
DOSING_PAT    = re.compile(r"\b(\d+\s*mg|\d+\s*mcg|dos(e|ing)|how much should i take)\b", re.I)

def safety_gate(msg: str):
    if EMERGENCY_PAT.search(msg or ""):
        return ("⚠️ I may be detecting an urgent/high-risk topic. "
                "Please contact emergency services or a healthcare professional immediately. "
                "I can’t provide crisis guidance.")
    if DOSING_PAT.search(msg or ""):
        return ("⚠️ I can’t provide dosing instructions. Please consult a clinician or pharmacist.")
    return None

def chat_fn(message, history):
    flagged = safety_gate(message)
    if flagged:
        return flagged

    try:
        answer, results = ask(message, topn=200, topk=5)   # uses your FAISS + (fine-tuned) models
    except NameError:
        return "Setup missing: please run the retrieval cells (8–9) first."
    except Exception as e:
        return f"Something went wrong while searching: {e}"

    conf = float(results[0]["score"]) if results else 0.0

    if conf < UNSURE_THR:
        bullets = "\n".join([f"- {r['text']}" for r in results[:5]])
        return (f"❓ I'm not fully confident. Here are the most relevant passages I found:\n\n"
                f"{bullets}\n\n_Not medical advice._")

    top_snips = "\n".join([f"- {r['text']}" for r in results[:3]])
    return f"{answer}\n\n_Not medical advice._  \nConfidence: {conf:.2f}\n\n**Top sources**\n{top_snips}"

examples = [
    "Can ibuprofen cause stomach pain?",
    "What is a normal A1C level?",
    "Do antihistamines cause drowsiness?",
    "What does a lipid panel measure?",
    "What is GERD?",
    "Signs of dehydration in adults?"
]

gr.ChatInterface(
    fn=chat_fn,
    title="🩺 MedQuAD QA Chat (RAG)",
    description=("Ask general health questions. Answers are grounded to your dataset. "
                 "Educational only — not medical advice. No emergencies or dosing."),
    examples=examples
).launch()


  self.chatbot = Chatbot(


It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://0660dbb0a747dc673e.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




In [18]:
# SAVE all artifacts for inference-only reuse
import os, json, shutil, pandas as pd, faiss

SAVE_DIR = "/content/medquad_artifacts"
os.makedirs(SAVE_DIR, exist_ok=True)

# 1) Retriever (fine-tuned) — if you trained it in Cell 5
if os.path.exists("/content/retriever_ft"):
    from sentence_transformers import SentenceTransformer
    SentenceTransformer("/content/retriever_ft").save(f"{SAVE_DIR}/retriever_ft")
    print("Saved retriever_ft")

# 2) Cross-encoder (fine-tuned) — if you trained it in Cell 7
if os.path.exists("/content/crossencoder_ft"):
    from sentence_transformers import CrossEncoder
    ce = CrossEncoder("/content/crossencoder_ft")
    ce.save(f"{SAVE_DIR}/crossencoder_ft")
    print("Saved crossencoder_ft")

# 3) FAISS index (already built in Cell 8)
faiss.write_index(index, f"{SAVE_DIR}/corpus.index")
print("Saved FAISS index")

# 4) Passage id order + passage texts (needed to map FAISS rows back to text)
import pandas as pd
pass_df = pd.DataFrame({"doc_id": passage_ids, "text": [corpus_dict[p] for p in passage_ids]})
pass_df.to_parquet(f"{SAVE_DIR}/passages.parquet", index=False)
with open(f"{SAVE_DIR}/passage_ids.json","w") as f:
    json.dump(passage_ids, f)
print("Saved passages & ids")

# 5) Simple config (thresholds etc.)
cfg = {"UNSURE_THR": 0.35, "topn": 200, "topk": 5}
with open(f"{SAVE_DIR}/config.json","w") as f:
    json.dump(cfg, f, indent=2)
print("Saved config")

print("All artifacts saved to:", SAVE_DIR)


Saved retriever_ft
Saved FAISS index
Saved passages & ids
Saved config
All artifacts saved to: /content/medquad_artifacts


In [19]:
!cd /content && zip -rq medquad_artifacts.zip medquad_artifacts
from google.colab import files
files.download("/content/medquad_artifacts.zip")


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [20]:
# ✅ Replace your previous chatbot cell with this one

import gradio as gr, re, numpy as np

# ---- Confidence as probability (sigmoid) ----
PROB_THR = 0.80   # require ≥80% confidence to answer directly (tune as you like)

def score_to_prob(score: float) -> float:
    # Map raw cross-encoder score → [0,1] probability-ish
    return float(1 / (1 + np.exp(-score)))

# ---- Safety gates ----
EMERGENCY_PAT = re.compile(r"\b(chest pain|stroke|not breathing|unconscious|overdose|suicid|self[- ]?harm)\b", re.I)
DOSING_PAT    = re.compile(r"\b(\d+\s*mg|\d+\s*mcg|dos(e|ing)|how much should i take)\b", re.I)

def safety_gate(msg: str):
    if EMERGENCY_PAT.search(msg or ""):
        return ("⚠️ I may be detecting an urgent/high-risk topic. "
                "Please contact emergency services or a healthcare professional immediately. "
                "I can’t provide crisis guidance.")
    if DOSING_PAT.search(msg or ""):
        return ("⚠️ I can’t provide dosing instructions. Please consult a clinician or pharmacist.")
    return None

# ---- Chat handler ----
def chat_fn(message, history):
    flagged = safety_gate(message)
    if flagged:
        return flagged

    try:
        answer, results = ask(message, topn=200, topk=5)   # uses your FAISS + (fine-tuned) models
    except NameError:
        return "Setup missing: please run the retrieval cells (8–9) first."
    except Exception as e:
        return f"Something went wrong while searching: {e}"

    if not results:
        return "❓ I couldn't find supporting passages. Please try rephrasing your question."

    raw = float(results[0]["score"])
    prob = score_to_prob(raw)  # 0..1

    if prob < PROB_THR:
        bullets = "\n".join([f"- {r['text']}" for r in results[:5]])
        return (f"❓ I'm not fully confident. Here are the most relevant passages I found:\n\n"
                f"{bullets}\n\n_Not medical advice._")

    top_snips = "\n".join([f"- {r['text']}" for r in results[:3]])
    return f"{answer}\n\n_Not medical advice._  \nConfidence: {prob:.2%}\n\n**Top sources**\n{top_snips}"

examples = [
    "Can ibuprofen cause stomach pain?",
    "What is a normal A1C level?",
    "Do antihistamines cause drowsiness?",
    "What does a lipid panel measure?",
    "What is GERD?",
    "Signs of dehydration in adults?"
]

gr.ChatInterface(
    fn=chat_fn,
    title="🩺 MedQuAD QA Chat (RAG)",
    description=("Ask general health questions. Answers are grounded to your dataset. "
                 "Educational only — not medical advice. No emergencies or dosing."),
    examples=examples
).launch()


  self.chatbot = Chatbot(


It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://cc21e7d7ca31db69c9.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




In [21]:
# save as scripts/clean_notebooks.py and run: python scripts/clean_notebooks.py
import nbformat as nbf, glob, os

paths = [p for p in glob.glob("**/*.ipynb", recursive=True) if ".ipynb_checkpoints" not in p]

for p in paths:
    nb = nbf.read(p, as_version=nbf.NO_CONVERT)
    if "widgets" in nb.get("metadata", {}):
        nb["metadata"].pop("widgets", None)
        nbf.write(nb, p)
        print("Cleaned:", p)
