In [16]:
# Cell 1: Install packages
!pip install -q sentence-transformers faiss-cpu transformers accelerate
!pip install -q unstructured pandas numpy tqdm

In [17]:
# Cell 2: Upload and unzip dataset
from google.colab import files
import os
import zipfile

# Upload the zip file (only needed once per session)
print("Upload your 'Finished.zip' file.")
uploaded = files.upload()  # Select 'finished.zip' from your computer

# Unzip it
with zipfile.ZipFile('Finished.zip', 'r') as zip_ref:
    zip_ref.extractall('data')

print("Dataset extracted to 'data/finished/'")

Upload your 'Finished.zip' file.


Saving Finished.zip to Finished (2).zip
Dataset extracted to 'data/finished/'


In [18]:
# Cell 3: Preprocess dataset into a list of documents
import os
import json
from pathlib import Path

def parse_json_file(file_path):
    """Extract and combine input fields into a single clinical note."""
    with open(file_path, 'r') as f:
        data = json.load(f)

    # Combine input fields with section labels
    note_parts = []
    if 'input1' in data and data['input1'].strip():
        note_parts.append(f"[Chief Complaint / Symptoms]:\n{data['input1'].strip()}")
    if 'input2' in data and data['input2'].strip():
        note_parts.append(f"[History of Present Illness]:\n{data['input2'].strip()}")
    if 'input3' in data and data['input3'].strip():
        note_parts.append(f"[Past Medical History / Allergies]:\n{data['input3'].strip()}")
    if 'input4' in data and data['input4'].strip():
        note_parts.append(f"[Review of Systems]:\n{data['input4'].strip()}")
    if 'input5' in data and data['input5'].strip():
        note_parts.append(f"[Physical Exam]:\n{data['input5'].strip()}")
    if 'input6' in data and data['input6'].strip():
        note_parts.append(f"[Relevant Labs / Diagnostics]:\n{data['input6'].strip()}")

    full_text = "\n\n".join(note_parts)
    return full_text

# Walk through the 'finished' directory
documents = []
root_dir = Path("data/Finished")

for json_path in root_dir.rglob("*.json"):
    # Skip .DS_Store or hidden files
    if json_path.name.startswith('.'):
        continue

    try:
        # Build diagnosis category from folder path (e.g., Asthma/Severe Asthma Exacerbation)
        rel_path = json_path.relative_to(root_dir)
        diagnosis_category = str(rel_path.parent).replace(os.sep, " / ")

        text = parse_json_file(json_path)
        if not text.strip():
            continue  # skip empty files

        doc = {
            "id": json_path.stem,  # e.g., "11535902-DS-14"
            "diagnosis_category": diagnosis_category,
            "text": text,
            "file_path": str(json_path)
        }
        documents.append(doc)
    except Exception as e:
        print(f"Error processing {json_path}: {e}")

print(f"Preprocessed {len(documents)} clinical notes.")

Preprocessed 511 clinical notes.


In [19]:
# Cell 4: Save documents as JSONL
import json

with open('preprocessed_documents.jsonl', 'w') as f:
    for doc in documents:
        f.write(json.dumps(doc) + '\n')

print("Saved preprocessed data to 'preprocessed_documents.jsonl'")

Saved preprocessed data to 'preprocessed_documents.jsonl'


In [20]:
# Cell 5: View sample preprocessed documents
import json

# Print first 2 documents as examples
for i in range(min(2, len(documents))):
    print("="*80)
    print(f" Document {i+1}: ID = {documents[i]['id']}")
    print(f" Diagnosis Category: {documents[i]['diagnosis_category']}")
    print("\n Clinical Note:\n")
    print(documents[i]['text'])
    print("\n" + "="*80 + "\n")

 Document 1: ID = 14545508-DS-6
 Diagnosis Category: Upper Gastrointestinal Bleeding

 Clinical Note:

[Chief Complaint / Symptoms]:
malaise, melanotic stools

[History of Present Illness]:
A man with recent admission here for C. Cath given NSTEMI, s/p BMS to LAD, c/b hematemesis requiring transfusions. Pt. has complained of ongoing dark stools, malaise at home, so reported to ED. Found to have Guaiac positive stools, Hct down slightly from baseline, no evidence of brisk bleeding seen. AF, VSS in ED. Denies CP, SOB, hematemesis, lightheadedness, hematochezia. All other ROS negative.

