In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install llama-index llama-index-embeddings-openai langchain_community igraph leidenalg cdlib chromadb qdrant-client




In [3]:
!apt-get install docker.io -y

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
docker.io is already the newest version (27.5.1-0ubuntu3~22.04.2).
0 upgraded, 0 newly installed, 0 to remove and 35 not upgraded.


In [4]:
!pip install fiftyone



In [5]:
!pip install torchvision



# **Data Preparation & Graph Construction**



In [6]:
import pandas as pd
from llama_index.core.graph_stores.types import EntityNode, Relation

# Load the medical Knowledge Graph triplets
csv_path = "/content/drive/MyDrive/Medical_ontologies_v2_part2.csv"
df = pd.read_csv(csv_path).dropna(subset=["Entity", "Relationship", "Value"]).head(100)

# Build graph components
entity_nodes = {}
relations = []

for _, row in df.iterrows():
    head = str(row["Entity"]).strip()
    rel = str(row["Relationship"]).strip()
    tail = str(row["Value"]).strip()

    if head not in entity_nodes:
        entity_nodes[head] = EntityNode(name=head)
    if tail not in entity_nodes:
        entity_nodes[tail] = EntityNode(name=tail)

    relations.append(Relation(
        label=rel,
        source_id=entity_nodes[head].id,
        target_id=entity_nodes[tail].id,
        properties={"source": head, "target": tail}
    ))


# **CommunityGraphStore – Summary + Triples + Embeddings**

In [7]:
import os
os.environ["OPENAI_API_KEY"] = "sk-..."

In [8]:

import networkx as nx
import json
import os
import leidenalg
import igraph as ig
from llama_index.core.graph_stores import SimplePropertyGraphStore
from llama_index.core.llms import ChatMessage


