In [15]:
from typing import List
import re
import pandas as pd


print("Welcome to the Medical Knowledge Assistant.")
print("Processing your request, please wait...\n")


def split_into_chunks(csv_file: str) -> List[str]:
    df = pd.read_csv(csv_file)
    chunks = []
    for _, row in df.iterrows():
        parts = []
        for col in df.columns:
            
            val = row[col]
            if pd.isna(val):
                continue
                
            parts.append(f"{col}: {str(val)}")
        chunks.append("\n".join(parts))
    return chunks

chunks = split_into_chunks("mimic_dataset_600_keywords.csv")

# for i, chunk in enumerate(chunks[:3]):
#     print(f"[{i}] {chunk}\n")


from sentence_transformers import SentenceTransformer

embedding_model = SentenceTransformer("shibing624/text2vec-base-chinese")

def embed_chunk(chunk: str) -> List[float]:
    embedding = embedding_model.encode(chunk, normalize_embeddings=True)
    return embedding.tolist()

# embedding = embed_chunk("test")
# print(len(embedding))



embeddings = embedding_model.encode(chunks, normalize_embeddings=True)
embeddings = embeddings.tolist()

# print(len(embeddings))
# print(len(embeddings[0]))


import chromadb

chromadb_client = chromadb.EphemeralClient()
chromadb_collection = chromadb_client.get_or_create_collection(name="default3")

def save_embeddings(chunks: List[str], embeddings: List[List[float]]) -> None:
    ids = [str(i) for i in range(len(chunks))]
    chromadb_collection.add(
        documents=chunks,
        embeddings=embeddings,
        ids=ids
    )

save_embeddings(chunks, embeddings)


def retrieve(query: str, top_k: int) -> List[str]:
    query_embedding = embed_chunk(query)
    results = chromadb_collection.query(
        query_embeddings=[query_embedding],
        n_results=top_k
    )
    return results["documents"][0]


# Pick a query

def pick_query_from_dataset(chunks: List[str]) -> str:
    if not chunks:
        return "sepsis"
    text = chunks[0]
    cands = re.findall(r"[\u4e00-\u9fffA-Za-z0-9]{2,}", text)
    if cands:
        return cands[0]
    return text.splitlines()[0][:20]

query = "Based on the retrieved cases, what symptoms and medications were observed for AMI inferior wall?"


retrieved_chunks = retrieve(query, 10)

# for i, chunk in enumerate(retrieved_chunks):
#     print(f"[{i}] {chunk}\n")


from sentence_transformers import CrossEncoder

cross_encoder = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1")

def rerank(query: str, retrieved_chunks: List[str], top_k: int) -> List[str]:
    pairs = [(query, chunk) for chunk in retrieved_chunks]
    scores = cross_encoder.predict(pairs)

    scored_chunks = list(zip(retrieved_chunks, scores))
    scored_chunks.sort(key=lambda x: x[1], reverse=True)

    return [chunk for chunk, _ in scored_chunks][:top_k]

reranked_chunks = rerank(query, retrieved_chunks, 3)

# Print Chunks
print("Retrieved context:\n")
for i, chunk in enumerate(reranked_chunks):
    print(f"[{i}] {chunk}\n")


from dotenv import load_dotenv
from google import genai

load_dotenv()
google_client = genai.Client()

# API call
def generate(query: str, chunks: List[str]) -> str:
    prompt = f"""You are a medical knowledge assistant. Answer the user's question using ONLY the context below.

User question:
{query}

Context:
{"\n\n".join(chunks)}

Rules:
- Do not invent facts.
- If the context is not enough, say you don't know based on the context.
- Keep it concise.
"""
    response = google_client.models.generate_content(
        model="gemini-2.5-flash",
        contents=prompt
    )
    return response.text
    
print("Question:", query)
answer = generate(query, reranked_chunks)
print("\nAnswer:")
print(answer)

Welcome to the Medical Knowledge Assistant.
Processing your request, please wait...

Retrieved context:

[0] HADM_ID: 139560
Disease_Name: AMI inferior wall, init
SUBJECT_ID: 5599
Symptom_Keywords: woman known coronary artery disease who reported feeling well sister morning found weak without chest pain shortness breath taken where electrocardiogram showed inferior right sided myocardial infarction catheterization laboratory three vessel total occlusion left circumflex diffusely diseased anterior descending stented times intraaortic balloon pump placed hypotension subsequently went into ventricular fibrillation arrest intubated converted sinus rhythm defibrillation then transferred further management outside received heparin integrilin lidocaine digoxin transiently dopamine asp
Standard_Output: Diagnosis: AMI inferior wall, init | Medications: None

[1] HADM_ID: 168124
Disease_Name: AMI inferior wall, init
SUBJECT_ID: 43633
Prescriptions: 1/2 NS; Acetaminophen; Albuterol Inhaler; Aspir