[Past Medical History / Allergies]:
+NSTEMI, s/p LAD BMD
+severe esophagitis and ugib in setting of dual antiplatelet therapy following nstemi.
+hypertension  
+Hematuria 
+Glaucoma left eye  
+left cRenal mass - c/w malignancy  
+Colon adenoma  
+CAD - diagnosed by stress test  
+Depression  
+Elevated PSA (approx 21)

[Review of Systems]:
heart attack in patient's father. Patient's father died of aortic 

In [21]:
# Cell 6 (FIXED): Generate embeddings with separate variable
from sentence_transformers import SentenceTransformer
import numpy as np
from tqdm import tqdm

print("Loading clinical sentence transformer...")
embedding_model = SentenceTransformer('pritamdeka/S-Biomed-Roberta-snli-multinli-stsb')  # ← renamed!
embedding_model.eval()

texts = [doc["text"] for doc in documents]

print(f"Generating embeddings for {len(texts)} documents...")
embeddings = embedding_model.encode(  # ← use embedding_model
    texts,
    batch_size=16,
    show_progress_bar=True,
    convert_to_numpy=True,
    normalize_embeddings=True
)
embedding_model.save("embedding_model")

print(f" Embeddings shape: {embeddings.shape}")

Loading clinical sentence transformer...
Generating embeddings for 511 documents...


Batches:   0%|          | 0/32 [00:00<?, ?it/s]

 Embeddings shape: (511, 768)


In [22]:
# Cell 7: Build FAISS index (using inner product = cosine similarity)
import faiss

# Get embedding dimension
dim = embeddings.shape[1]

# Create FAISS index (IndexFlatIP = inner product for normalized vectors = cosine sim)
index = faiss.IndexFlatIP(dim)  # IP = Inner Product (cosine similarity when normalized)

# Add embeddings to index
index.add(embeddings.astype('float32'))
faiss.write_index(index, "index.faiss")
print(f" FAISS index built with {index.ntotal} documents.")

 FAISS index built with 511 documents.


In [23]:
# Cell 8 (FIXED): Use embedding_model, not global 'model'
def retrieve(query: str, k: int = 3):
    query_emb = embedding_model.encode(query, normalize_embeddings=True).astype('float32')  # ← embedding_model
    query_emb = np.expand_dims(query_emb, axis=0)

    scores, indices = index.search(query_emb, k)

    results = []
    for i in range(len(indices[0])):
        idx = indices[0][i]
        score = float(scores[0][i])
        results.append({
            "score": score,
            "document": documents[idx]
        })
    return results

In [24]:
# Cell 9: Test the retriever
sample_queries = [
    "What are common symptoms of COPD exacerbation?",
    "How is severe asthma diagnosed?",
    "Signs of acute coronary syndrome in elderly patients",
    "Lab findings in adrenal insufficiency"
]

for query in sample_queries:
    print(f"\n Query: {query}")
    print("-" * 60)
    results = retrieve(query, k=2)
    for i, res in enumerate(results):
        print(f"\n Result {i+1} (Score: {res['score']:.3f})")
        print(f"Diagnosis: {res['document']['diagnosis_category']}")
        # Show first 200 chars of note
        print(f"Text preview: {res['document']['text'][:200]}...")
    print("\n" + "="*80)


 Query: What are common symptoms of COPD exacerbation?
------------------------------------------------------------

 Result 1 (Score: 0.647)
Diagnosis: COPD
Text preview: [Chief Complaint / Symptoms]:
None

[History of Present Illness]:
The patient reports that over the past 1.5 months, she 
has experienced worsening shortness of breath and chest pressure with exertion...

 Result 2 (Score: 0.636)
Diagnosis: Asthma / COPD Asthma
Text preview: [Chief Complaint / Symptoms]:
Cough, SOB

[History of Present Illness]:
She is with history of asthma, anemia, MGUS, COPD,  and prior CVA, who presented with shortness of breath.  Patient reports that...


 Query: How is severe asthma diagnosed?
------------------------------------------------------------

 Result 1 (Score: 0.585)
Diagnosis: Asthma / Severe Asthma Exacerbation
Text preview: [Chief Complaint / Symptoms]:
asthma

[History of Present Illness]:
Female with history of asthma on 3L O2 for ambulation, OSA on CPAP, allergic rhinitis, at

In [25]:
# Cell 10: Load tokenizer and model without custom modules
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Clear any model references
torch.cuda.empty_cache()

# Load using official transformers code (no custom modeling_phi3.py)
tokenizer = AutoTokenizer.from_pretrained(
    "microsoft/Phi-3-mini-4k-instruct",
    trust_remote_code=False,  # ⚠️ SET TO FALSE!
    padding_side="left"
)