class CommunityGraphStore(SimplePropertyGraphStore):
    def __init__(self, entity_nodes, relations, llm, embed_model, cache_path="/content/drive/MyDrive/community_summaries.json"):
        super().__init__()
        self.llm = llm
        self.embed_model = embed_model
        self.community_summary = {}
        self.cache_path = cache_path

        for node in entity_nodes.values():
            self.graph.add_node(node)
        for rel in relations:
            self.graph.add_relation(rel)

        self.load_summaries()

    def _create_nx_graph(self):
        nx_graph = nx.Graph()
        for node in self.graph.nodes.values():
            nx_graph.add_node(node.id, name=node.name)
        for rel in self.graph.relations.values():
            nx_graph.add_edge(
                rel.source_id,
                rel.target_id,
                relationship=rel.label,
                description=f"{rel.properties['source']} {rel.label} {rel.properties['target']}"
            )
        return nx_graph

    def _nx_to_igraph(self, nx_graph):
        mapping = {node: i for i, node in enumerate(nx_graph.nodes())}
        reverse_mapping = {i: node for node, i in mapping.items()}
        edges = [(mapping[u], mapping[v]) for u, v in nx_graph.edges()]
        g = ig.Graph(edges=edges)
        return g, mapping, reverse_mapping

    def _chunk_triplets(self, triples, chunk_size=50):
        for i in range(0, len(triples), chunk_size):
            yield triples[i:i + chunk_size]

    def _generate_summary(self, triples, chunk_size=50):
        summaries = []
        for chunk in self._chunk_triplets(triples, chunk_size):
            text = "\n".join(chunk)
            prompt = f"""
You are a biomedical ontology summarizer.
Your job is to read a set of subject–predicate–object triplets extracted from a medical knowledge graph and summarize the high-level clinical patterns and relationships.

Be concise but insightful. Highlight key medical connections, recurring entities, and frequent relationships. Use medically appropriate terminology.

--- Ontology Triplets ---
{text}

--- Summary ---"""
            messages = [
                ChatMessage(role="system", content=prompt),
                ChatMessage(role="user", content="Summarize the above relationships into a meaningful medical insight.")
            ]
            try:
                response = self.llm.chat(messages)
                summaries.append(str(response).strip())
            except Exception as e:
                print(f"[LLM ERROR] Failed to summarize chunk: {e}")
                summaries.append("Summary chunk failed.")

        if len(summaries) == 1:
            return summaries[0]

        final_prompt = f"""
You are a clinical knowledge synthesis expert.
You are given several summaries of medical triplet clusters extracted from an ontology-based knowledge graph.
Merge them into a single cohesive summary, prioritizing clarity, correctness, and clinical relevance.

--- Partial Summaries ---
{chr(10).join(summaries)}

--- Final Summary ---"""
        messages = [
            ChatMessage(role="system", content=final_prompt),
            ChatMessage(role="user", content="Write the merged summary highlighting key biomedical insights.")
        ]
        try:
            final_response = self.llm.chat(messages)
            return str(final_response).strip()
        except Exception as e:
            print(f"[LLM ERROR] Failed to merge summaries: {e}")
            return "Final summary generation failed."

    def build_communities(self):
        nx_graph = self._create_nx_graph()
        ig_graph, mapping, reverse_mapping = self._nx_to_igraph(nx_graph)
        leiden_partition = leidenalg.find_partition(ig_graph, leidenalg.ModularityVertexPartition)

        community_id_counter = 0
        for cluster in leiden_partition:
            node_ids = [reverse_mapping[i] for i in cluster]
            triples = []
            for i in range(len(node_ids)):
                for j in range(i + 1, len(node_ids)):
                    if nx_graph.has_edge(node_ids[i], node_ids[j]):
                        edge = nx_graph.get_edge_data(node_ids[i], node_ids[j])
                        triples.append(
                            f"{nx_graph.nodes[node_ids[i]]['name']} --{edge['relationship']}--> {nx_graph.nodes[node_ids[j]]['name']}"
                        )

            if triples:
                summary = self._generate_summary(triples)
                try:
                    embedding = self.embed_model.get_text_embedding(summary)
                except Exception as e:
                    print(f"[Embedding ERROR] Could not compute embedding: {e}")
                    embedding = []

                self.community_summary[community_id_counter] = {
                    "summary": summary,
                    "triples": triples,
                    "embedding": embedding
                }
                community_id_counter += 1

        self.save_summaries()

    def get_community_summaries(self):
        if not self.community_summary:
            self.build_communities()
        return self.community_summary

    def save_summaries(self):
        try:
            with open(self.cache_path, "w") as f:
                json.dump(self.community_summary, f)
        except Exception as e:
            print(f"[File ERROR] Failed to save summaries: {e}")

    def load_summaries(self):
        if os.path.exists(self.cache_path):
            try:
                with open(self.cache_path, "r") as f:
                    self.community_summary = json.load(f)
            except Exception as e:
                print(f"[File ERROR] Failed to load summaries: {e}")


# **GraphRAGQueryEngine – Top-K Community Retrieval & QA**

In [21]:
import uuid
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
from llama_index.core.base.llms.types import ChatMessage


