In [65]:
!pip install nx-arangodb



In [66]:
import torch
print(torch.__version__)  # Should output 2.5.1+cu124

import networkx as nx
print(nx.__version__)     # Should output 3.4.2

2.5.1+cu124
3.4


In [67]:
# 2. Check if you have an NVIDIA GPU
# Note: If this returns "command not found", then GPU-based algorithms via cuGraph are unavailable

!nvidia-smi
!nvcc --version

Mon Mar 10 04:45:32 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   32C    P8              9W /   70W |       2MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [68]:
!pip install nx-cugraph-cu12 --extra-index-url https://pypi.nvidia.com # Requires CUDA-capable GPU

Looking in indexes: https://pypi.org/simple, https://pypi.nvidia.com


In [69]:
# prompt: install cugraph
!pip install cugraph-cu12 --extra-index-url=https://pypi.nvidia.com

Looking in indexes: https://pypi.org/simple, https://pypi.nvidia.com


In [70]:
!pip install langchain langgraph langchain-openai langchain-community langgraph



In [71]:
!pip install langgraph.prebuilt



In [72]:
!pip install nx-arangodb # Make sure nx_arangodb is installed
import nx_arangodb as nxadb # Import the libra



In [109]:
from arango import ArangoClient

# Initialize the ArangoDB client.
client = ArangoClient(hosts='https://ac057b4e6882.arangodb.cloud:8529')

# Connect to "_system" database as root user.
# This returns an API wrapper for "_system" database.
#sys_db = client.db('_system', username='root', password='KidVsaLuMx41AgOrCKlI')
'''
if not sys_db.has_database('MediGraph'):
    sys_db.create_database('MediGraph')
'''
# Connect to "test" database as root user.
# This returns an API wrapper for "test" database.
db = client.db('MediGraph', username='root', password='KidVsaLuMx41AgOrCKlI')

In [74]:
G_adb = nxadb.MultiGraph(name="MediGraph", db=db)

print(G_adb)

[04:45:47 +0000] [INFO]: Graph 'MediGraph' exists.
INFO:nx_arangodb:Graph 'MediGraph' exists.
[04:45:48 +0000] [INFO]: Default node type set to 'MediGraph_node'
INFO:nx_arangodb:Default node type set to 'MediGraph_node'


MultiGraph named 'MediGraph' with 8225 nodes and 75677 edges


In [75]:

import requests

# Hugging Face API Key (Replace with Your Own)
HF_API_KEY = "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"

# Define the API URL for a free model (Mistral-7B, Falcon, etc.)
HF_MODEL = "tiiuae/falcon-7b-instruct"  # Example: Free LLM
API_URL = f"https://api-inference.huggingface.co/models/{HF_MODEL}"

# Headers with API Key
headers = {"Authorization": f"Bearer {HF_API_KEY}"}

# Function to Call LLM
def call_llm(prompt):
    payload = {"inputs": prompt}
    response = requests.post(API_URL, headers=headers, json=payload)
    return response.json()

In [76]:
query = "What are the symptoms of diabetes?"
response = call_llm(query)

# Print Response
print(response)

[{'generated_text': 'What are the symptoms of diabetes?\nThe most common symptoms of diabetes include increased thirst and hunger, fatigue, frequent urination, blurry vision, slow healing sores or cuts, and unexplained weight loss. In addition, some people with diabetes also experience sexual dysfunction.'}]


In [77]:
from langchain.chat_models import ChatAnthropic
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.tools import Tool
from langgraph.graph import StateGraph, END
from transformers import pipeline
from arango import ArangoClient
import networkx as nx
import cugraph
import matplotlib.pyplot as plt


In [78]:
# Query State Class
class QueryState(dict):
    def __init__(self, query):
        super().__init__(query=query, result=None, execution_path=None)
        self["query"] = query
        self["result"] = None
        self["execution_path"] = None
        self.complex_query_keywords = [
            "community", "cluster", "network", "influence", "central",
            "risk", "readmission", "prediction", "similar", "connected",
            "pagerank", "louvain", "analytics", "path", "relationship"
        ]

    def is_complex_query(self) -> bool:
        """Determine if a query is complex based on keywords & structure."""
        nl_lower = self["query"].lower()
        if any(keyword in nl_lower for keyword in self.complex_query_keywords):
            return True
        entity_keywords = ["patient", "condition", "medication", "treatment"]
        entity_count = sum(1 for keyword in entity_keywords if keyword in nl_lower)
        return entity_count >= 2


In [79]:
# 🔹 Query Classification
def classify_query(state: dict) -> dict:
    """Classifies the query and selects the execution path."""
    if state["execution_path"] is None:  # Ensure execution_path is set
        state["execution_path"] = "cuGraph" if QueryState(state["query"]).is_complex_query() else "AQL"
    if "patient" in state["query"].lower() and "influence" in state["query"].lower():
        state["execution_path"] = "Hybrid"
    print(f"Query classified as: {state['execution_path']}")
    return state  # ✅ Return dict, not object


