# LLM Review of Disease-Centric Splits

Run `uv run cli split` first.

In [1]:
import logging
import os
import shutil

import numpy as np
import pandas as pd
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from tqdm import tqdm

from src.config import conf

_logger = logging.getLogger(__name__)

Read in disease splits.

In [2]:
nodes = pd.read_csv(conf.paths.kg.nodes_path, dtype={"node_index": int}, low_memory=False)
edges = pd.read_csv(
    conf.paths.kg.edges_path, dtype={"edge_index": int, "x_index": int, "y_index": int}, low_memory=False
)
disease_splits = pd.read_csv(conf.paths.splits_dir / "disease_splits.csv")

Set up the LLM client using LangChain. This uses Gemini 3 by default but can easily be switched to other models.

In [3]:
# Using LangChain for easy model switching
# Gemini 3 requires GOOGLE_API_KEY environment variable
# To switch models, just change the model name:
#   - "gemini-2.5-flash" (fast, cost-effective)
#   - "gemini-2.5-pro" (more capable)
#   - "gemini-3-pro-preview" (latest, best reasoning)
llm = ChatGoogleGenerativeAI(
    model=conf.llm.model_name,
    api_key=conf.GOOGLE_API_KEY,
    temperature=conf.llm.temperature,
    max_tokens=conf.llm.max_tokens,
    max_retries=2,
)

Use LLM to evaluate disease splits.

In [4]:
# Prepare all messages upfront
all_messages = []
SYSTEM_PROMPT = """You are a helpful biomedical expert with an understanding of disease mechanisms, treatment options for every disease, and deep clinical knowledge of disease symptoms, phenotypes, genotypes, and drug treatments."""

for _i, row in tqdm(disease_splits.iterrows(), desc="Preparing prompts", total=len(disease_splits)):
    # Construct message
    split_disease = nodes[nodes["node_index"] == row["disease_split_index"]]["node_name"].values[0]
    candidate_disease = row["node_name"]
    split_disease = split_disease.replace("(disease)", "").strip()
    candidate_disease = candidate_disease.replace("(disease)", "").strip()

    user_prompt = f"""Rank on a scale from 1 to 5 how closely related {split_disease} and {candidate_disease} are. 1 is not related at all, 4 is that that they are closely related (e.g., a drug that treats {split_disease} could also treat {candidate_disease}), 5 is that they are the same disease or subtypes of the same disease. Respond with a number from 1-5 only, no other text."""

    all_messages.append([
        SystemMessage(content=SYSTEM_PROMPT),
        HumanMessage(content=user_prompt),
    ])

# Batch inference with concurrency
_logger.info(f"Running batch inference on {len(all_messages)} prompts...")
responses = llm.batch(all_messages, config={"max_concurrency": 8})

# Extract results
llm_ranks = []
tokens_used = []
for response in responses:
    response_text = response.text if hasattr(response, 'text') else response.content
    llm_ranks.append(response_text.strip())
    tokens_used.append(response.usage_metadata.get("total_tokens", 0) if response.usage_metadata else 0)

# Add to disease splits
disease_splits["llm_rank"] = llm_ranks
disease_splits["tokens_used"] = tokens_used

 ... (more hidden) ...


In [5]:
disease_splits["llm_rank"] = disease_splits["llm_rank"].astype(int)
_logger.info(disease_splits["llm_rank"].value_counts())

Set self comparisons to `Yes`. Set all non-`Yes`/`No` comparisons to `No`.

In [6]:
# Construct llm_eval as 'Yes' if llm_rank >= 3, 'No' if llm_rank < 3
disease_splits["llm_eval"] = np.where(disease_splits["llm_rank"] >= 3, "Yes", "No")

# Set self comparisons to 'Yes'
disease_splits.loc[disease_splits["node_index"] == disease_splits["disease_split_index"], "llm_eval"] = "Yes"
_logger.info(disease_splits["llm_eval"].value_counts())

# For backward compatibility with existing code that expects 'gpt_eval' column
disease_splits["gpt_eval"] = disease_splits["llm_eval"]
disease_splits["gpt_rank"] = disease_splits["llm_rank"]

Compute total tokens used.

In [8]:
_logger.info("Total tokens used: %s", sum(disease_splits["tokens_used"]))

Save file to CSV.

In [9]:
disease_splits.to_csv(conf.paths.splits_dir / "disease_splits_GPT.csv", index=False)

## Save Disease Splits

Save each split to its own file.

In [10]:
split_dir = conf.paths.splits_dir / "split_edges_GPT"
if os.path.isdir(split_dir):
    shutil.rmtree(split_dir)
os.mkdir(split_dir)

In [11]:
# Get drug_disease_edges
drug_disease_edges = edges[(edges["x_type"] == "disease") & (edges["y_type"] == "drug")]

# Filter to GPT-4 evaluations of 'Yes'
disease_splits_filtered = disease_splits[disease_splits["gpt_eval"] == "Yes"]
disease_splits_grouped = disease_splits_filtered.groupby("disease_split_index")
edge_count = {}

for disease_split, disease_split_df in tqdm(disease_splits_grouped, desc="Save splits"):
    
    # Get indication edges
    disease_split_edges = drug_disease_edges[drug_disease_edges["x_index"].isin(disease_split_df["node_index"])]
    disease_split_edges = disease_split_edges.reset_index(drop=True)

    # If some edges exist
    if len(disease_split_edges) > 0:

        # Save to CSV
        disease_split_edges.to_csv(split_dir / f"{disease_split}.csv", index=False, encoding="utf-8-sig")
        edge_count[disease_split] = len(disease_split_edges)

    else:

        # Drop from disease_splits
        disease_splits_filtered = disease_splits_filtered[
            ~disease_splits_filtered["disease_split_index"].isin(disease_split_df["node_index"])
        ]
        tqdm.write(f"Removed split {disease_split} ({disease_split_df['disease_split'].values[0]}) as it has no edges.")

# Save all disease splits
all_split_edges = drug_disease_edges[drug_disease_edges["x_index"].isin(disease_splits_filtered["node_index"].unique())]
all_split_edges.to_csv(split_dir / "all.csv", index=False, encoding="utf-8-sig")

 ... (more hidden) ...

Removed split 34836 (Alexander disease) as it has no edges.


 ... (more hidden) ...