class QdrantGraphRAGQueryEngine:
    def __init__(
        self,
        graph_store,
        llm,
        embed_model,
        top_k=5,
        collection_name="community_summaries",
        qdrant_url=None,
        qdrant_api_key=None,
    ):
        self.graph_store = graph_store
        self.llm = llm
        self.embed_model = embed_model
        self.top_k = top_k
        self.collection_name = collection_name

        # ✅ In-memory fallback for Colab/local environments
        if qdrant_url:
            self.qdrant = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
        else:
            self.qdrant = QdrantClient(":memory:")

        # ✅ Create collection if not exists
        if not self.qdrant.collection_exists(collection_name):
            self.qdrant.create_collection(
                collection_name=collection_name,
                vectors_config=VectorParams(
                    size=getattr(self.embed_model, "embedding_size", 1536),
                    distance=Distance.COSINE,
                ),
            )

        self._index_communities()

    def _index_communities(self):
        summaries = self.graph_store.get_community_summaries()
        seen_ids = set()
        points = []

        for cid, data in summaries.items():
            embedding = data.get("embedding", [])
            summary = data.get("summary", "")
            if not embedding or not summary:
                continue

            raw_key = f"{cid}_{summary}"
            uid = str(uuid.uuid5(uuid.NAMESPACE_DNS, raw_key))

            if uid in seen_ids:
                continue
            seen_ids.add(uid)

            points.append(PointStruct(
                id=uid,
                vector=embedding,
                payload={"triples": data.get("triples", [])}
            ))

        if points:
            self.qdrant.upsert(collection_name=self.collection_name, points=points)

    def _convert_triplets_to_sentences(self, triplets):
        sentences = []
        for triplet in triplets:
            if "--" in triplet and "-->" in triplet:
                subject, rest = triplet.split("--", 1)
                predicate, obj = rest.split("-->", 1)
                sentence = f"{subject.strip()} {predicate.strip()} {obj.strip()}."
                sentences.append(sentence)
        return "\n".join(sentences)

    def query(self, query_str):
        query_embedding = self.embed_model.get_text_embedding(query_str)

        results = self.qdrant.search(
            collection_name=self.collection_name,
            query_vector=query_embedding,
            limit=self.top_k
        )

        triplet_blocks = []
        for res in results:
            triplets = res.payload.get("triples", [])
            if triplets:
                triplet_blocks.append(self._convert_triplets_to_sentences(triplets))

        triplet_context = "\n\n".join(triplet_blocks)

        prompt = f"""
You are **MedQuery**, a highly specialized biomedical reasoning assistant trained in clinical knowledge synthesis.

You are provided with structured biomedical knowledge derived from **subject–predicate–object triplets**. These have already been transformed into natural language facts describing relationships between clinical concepts, molecular mechanisms, diseases, symptoms, pathways, and therapeutics.

Your task is to deeply analyze this factual context and generate a medically **accurate**, **coherent**, and **comprehensive** response to the clinical question that follows. Use proper clinical reasoning, precise terminology, and explain mechanisms where relevant.

🧠 Guidelines:
- DO NOT copy or reference the original triplet structure.
- DO synthesize insights into fluent, human-like clinical sentences.
- DO include possible pathophysiological, pharmacological, or genetic explanations if applicable.
- DO maintain an academic and professional tone as if writing for a medical research analyst or clinician.
- DO cite multiple facts from the context when needed.
- DO NOT fabricate any information not grounded in the given knowledge.

===
📚 Clinical Knowledge:
{triplet_context}

===
❓ Clinical Question:
{query_str}

===
📝 Answer:
""".strip()


        messages = [
            ChatMessage(role="system", content=prompt),
            ChatMessage(
                role="user",
                content="Write a medically sound answer using the provided knowledge. Do not echo triplets."
            ),
        ]

        try:
            response = self.llm.chat(messages)
            return str(response).strip()
        except Exception as e:
            return f"Response generation failed: {e}"


# **End-to-End Pipeline Execution**

In [10]:
!pip install bitsandbytes accelerate transformers




In [11]:
from huggingface_hub import login

# Use your token here
login("hf_...")

In [12]:
from typing import List, Iterator, Any
from llama_index.core.llms import CustomLLM, CompletionResponse, CompletionResponseGen, LLMMetadata, ChatMessage
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from pydantic import PrivateAttr
import torch


class MedGemmaLLM(CustomLLM):
    _model_id: str = PrivateAttr()
    _tokenizer: Any = PrivateAttr()
    _model: Any = PrivateAttr()
    _device: str = PrivateAttr()
    _eos_token_id: int = PrivateAttr()

    def __init__(self, model_variant: str = "4b-it", device: str = "auto", **kwargs):
        super().__init__(**kwargs)

        self._model_id = f"google/medgemma-{model_variant}"
        self._device = device

        # Use bfloat16 if supported, else float32
        major, _ = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0)
        torch_dtype = torch.bfloat16 if major >= 8 else torch.float32

        quant_config = BitsAndBytesConfig(load_in_4bit=True)

        self._tokenizer = AutoTokenizer.from_pretrained(self._model_id)
        self._model = AutoModelForCausalLM.from_pretrained(
            self._model_id,
            device_map=self._device,
            torch_dtype=torch_dtype,
            quantization_config=quant_config,
        )

        self._eos_token_id = self._tokenizer.eos_token_id

    def chat(self, messages: List[ChatMessage], **kwargs) -> str:
        prompt = self._tokenizer.apply_chat_template(
            [{"role": msg.role, "content": msg.content} for msg in messages],
            tokenize=True,
            add_generation_prompt=True,
            return_dict=True,
            return_tensors="pt"
        ).to(self._model.device)

        input_len = prompt["input_ids"].shape[-1]

        with torch.inference_mode():
            output = self._model.generate(
                **prompt,
                max_new_tokens=kwargs.get("max_tokens", 500),
                do_sample=False,
                pad_token_id=self._eos_token_id
            )

        decoded = self._tokenizer.decode(output[0][input_len:], skip_special_tokens=True)
        return decoded.strip()

    def complete(self, prompt: str, **kwargs) -> CompletionResponse:
        messages = [ChatMessage(role="user", content=prompt)]
        output = self.chat(messages, **kwargs)
        return CompletionResponse(text=output)

    def stream_complete(self, prompt: str, **kwargs) -> Iterator[CompletionResponse]:
        output = self.complete(prompt, **kwargs)
        return iter([CompletionResponse(text=output.text, delta=output.text)])

    @property
    def metadata(self) -> LLMMetadata:
        return LLMMetadata(
            model_name=self._model_id,
            is_chat_model=True,
            is_function_calling_model=False,
            is_streaming_model=False,
            context_window=8192,
            num_output=500,
        )


