# Clinical Practice Guideline (CPG) to Knowledge Graph

This notebook loads a section of the AAP's [Clinical Practice Guideline for Screening and Management of High Blood Pressure in Children and Adolescents](https://publications.aap.org/pediatrics/article/140/3/e20171904/38358/Clinical-Practice-Guideline-for-Screening-and?autologincheck=redirected) into a Graph DB and then uses an LLM to enrich it. The section used is "4.3 Patient Management on the Basis of Office BP". The text for this section is provided in `./BP_Ped_graph.txt`, it has been converted into Cypher in `./BP_Ped_graph.cy`.

You can then use the code in `./KG_to_BPM.ipynb` to create a BPMN XML from the graph. 

This notebook is configured to use Llama 3 via Ollama. It could be changed to use other LLMs (see Define Method to Call LLM). 

## Disclaimer
Nothing provided here is guaranteed or warrantied to work. It is provided as is. Using this notebook is at the risk of the user. 

The CPG used is published by the American Academy of Pediatrics and is solly their product. Nothing here should be inferred to supersede what it says. 

**Nothing here should be used for delivering care.** It is in no way certified by any credited medical organization or professional. It is provided solly for research purposes.

## SETUP

### 1. Neo4J
This notebook needs an instance of [Neo4J](https://www.neo4j.com) to talk to. I used Neo4J's desktop app, but you could use docker to run Neo4J locally using the following command:
```
docker run --name testneo4j -p7474:7474 -p7687:7687 -d \
    -v $HOME/neo4j/data:/data \
    -v $HOME/neo4j/logs:/logs \
    -v $HOME/neo4j/import:/var/lib/neo4j/import \
    -v $HOME/neo4j/plugins:/plugins \
    --env NEO4J_AUTH=neo4j/password \
    neo4j:latest
``` 

Or you can also use a Neo4J Aurora instance. 

### 2. Jupyter Environment for DB Connection

In order for the code to connect to the DB it needs certain environment variables to be defined in the Jupyter runtime. This avoids checking in credentials to GitHub. 

| Variable     | Description                          | Sample Value          |
|--------------|--------------------------------------|-----------------------|
| CPG_URL      | Where to find the instance of Neo4j. | bolt://localhost:7687 |
| CPG_USER     | The username for the database.       | neo4j                 |
| CPG_PASSWORD | The password for the database.       | password              |
| CPG_DATABASE | The name of the database instance.   | neo4j                 |


### 3. Ollama
This notebook uses [Ollama](https://ollama.ai/) to run LLM models locally. It could be modified to use OpenAI, sample code below, or any other LLM, left to user to do. To use this notebook as is, you will need ot install Ollama. 


## Install Python Libraries

This cell installs some needed libraries.

In [None]:
!pip3 install ollama
!pip3 install neo4j

## Import Python Libraries

In [None]:
import os
import re

import ollama

import NEO4J_Graph

## Connect to the Graph DB

This cell establishes the connection to the Graph DB (see Setup 1 & 2 above).

In [None]:
NEO4J_URI = os.getenv('CPG_URL')
USERNAME = os.getenv('CPG_USER')
PASSWORD = os.getenv('CPG_PASSWORD')
DATABASE = os.getenv('CPG_DATABASE')

graph = NEO4J_Graph.Graph(NEO4J_URI, USERNAME, PASSWORD, DATABASE)

## Delete All Nodes & Edges from the Graph DB

This cell will clear the DB. This is useful when repeating the process. 

In [None]:
graph.wipe_database()

## Load the CPG Section into Graph DB

This uses a Cypher script to load the section of the CPG into Graph DB.

In [None]:
with open('BP_Ped_graph.cy') as file:
    for statement in file.read().split(';\n'):
        if statement.strip():
            graph.query(statement)

graph.database_metrics()

## Find Relevant Nodes

The nodes we want to enrich are the ones that have text properties. This cell finds those nodes and any tables that the nodes point to. It then reduces duplicates. 

In [None]:
nodes_with_text, _ = graph.query('''
match (n)
where n.text is not null 
optional match (n)-[:REFERENCES]->(t) 
return n, t
order by n.order desc
''')

print(len(nodes_with_text))

def find_node(nodes, node):
    for i, a_node in enumerate(nodes):
        if a_node[0].element_id == node.element_id:
            return i
    return -1

nodes_reduced = []
for node in nodes_with_text:
    i = find_node(nodes_reduced, node[0])
    if i > -1:
        nodes_reduced[i].append(node[1])
    else:
        nodes_reduced.append(node)

print(len(nodes_reduced))

nodes_with_text = nodes_reduced

## Define Helper Functions

This cell contains some functions that will be useful later. 

In [None]:
def replace_table(node_text: str, tables)->str:
    """
    Replace table references with the actual text of the table. 
    
    :param node_text: The text that is in the node. 
    :param tables: The list of table nodes that may appear in the text.
    :return: The text with the table references replaced with table text.
    """
    for table in tables:
        node_text = node_text.replace(table, f'\n{tables[table]}\n')
    return node_text
        

def nodes_to_text(nodes)->str:
    """
    Turn a list of nodes into a text string that contains headers, text, and tables in places they are referenced. 
    
    :param nodes: List of nodes, including associated tables.
    :return: A string containing the headers, text, and tables in place. 
    """
    tables = {}
    for node in nodes:
        if node and node['table']:
            tables[node['heading']] = f'''{node["heading"]}: {node["name"]}
            {node["table"]}
            '''

    text = ''
    for node in nodes:
        if node:
            if node['table']:
                continue
            if node['heading']:
                text += f'{node["heading"]}: '
            if node['name']:
                text += f'{node["name"]}'
            if node['text']:
                text += f'\n{replace_table(node["text"], tables)}'
            text += '\n\n'
    return text

def parse_list(text:str):
    """
    Parse a list from some text. It assumes each element in the list 
    starts with '*' or a number. Lines that don't start with one of 
    those characters will be ignored.
    
    :param text: String contain a list.
    :return: A list of strings parsed from the text.
    """
    lines = text.split('\n')
    return list(
        map(lambda line: re.sub(r'^[*0-9][.]?', '', line).strip(),
            filter(lambda line: re.match(r'^[*0-9]', line.strip()), lines))
    )

## Define Method to Call LLM

This cell sets up to call Llama 3 using Ollama running locally. 

In [None]:
ollama_model = 'llama3'
def ask_llm(prompt:str)->str:
    llm_response = ollama.generate(model=ollama_model, prompt=prompt)
    return llm_response['response']

This cell is setup to use OpenAI. It requires an API Key be defined in the environment. It is commented out, but provided for experimenting. It has not been extensively tested, you may need to refine the imports and so on.

In [None]:
#!pip3 install langchain
# from langchain_openai import ChatOpenAI

# OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
# 
# llm = ChatOpenAI(openai_api_key=OPENAI_API_KEY, model_name='gpt-4')

# def ask_llm(prompt:str)->str:
#     llm_response = llm.invoke(prompt)
#     return llm_response.content

## ASK ABOUT FUNCTIONS

The first question asked is what is the JavaScript expression that would tell if the patient is in the section. 

In [None]:
# Helper to remove Function nodes if you want to repeat the ask without completely deleting the graph. 
graph.wipe_label('Function')

In [None]:
def prompt_for_function(node_text):
    return f'''
Given the following section of a Clinical Practice Guideline.

{node_text}

Create a JavaScript expression that returns true if this section applies to a patient. Create only one expression for this section. Do not create a function. 
Be careful about compound medical concepts. Blood pressure is made up of two components, systolic and diastolic, so the expression should compare them separately. For example, if the section applies to patients whose BP is greater than 130/80 then the expression should say:
    (systolic_blood_pressure > 130 && diastolic_blood_pressure > 80)
Be careful to consider age of the patient if needed. For example, if the section says it patients under 13 should use 120/70 but patients over 13 should 130/80 then the expression should be: 
    (age < 13 && (systolic_blood_pressure > 120 && diastolic_blood_pressure > 70)) || (age >= 13 && (systolic_blood_pressure > 130 && diastolic_blood_pressure > 80))
    
If an age range calls for percentiles, use the numbers from the other age range. For example, if the section says it patients under 13 should use 90th percentile but patients over 13 should 130/80 then the expression should be: 
    (age < 13 && (systolic_blood_pressure > 130 && diastolic_blood_pressure > 80)) || (age >= 13 && (systolic_blood_pressure > 130 && diastolic_blood_pressure > 80))
    
Make sure to include upper and lower bounds. For example, if the section applies to BP of 120/70 to 130/80 then the expression should say:
    ((systolic_blood_pressure >= 120 && diastolic_blood_pressure >= 70) && (systolic_blood_pressure < 130 && diastolic_blood_pressure < 80))
    
For example, if the section is for "Stage 1 HTN" then the expression should be:
    (age < 13 && (((systolic_blood_pressure >= 130 && diastolic_blood_pressure >= 80) && (systolic_blood_pressure < 140 && diastolic_blood_pressure < 90)))) && (age >= 13 && (((systolic_blood_pressure >= 130 && diastolic_blood_pressure >= 80) && (systolic_blood_pressure < 140 && diastolic_blood_pressure < 90))))
    
For example, if the section is for "Normal BP", which has no lower bound, then the expression should be:
    (age < 13 && (systolic_blood_pressure <= 120 && diastolic_blood_pressure <= 80)) && (age >= 13 && (systolic_blood_pressure <= 120 && diastolic_blood_pressure <= 80))
    
Please answer with only the expression, do not include description or tell me that this is the expression. Only show me the Javascript.
'''

def create_function(function, parent_node, _graph:NEO4J_Graph.Graph):
    cypher = f'''
        match (parent) 
        where elementId(parent) = '{parent_node.element_id}'
        create (parent)-[:FUNCTION]->(:Function {{name: '{_graph.escape_string(function)}', function: '{_graph.escape_string(function)}'}})
    '''
    _graph.query(cypher)

print(graph.database_metrics())
for node_group in nodes_with_text:
    node_text = nodes_to_text(node_group)
    prompt = prompt_for_function(node_text)
    # print('PROMPT:')
    # print(prompt)
    answer = ask_llm(prompt)
    print('\nANSWER:')
    print(answer)
    print('\n------\n\n')
    create_function(answer, node_group[0], graph)
print(graph.database_metrics())

## ASK ABOUT INPUTS to FUNCTIONS

The second question is to ask what are the inputs to the functions captured above. 

In [None]:
# Helper to remove Input nodes if you want to repeat the ask without completely deleting the graph.
graph.wipe_label('Input')

In [None]:
function_nodes, _ = graph.query('''
match (n:Function)
return n
''')

def prompt_for_inputs_for_functions(node_text):
    return f'''
The following is an expression that determines if a patient qualifies for a section of a Clinical Practice Guideline:

{node_text}

Please tell me the list of variables needed to evaluate this expression and the associated human readable question to ask the data. 
The results should be in a list with 
  * "variable_name": "question"
for example, if the expression requires patient_height, the result should include:
  * patient_height: What is the height of the patient?

The list should only be based on what is explicitly called for in the expression.
Questions should be as short as possible. Do not include any reasons or explanations.
Please do not use any abbreviations in the questions, so BP should be blood pressure. 
'''

def create_inputs_for_functions(names, parent_node, _graph:NEO4J_Graph.Graph):
    for name in names:
        split_name = name.split(":")
        cypher = f'''
            match (parent) 
            where elementId(parent) = '{parent_node.element_id}'
            create (parent)<-[:INPUT]-(:Input {{name: '{_graph.escape_string(split_name[0].strip())}', question: '{_graph.escape_string(split_name[1].strip())}'}})
        '''
        _graph.query(cypher)

print(graph.database_metrics())
for function_node in function_nodes:
    prompt = prompt_for_inputs_for_functions(function_node[0]["function"])
    # print('PROMPT:')
    # print(prompt)
    answer = ask_llm(prompt)
    print('\nANSWER:')
    print(answer)
    print('\n------\n\n')
    input_list = parse_list(answer)
    create_inputs_for_functions(input_list, function_node[0], graph)
print(graph.database_metrics())

## ASK TO REDUCE INPUTS and GIVE HUMAN READABLE QUESTION

This section asks the LLM to reduce the list from every function into a smaller list of global inputs. At the sametime it asks the LLM to generate a human readable question that asks for the value of that input. 

In [None]:
# Helper to remove Global Input nodes if you want to repeat the ask without completely deleting the graph.
graph.wipe_label('GlobalInput')

In [None]:
input_nodes, _ = graph.query('''
match (n:Input)
return n
''')

root_node, _ = graph.query('''
match (n:Header1)
return n
''')

root_node = root_node[0][0]

def create_global_inputs_for_functions(questions, parent_node, _graph:NEO4J_Graph.Graph):
    for question in questions:
        split_question = question.split(":")
        cypher = f'''
            match (parent) 
            where elementId(parent) = '{parent_node.element_id}'
            create (parent)<-[:GLOBAL_INPUT]-(:GlobalInput {{name: '{_graph.escape_string(split_question[0].strip())}', question: '{_graph.escape_string(split_question[1].strip())}'}})
        '''
        _graph.query(cypher)

inputs = ''
for input_node in input_nodes:
    inputs += f'* {input_node[0]["name"]}: {input_node[0]["question"]}\n'

print('\nINPUTS:')
print(inputs)
print('\n------\n\n')

prompt_for_in_list = f'''
The following is a list of questions:
{inputs}
Please remove the duplicates from this list.
Also, consider duplicates that are close in concept, such as "how old" and "age".
The results should be in a list with 
  * "variable_name": "question"
for example, if the expression requires patient_height, the result should include:
  * patient_height: What is the height of the patient?
Please do not add questions, only return questions from the list.
'''

print(graph.database_metrics())
response = ask_llm(prompt_for_in_list)
print('\nANSWER:')
print(response)
print('\n------\n\n')
input_list = parse_list(response)
create_global_inputs_for_functions(input_list, root_node, graph)
print(graph.database_metrics())


## ASK ABOUT OUTPUTS

In this section we ask about what steps should be taken if a patient falls in a section of the CPG. 

In [None]:
# Helper to remove Output nodes if you want to repeat the ask without completely deleting the graph.
graph.wipe_label('Output')

In [None]:
def prompt_for_output(node_text):
    return f'''
Given the following section of a Clinical Practice Guideline.

{node_text}

Assuming this section of the guideline applies to a patient, please list the actions that should be taken?
The list should be in the form of:
 1. do this
 2. do that
 
Items in the list should be as short as possible. Do not explain.
Please keep the list as short as you can. 
Do not invent anything not in the supplied guideline section. In other words only include actions indicated in the supplied text.  
'''

def create_output(actions, parent_node, _graph:NEO4J_Graph.Graph):
    for i, action in enumerate(actions):
        cypher = f'''
            match (parent) 
            where elementId(parent) = '{parent_node.element_id}'
            create (parent)-[:OUTPUT]->(:Output {{name: '{_graph.escape_string(action)}', order: {i}, output: '{_graph.escape_string(action)}'}})
        '''
        _graph.query(cypher)

print(graph.database_metrics())
for node_group in nodes_with_text:
    prompt = prompt_for_output(nodes_to_text(node_group))
    # print('PROMPT:')
    # print(prompt)
    answer = ask_llm(prompt)
    print('\nANSWER:')
    print(answer)
    print('\n------\n\n')
    list_actions = parse_list(answer)
    create_output(list_actions, node_group[0], graph)
print(graph.database_metrics())

**Disclaimer:** Nothing provided here is guaranteed or warrantied to work. It is provided as is and has not been tested extensively. Using this notebook is at the risk of the user. Further, it is provided for research only and should not be used for treatment. 

Copyright &copy; 2024 Sam Schifman