In [90]:
def execute_aql(state: dict) -> dict:
    """Executes an AQL query generated from natural language."""

    # 1️⃣ Convert NL Query → AQL
    aql_query = nl_to_aql(state["query"])

    if not aql_query:
        state["result"] = "AQL Conversion Error: No valid patient ID found."
        return state  # Return error state

    print("\n🔍 Executing AQL Query:\n", aql_query)  # Debugging log

    try:
        # 2️⃣ Execute AQL Query
        cursor = db.aql.execute(aql_query)
        results = list(cursor)  # Convert cursor to list

        # 3️⃣ Store results in state
        state["result"] = results if results else []
        print("✅ AQL Results:", state["result"])

    except Exception as e:
        print("❌ AQL Execution Error:", str(e))
        state["result"] = f"AQL Execution Error: {str(e)}"

    return state



In [81]:
# 🔹 Execute cuGraph Query (GPU-Accelerated Analysis)
def execute_cugraph(state: dict) -> dict:
    """Executes cuGraph analytics for influence ranking, clustering, etc."""
    G = nx.erdos_renyi_graph(200, 0.05)
    pr = nx.pagerank(G)
    state["result"] = sorted(pr.items(), key=lambda x: x[1], reverse=True)[:5]
    return state  # ✅ Return dict


In [82]:
# 🔹 Hybrid Execution: Combine AQL & cuGraph
def execute_hybrid(state: dict) -> dict:
    """Executes AQL to fetch data and cuGraph for analytics on the retrieved data."""
    aql_query = """
    FOR v, e IN 1..2 OUTBOUND 'patients/Alice' GRAPH 'MediGraph'
        RETURN { 'name': v.name, 'type': v.type }
    """
    cursor = db.aql.execute(aql_query)
    connections = list(cursor)
    G = nx.erdos_renyi_graph(200, 0.05)
    pr = nx.pagerank(G)
    ranked_connections = [(entry['name'], pr.get(entry['name'], 0)) for entry in connections]
    ranked_connections.sort(key=lambda x: x[1], reverse=True)
    state["result"] = ranked_connections[:5]
    return state  # ✅ Return dict


In [83]:
# 🔹 Call Hugging Face API for LLM Response
def call_llm(prompt: str) -> str:
    """Calls the Hugging Face inference API to generate a response."""
    payload = {"inputs": prompt}
    response = requests.post(API_URL, headers=headers, json=payload)
    return response.json()

# 🔹 Generate Response Using LLM
def generate_response(state: dict) -> dict:
    """Generates a natural language response using an open-source LLM."""
    context = f"Query: {state['query']}\nResults: {state['result']}"
    response = call_llm(context)
    generated_text = response[0]['generated_text'] if isinstance(response, list) else "Error in response"
    state["generated_response"] = generated_text
    return state  # ✅ Return dict

# 🔹 Execute a Query
def execute_medical_query(query: str):
    """Executes the MediGraphAI pipeline for a given query."""
    agent_executor = graph.compile()
    state = agent_executor.invoke({"query": query, "execution_path": None, "result": None})
    return state

In [145]:
# prompt: a simple function to convert nl queries to aql ( accoring to the structure of medigraph< the collections are : providers
# Procedures
# payers
# patients
# Patients
# organizations
# Medications
# Encounters
# Conditions

def nl_to_aql(nl_query):
    """
    Converts a natural language medical query into an AQL query for retrieving all connected data.

    Args:
        nl_query: Natural language query string
        query_config: Optional parameters for query customization:
            - max_depth: Depth of traversal (default: 2)
            - graph_name: Name of the graph in ArangoDB (default: 'MediGraph')
            - collection_name: Collection storing medical nodes (default: 'MediGraphNode')

    Returns:
        AQL query string or None if validation fails.
    """

    # Simple keyword-based mapping (expand this for more complex queries)
    aql_query = ""
    if "patient" in nl_query.lower():
      aql_query = "FOR patient IN patients RETURN patient" # Placeholder: Replace with actual AQL
      #Add more complex query structures
    elif "provider" in nl_query.lower():
        aql_query = "FOR provider IN providers RETURN provider"  # Placeholder
    # ... more mappings
    else:
        return None  # Indicate failure to convert
    return aql_query


In [93]:
# 🔹 Build LangGraph Workflow
graph = StateGraph(dict)

# Add Nodes
graph.add_node("classify", classify_query)
graph.add_node("aql_execution", execute_aql)
graph.add_node("cugraph_execution", execute_cugraph)
graph.add_node("hybrid_execution", execute_hybrid)
graph.add_node("response", generate_response)

# 🔹 Function to Determine Execution Path
def path_selector(state: dict):
    """Selects the next step based on execution path."""
    return state["execution_path"] if state["execution_path"] in ["AQL", "cuGraph", "Hybrid"] else END