In [22]:
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.core.llms import ChatMessage

# Your custom LLM
llm = MedGemmaLLM()

# Embedding model (can also use SentenceTransformer)
embed_model = OpenAIEmbedding(model="text-embedding-3-small")

# Create the knowledge graph
graph_store = CommunityGraphStore(entity_nodes, relations, llm, embed_model)
graph_store.build_communities()

from qdrant_client import QdrantClient
qdrant = QdrantClient(location=":memory:")




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not va

In [23]:
query_engine = QdrantGraphRAGQueryEngine(
    graph_store=graph_store,
    llm=llm,
    embed_model=embed_model
)


In [24]:

response = query_engine.query("What disease requires evaluation for CSF sample?")
print(response)


  results = self.qdrant.search(
The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Findings suggestive of acute meningitis or encephalitis (Disease or Syndrome) (Meningitis-Encephalitis_Panel_Algorithm) requires Evaluation for CSF sample (Diagnostic Procedure).


# **Load the summaries and instantiate query engine**

In [25]:
class LazyCommunityGraphStore:
    def __init__(self, llm, embed_model, cache_path="/content/drive/MyDrive/community_summaries.json"):
        self.llm = llm
        self.embed_model = embed_model
        self.community_summary = {}
        self.cache_path = cache_path
        self.load_summaries()

    def get_community_summaries(self):
        return self.community_summary

    def load_summaries(self):
        import json, os
        if os.path.exists(self.cache_path):
            with open(self.cache_path, "r") as f:
                self.community_summary = json.load(f)
        else:
            raise FileNotFoundError(f"{self.cache_path} not found")


In [27]:
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI

llm = MedGemmaLLM()
embed_model = OpenAIEmbedding(model="text-embedding-3-small")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [28]:
graph_store = LazyCommunityGraphStore(
    llm=llm,
    embed_model=embed_model,
    cache_path="/content/drive/MyDrive/community_summaries.json"
)

query_engine = QdrantGraphRAGQueryEngine(
    graph_store=graph_store,
    llm=llm,
    embed_model=embed_model,
    qdrant_api_key=os.getenv("QDRANT_API_KEY")
)

In [29]:
response = query_engine.query("What is the role of TNF-alpha in autoimmune inflammation?")
print(response)

  results = self.qdrant.search(
The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Tumor necrosis factor-alpha (TNF-α) plays a significant role in the pathogenesis of autoimmune inflammation. TNF-α is a pro-inflammatory cytokine that contributes to the amplification of the inflammatory response. It achieves this by binding to its receptor, TNF-R1 and TNF-R2, on target cells, leading to downstream signaling cascades that promote the activation of various immune cells, including T cells, B cells, and macrophages. This activation results in the release of additional inflammatory mediators, further perpetuating the inflammatory cycle.

TNF-α contributes to autoimmune inflammation by several mechanisms. It can directly stimulate the production of other cytokines, such as interleukin-1 (IL-1) and interleukin-6 (IL-6), which further amplify the inflammatory response. It can also enhance the expression of adhesion molecules on endothelial cells, facilitating the recruitment of leukocytes to sites of inflammation. Furthermore, TNF-α can induce the production of chemokines, wh

# **LLM for evaluation**

In [30]:
import openai

client = openai.OpenAI()

def llm_judge(question, correct_answer, model_answer):
    system_prompt = """
You are a clinical evaluation expert tasked with grading medical question-answering systems.
You will be shown a medical question, the correct reference answer, and a model-generated answer.

Assign a score from 0 to 10 based on the model's clinical correctness, factual alignment with the reference, clarity, and completeness.

Grading rubric:
- 10: Perfect — clinically precise and complete
- 8–9: Minor flaws but medically correct
- 6–7: Partially correct, some gaps
- 4–5: Important errors or omissions
- 1–3: Mostly wrong
- 0: Completely incorrect or hallucinated

Only return the score as a number.
"""

    user_prompt = f"""
Medical Question: {question}

Correct Answer:
{correct_answer}

Model's Answer:
{model_answer}

Score (integer from 0 to 10):
"""

    try:
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": system_prompt.strip()},
                {"role": "user", "content": user_prompt.strip()}
            ],
            temperature=0.0
        )
        result = response.choices[0].message.content
        digits = [int(s) for s in result.split() if s.isdigit()]
        for d in digits:
            if 0 <= d <= 10:
                return d
    except Exception as e:
        print("GPT-4o judge error:", e)

    return -1