# Set pad_token if missing
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

llm = AutoModelForCausalLM.from_pretrained(
    "microsoft/Phi-3-mini-4k-instruct",
    device_map="auto",
    torch_dtype=torch.float16,
    trust_remote_code=False,  # ⚠️ CRITICAL: Use official transformers code
    low_cpu_mem_usage=True,
    use_safetensors=True
)

print(" Phi-3 loaded using official transformers implementation!")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



 Phi-3 loaded using official transformers implementation!


In [26]:
# Cell 11 (UPDATED): Phi-3 RAG prompt
def build_rag_prompt(query: str, retrieved_docs: list, max_length: int = 2000):
    context_parts = []
    for i, doc in enumerate(retrieved_docs):
        note = doc['document']['text']
        context_parts.append(f"Document {i+1} (Diagnosis: {doc['document']['diagnosis_category']}):\n{note}\n")

    context = "\n".join(context_parts)
    if len(context) > max_length:
        context = context[:max_length] + "... [truncated]"

    prompt = f"""<|system|>
You are a clinical assistant. Follow these rules strictly:
1. Answer ONLY using facts explicitly stated in the "Context" below.
2. NEVER mention lab tests (e.g., ESR, CRP, AFB smear, culture), imaging findings, or symptoms unless they appear verbatim in the Context.
3. NEVER use phrases like "typical findings", "may include", "commonly seen", or "laboratory findings may include".
4. If the Context does not contain information directly relevant to the question, respond EXACTLY:
   "The provided clinical notes do not contain sufficient information."
<|end|>
<|user|>
Context:
{context}

Question: {query}
<|end|>
<|assistant|>"""

    return prompt

In [27]:
# Cell 12 (UPDATED): RAG pipeline generating clean answer
def rag_pipeline(query: str, k: int = 3, max_new_tokens: int = 250):
    retrieved = retrieve(query, k=k)
    prompt = build_rag_prompt(query, retrieved)

    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=2048
    ).to("cuda")

    with torch.no_grad():
        outputs = llm.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.3,
            top_p=0.9,
            pad_token_id=tokenizer.pad_token_id,
            use_cache=False,
            return_dict_in_generate=False,
            output_attentions=False,
            output_hidden_states=False
        )

    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Strip assistant special token
    if "<|assistant|>" in answer:
        answer = answer.split("<|assistant|>")[-1].strip()

    # Strip repeated instructions or question if echoed
    if "Question:" in answer:
        answer = answer.split("Question:")[-1].strip()

    # Optional: remove query text if repeated
    if answer.lower().startswith(query.lower()):
        answer = answer[len(query):].strip()

    return answer, retrieved


In [28]:
# ================================
# Cell 13 — Full Evaluation & Analysis
# ================================

import re
import numpy as np

# ---------------------------
# AUTO-GENERATE REAL EVAL LABELS
# ---------------------------
def extract_query_from_text(text):
    """Extract a meaningful query from each clinical note."""
    # Chief complaint extraction
    m = re.search(r"\[Chief Complaint.*?\]:\s*(.+)", text, re.I | re.S)
    if m:
        cc = m.group(1).strip().split(".")[0]
        if len(cc) > 6:
            return cc

    # Fallback: first sentence
    first_sentence = text.strip().split(".")[0]
    if len(first_sentence) > 6:
        return first_sentence

    return text[:80]


auto_test_queries = []
auto_ground_truth = []

N = 10  # number of evaluation samples

for doc in documents[:N]:
    q = extract_query_from_text(doc["text"])
    auto_test_queries.append(q)
    auto_ground_truth.append([doc["id"]])  # TRUE label is the same document

print("Generated", len(auto_test_queries), "real evaluation pairs.\n")



# ---------------------------
# RETRIEVAL METRICS
# ---------------------------
def evaluate_retrieval(test_queries, ground_truth_docs, k=3):
    precision_scores = []
    recall_scores = []

    for query, gt_list in zip(test_queries, ground_truth_docs):
        retrieved = retrieve(query, k=k)
        retrieved_ids = [r["document"]["id"] for r in retrieved]

        hits = sum([1 for gt in gt_list if gt in retrieved_ids])

        precision = hits / k
        recall = hits / len(gt_list)

        precision_scores.append(precision)
        recall_scores.append(recall)

    return {
        "precision": float(np.mean(precision_scores)),
        "recall": float(np.mean(recall_scores))
    }