# 🔹 Use add_conditional_edges() for Dynamic Routing
graph.add_conditional_edges("classify", path_selector)

# 🔹 Connect Execution Nodes to Response
graph.add_edge("aql_execution", "response")
graph.add_edge("cugraph_execution", "response")
graph.add_edge("hybrid_execution", "response")

# Set Start and End Points
graph.set_entry_point("classify")
graph.set_finish_point("response")


<langgraph.graph.state.StateGraph at 0x7b7aa1e96f10>

In [146]:
# 🔹 Example Query Execution
query = "list encounters of patients/0fef2411-21f0-a269-82fb-c42b55471405"

state = {"query": query, "execution_path": "AQL", "result": None}

state = execute_aql(state)
print(state)


🔍 Executing AQL Query:
 FOR patient IN patients RETURN patient
✅ AQL Results: [{'_key': '30a6452c4297a1ac977a6a23237c7b46', '_id': 'patients/30a6452c4297a1ac977a6a23237c7b46', '_rev': '_jTzVsci---', 'type': 'Patient', 'Id': '30a6452c-4297-a1ac-977a-6a23237c7b46', 'BIRTHDATE': '1994-02-06', 'DEATHDATE': None, 'SSN': '999-52-8591', 'DRIVERS': 'S99996852', 'PASSPORT': 'X47758697X', 'PREFIX': 'Mr.', 'FIRST': 'Joshua658', 'MIDDLE': 'Alvin56', 'LAST': 'Kunde533', 'SUFFIX': None, 'MAIDEN': None, 'MARITAL': 'M', 'RACE': 'white', 'ETHNICITY': 'nonhispanic', 'GENDER': 'M', 'BIRTHPLACE': 'Boston  Massachusetts  US', 'ADDRESS': '811 Kihn Viaduct', 'CITY': 'Braintree', 'STATE': 'Massachusetts', 'COUNTY': 'Norfolk County', 'FIPS': 25021, 'ZIP': 2184, 'LAT': 42.21114202874998, 'LON': -71.0458021760648, 'HEALTHCARE_EXPENSES': 56904.96, 'HEALTHCARE_COVERAGE': 18019.99, 'INCOME': 100511}, {'_key': '34a4dcc435fb6ad5ab98be285c586a4f', '_id': 'patients/34a4dcc435fb6ad5ab98be285c586a4f', '_rev': '_jTzVsey-

In [147]:
display_results(state)


--- Query Results ---
Query: list encounters of patients/0fef2411-21f0-a269-82fb-c42b55471405
Execution Path: AQL

Results (tabular format):
_key | _id | _rev | type | Id | BIRTHDATE | DEATHDATE | SSN | DRIVERS | PASSPORT | PREFIX | FIRST | MIDDLE | LAST | SUFFIX | MAIDEN | MARITAL | RACE | ETHNICITY | GENDER | BIRTHPLACE | ADDRESS | CITY | STATE | COUNTY | FIPS | ZIP | LAT | LON | HEALTHCARE_EXPENSES | HEALTHCARE_COVERAGE | INCOME
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
30a6452c4297a1ac977a6a23237c7b46 | patients/30a6452c4297a1ac977a6a23237c7b46 | _jTzVsci--- | Patient | 30a6452c-4297-a1ac-977a-6a23237c7b46 | 1994-02-06 | None | 999-52-8591 | S99996852 | X47758697X | Mr. | Joshua658 | Alvin56 | Kunde533 | None | None | M | white | no

In [130]:
# prompt: write a function to display the results in a good way presentable

def display_results(state):
    """Displays the results of a query in a user-friendly format."""

    print("\n--- Query Results ---")
    print(f"Query: {state['query']}")

    if "execution_path" in state:
        print(f"Execution Path: {state['execution_path']}")

    if "result" in state:
        if isinstance(state["result"], list):  # Check if results are a list
            if state["result"] and isinstance(state["result"][0], dict): # Check if list of dictionaries
                print("\nResults (tabular format):")
                # Determine the columns dynamically for the first item in the list
                if state["result"]:
                  headers = list(state["result"][0].keys())
                  # Print the header row
                  print(" | ".join(headers))
                  print("-" * (sum(len(header) + 3 for header in headers) - 3)) # Separator

                  for item in state["result"]:
                      row_values = [str(item.get(header, '')) for header in headers]
                      print(" | ".join(row_values))

            elif state["result"] and isinstance(state["result"][0], tuple):
              print("Results (tuples):")
              for item in state["result"]:
                  print(item)
            else:  # Handle other list formats or empty list
                print("\nResults:")
                for item in state["result"]:
                    print(item)
        elif isinstance(state["result"], str):  # Handle string results
            print("\nResult:")
            print(state["result"])
        else:  # For other result types
            print("\nResult:")
            print(state["result"])
    else:
        print("No results found.")

    if "generated_response" in state:
        print("\n--- Generated Response ---")
        print(state["generated_response"])

    print("--- End of Results ---\n")