# **Evaluation using MedBullets benchmark**

In [36]:
import requests
import pandas as pd
from tqdm import tqdm

# --- Fetch MedBullets Data ---
def fetch_medbullets_batch(offset=0, length=100):
    url = "https://datasets-server.huggingface.co/rows"
    params = {
        "dataset": "super-dainiu/medagents-benchmark",
        "config": "MedBullets",
        "split": "test_hard",
        "offset": offset,
        "length": length
    }

    try:
        response = requests.get(url, params=params)
        response.raise_for_status()
        data = response.json()
        rows = data.get("rows", [])
        return pd.DataFrame([r["row"] for r in rows])
    except requests.exceptions.RequestException as e:
        print("Network error:", e)
        return pd.DataFrame()
    except ValueError as e:
        print("JSON decoding error:", e)
        return pd.DataFrame()


# --- Format MedBullets Question ---
def format_medbullets_question(row):
    choices = row.get("choices", [])
    options = "\n".join([f"{chr(65+i)}. {choice}" for i, choice in enumerate(choices)])
    return f"{row['question']}\n\nOptions:\n{options}"


# --- Evaluate MedBullets Batch ---
def evaluate_medbullets_batch(df, query_engine):
    results = []

    for _, row in tqdm(df.iterrows(), total=len(df), desc="Evaluating MedBullets batch"):
        try:
            question = format_medbullets_question(row)
            model_answer = query_engine.query(question)
        except Exception as e:
            print("GraphRAG error:", e)
            model_answer = "Error"
            question = row.get("question", "N/A")

        correct_answer = row.get("answer", "Unknown")

        try:
            score = llm_judge(question, correct_answer, model_answer)
        except Exception as e:
            print("Scoring error:", e)
            score = -1

        results.append({
            "question": question,
            "correct_answer": correct_answer,
            "model_answer": model_answer,
            "score": score
        })

    return pd.DataFrame(results)


# --- Run MedBullets Evaluation ---
df = fetch_medbullets_batch(offset=0, length=100)

if df.empty:
    print("❌ No data loaded.")
else:
    result_df = evaluate_medbullets_batch(df, query_engine)
    result_df.to_csv("/content/drive/MyDrive/graphrag_medbullets_results.csv", index=False)

    # Summary
    print("Evaluation complete.")
    print("Average Score:", result_df["score"].mean())
    print("Total Evaluated:", len(result_df))


  results = self.qdrant.search(
The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Evaluating MedBullets batch:   1%|          | 1/89 [00:28<41:39, 28.40s/it]The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Evaluating MedBullets batch:   2%|▏         | 2/89 [01:19<1:00:16, 41.57s/it]The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Evaluating MedBullets batch:   3%|▎         | 3/89 [01:42<47:36, 33.21s/it]  The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Evaluating MedBullets batch:   4%|▍         | 4/89 [02:11<44:58, 31.75s/it]The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=inf

Evaluation complete.
Average Score: 1.9662921348314606
Total Evaluated: 89



