# FHIR on RAG

This notebook is loading FHIR resources into a vector store and then using that to help prompt an LLM to answer questions about the data. To do that, it first flattens the FHIR resources into text files. It then uses [LlamaIndex](https://www.llamaindex.ai/) to load the text files into an in-memory vector store. Then it calls out to a LLama 2 running locally using [Ollama](https://ollama.ai/) using different [strategies](https://docs.llamaindex.ai/en/stable/module_guides/querying/response_synthesizers/root.html) for combining the FHIR with the question into the prompt.

In [None]:
# Some constants to use throughout 

in_file_glob = './working/raw_fhir/*.json'
flat_file_path = './working/flat'
vector_store_file_path = './working/vector_store'

## Flatten FHIR

This is going to read in any JSON files in the `in_file_glob`. It assumes that each file is a FHIR Bundle. It will first pull out the Patient resource and extract some key information, like name, from it to include in the text files it will create per resource. This helps the RAG know which patient a resource goes with. It then flattens each resource in the bundle. 

Flattening it means that it creates a path of all the attribute names from the root of the resource to each value. In the process it splits any camel case words into multiple words. Finally, it writes this out to a text file in the structure of:
``` [path name] is [value]. ```
This creates a semi-english version of the resource that can be turned into a vector by the embedding. 

**To use this project,** you will need to create the working and raw_fhir directories and populate raw_fhir with FHIR Bundles. I used [Synthea](https://synthea.mitre.org/) to generate synthetic data in my testing.

In [None]:
import glob
import os
import json
import re

camel_pattern1 = re.compile(r'(.)([A-Z][a-z]+)')
camel_pattern2 = re.compile(r'([a-z0-9])([A-Z])')


def split_camel(text):
    new_text = camel_pattern1.sub(r'\1 \2', text)
    new_text = camel_pattern2.sub(r'\1 \2', new_text)
    return new_text


def handle_special_attributes(attrib_name, value):
    if attrib_name == 'resource Type':
        return split_camel(value)
    return value


def flatten_fhir(nested_json):
    out = {}

    def flatten(json_to_flatten, name=''):
        if type(json_to_flatten) is dict:
            for sub_attribute in json_to_flatten:
                flatten(json_to_flatten[sub_attribute], name + split_camel(sub_attribute) + ' ')
        elif type(json_to_flatten) is list:
            for i, sub_json in enumerate(json_to_flatten):
                flatten(sub_json, name + str(i) + ' ')
        else:
            attrib_name = name[:-1]
            out[attrib_name] = handle_special_attributes(attrib_name, json_to_flatten)

    flatten(nested_json)
    return out


def filter_for_patient(entry):
    return entry['resource']['resourceType'] == "Patient"


def find_patient(bundle):
    patients = list(filter(filter_for_patient, bundle['entry']))
    if len(patients) < 1:
        raise Exception('No Patient found in bundle!')
    else:
        patient = patients[0]['resource']

        patient_id = patient['id']
        first_name = patient['name'][0]['given'][0]
        last_name = patient['name'][0]['family']

        return {'PatientFirstName': first_name, 'PatientLastName': last_name, 'PatientID': patient_id}


def flat_to_string(flat_entry):
    output = ''

    for attrib in flat_entry:
        output += f'{attrib} is {flat_entry[attrib]}. '

    return output


def flatten_bundle(bundle_file_name):
    file_name = bundle_file_name[bundle_file_name.rindex('/') + 1:bundle_file_name.rindex('.')]
    with open(bundle_file_name) as raw:
        bundle = json.load(raw)
        patient = find_patient(bundle)
        flat_patient = flatten_fhir(patient)
        for i, entry in enumerate(bundle['entry']):
            flat_entry = flatten_fhir(entry['resource'])
            with open(f'{flat_file_path}/{file_name}_{i}.txt', 'w') as out_file:
                out_file.write(f'{flat_to_string(flat_patient)}\n{flat_to_string(flat_entry)}')


if not os.path.exists(flat_file_path):
    os.mkdir(flat_file_path)

for file in glob.glob(in_file_glob):
    flatten_bundle(file)

## Setup the Gen AI with RAG

This section will use LlamaIndex to construct the vector store and tie to the LLM. 

In [None]:
!pip install llama-index
!pip install transformers

I tried a couple of different models for doing the embedding, i.e. turning the flattened FHIR text into vectors. I would like to experement with others, but haven't had time. In the end, `BAAI/bge-large-en-v1.5` was too big for me to run on my local, so I did most of my testing with `BAAI/bge-small-en-v1.5`.

In [None]:
from llama_index.embeddings import HuggingFaceEmbedding

# loads BAAI/bge-small-en
# embed_model = HuggingFaceEmbedding()

embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")

# embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5")

# embed_model = HuggingFaceEmbedding(model_name="medicalai/ClinicalBERT")

In [None]:
# This code is to play with embedding if desired. It is not needed and can remain commented out.

# embeddings = embed_model.get_text_embedding("Hello World!")
# print(len(embeddings))
# print(embeddings[:5])

In [None]:
from llama_index.llms import Ollama

# LLama 2 is running locally, using Ollama.
llm = Ollama(model="llama2")

In [None]:
# This is a test prompt, just to prove that Ollama is working. It can remain commented out.

# resp = llm.complete("Who is Paul Graham?")
# print(resp)

In [None]:
from llama_index import ServiceContext, VectorStoreIndex, SimpleDirectoryReader, SummaryIndex
from llama_index import set_global_service_context

service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model)
set_global_service_context(service_context)

In [None]:
# This code loads the flat FHIR text files. 

documents = SimpleDirectoryReader(flat_file_path).load_data()
print(len(documents))

In [None]:
# Load those flat FHIR text files into the vector store.

vector_index = VectorStoreIndex.from_documents(documents, show_progress=True)


# if not os.path.exists(vector_store_file_path):
#     os.mkdir(vector_store_file_path)
# vector_index.vector_store.persist(f'{vector_store_file_path}/FHIR_RAG.vs')

In [None]:
# I tried to play with summary indexes, but it took too long to get a response on my machine. 

# summary_index = SummaryIndex.from_documents(documents)

In [None]:
from llama_index.response.notebook_utils import display_response
import logging
import sys
from IPython.core.display import Markdown

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

## Actually do RAG

This is the code block that actually asks the questions of the LLM. 

In my tests, I used synthetic FHIR generated by [Synthea](https://github.com/synthetichealth/synthea/wiki/Basic-Setup-and-Running). I had Synthea generate two patients and so I asked questions about each patient. If you are looking to replicate this work, you will need to change the patient names to match what ever data you have available. 

In [None]:
def display_source_text(response):
    for ind, source_node in enumerate(response.source_nodes):
        display(Markdown("---"))
        display(Markdown(f"**`Source Node {ind + 1}/{len(response.source_nodes)}`**"))
        text_md = (
            f'**File:** {source_node.node.metadata["file_name"]}<br>'
            f'**Text:** {source_node.node.get_content().strip()}'
        )
        display(Markdown(text_md))


def ask_question(index, response_mode, question, show_sources=False):
    query_engine = index.as_query_engine(response_mode=response_mode, similarity_top_k=5)
    response = query_engine.query(question)
    display(Markdown(f'### Answer for {response_mode}'))
    if show_sources:
        display_source_text(response)
    else:
        display_response(response, show_source=False, show_metadata=False, show_source_metadata=False)


def ask_question_all_modes(person, question):
    display(Markdown(f'# Asking about {person}\n<br>**Question:** {question}'))
    ask_question(vector_index, 'no_text', question, show_sources=True)
    ask_question(vector_index, 'simple_summarize', question)
    ask_question(vector_index, 'compact', question)
    ask_question(vector_index, 'refine', question)
    ask_question(vector_index, 'tree_summarize', question)
    ask_question(vector_index, 'accumulate', question)
    ask_question(vector_index, 'compact_accumulate', question)


ask_question_all_modes('Arnold',
                       'What can you tell me about Arnold338 Wilkinson796 heart? For example, does he have hypertension?')
ask_question_all_modes('Ashley', 'What can you tell me about Ashley34 Bergstrom287 allergies?')