# ---------------------------
# GENERATION EVALUATION
# ---------------------------
def text_contains_keywords(text, keywords):
    text = text.lower()
    return any(k.lower() in text for k in keywords)

def evaluate_generation(test_queries, ground_truth_docs):
    results = []

    for query, gt in zip(test_queries, ground_truth_docs):
        answer, retrieved_docs = rag_pipeline(query, k=3)

        # expected text: from the true document
        true_doc = next(d for d in documents if d["id"] == gt[0])
        expected_keywords = true_doc["text"].split()[:12]  # top 12 words

        accuracy = 1 if text_contains_keywords(answer, expected_keywords) else 0
        coherence = 1 if len(answer.split()) > 3 else 0
        relevance = 1 if any(word.lower() in answer.lower() for word in query.lower().split()) else 0

        results.append({
            "question": query,
            "answer": answer,
            "accuracy": accuracy,
            "coherence": coherence,
            "relevance": relevance
        })

    return results



# ---------------------------
# ERROR ANALYSIS
# ---------------------------
def print_error_analysis(gen_results):
    print("\n===== ERROR ANALYSIS =====\n")

    for r in gen_results:
        print("Question:", r["question"])
        print("Answer:", r["answer"][:250], "...")

        if r["accuracy"] == 0:
            print(" Accuracy failure: model answer didn't contain expected info.")
        print("--------------------------------------------")



# ---------------------------
# OPTIONAL SCENARIO SIMULATION
# ---------------------------
def simulate_clinical_scenario():
    prompt = "Patient presents with cough, fever, and chest tightness. Based on similar notes, list what information is usually documented."
    answer, _ = rag_pipeline(prompt, k=3)
    return answer



# ---------------------------
# RUN EVERYTHING
# ---------------------------
print("\nRunning evaluation...\n")

retrieval_scores = evaluate_retrieval(
    test_queries=auto_test_queries,
    ground_truth_docs=auto_ground_truth,
    k=3
)

print("Retrieval Precision:", retrieval_scores["precision"])
print("Retrieval Recall:", retrieval_scores["recall"])

gen_results = evaluate_generation(auto_test_queries, auto_ground_truth)

print("\nGeneration evaluation results:")
for r in gen_results:
    print(r)

# Error analysis
print_error_analysis(gen_results)

# Scenario simulation
print("\n===== CLINICAL SCENARIO SIMULATION =====\n")
print(simulate_clinical_scenario())

print("\nEvaluation complete!")


Generated 10 real evaluation pairs.


Running evaluation...

Retrieval Precision: 0.33333333333333337
Retrieval Recall: 1.0

Generation evaluation results:
{'question': 'malaise, melanotic stools\n\n[History of Present Illness]:\nA man with recent admission here for C', 'answer': 'The provided clinical notes do not contain sufficient information.', 'accuracy': 1, 'coherence': 1, 'relevance': 1}
{'question': 'None\n\n[History of Present Illness]:\nMale with recurrent symptomatic dizziness over the past four months', 'answer': 'The provided clinical notes do not contain sufficient information.', 'accuracy': 0, 'coherence': 1, 'relevance': 1}
{'question': 'CC: coffee ground emesis, melena\n\n[History of Present Illness]:\nA female with a history of PUD, hiatal hernia, chronic back pain, presenting as a transfer from after melena/hematemesis and then a syncopal episode and fall', 'answer': 'The provided clinical notes do not contain sufficient information.', 'accuracy': 0, 'coherence': 1, 

In [29]:
# Cell 13: Final test
query = "What did the chest CT show in the patient with a history of treated tuberculosis who presented with respiratory symptoms"
answer, docs = rag_pipeline(query, k=2)
print(" Answer:\n", answer)

 Answer:
 The chest CT showed granulomatous changes, cavitary lesions, and lymphadenopathy.


In [30]:
!pip install huggingface_hub




In [31]:
from huggingface_hub import notebook_login
notebook_login()


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [33]:
from huggingface_hub import HfApi, upload_folder

repo_id = "zentom/embedding_model"  # rename this

# Create repo if it doesn't exist


# Upload the folder
upload_folder(
    folder_path="embedding_model",
    repo_id=repo_id,
    commit_message="Upload embedding model"
)

print("Uploaded successfully!")


Processing Files (0 / 0)      : |          |  0.00B /  0.00B            

New Data Upload               : |          |  0.00B /  0.00B            

  ...g_model/model.safetensors:   0%|          |  550kB /  499MB            

Uploaded successfully!
