# Few-Shot Prompting: Fall Detection with Open-Source Large Language Models

In [2]:
import pandas as pd
import numpy as np
import re

import os
from langchain_community.llms import Ollama
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
import time

In [None]:
print('Loading data...')
dataset = pd.read_json('INSERT DATASET PATH')
# required columns:
## pat_deid: patient ID
## note_deid: note ID
## effective_time: note time
## min_surg_date: surgery date
## regex_chunked_note: note chunked with regular expression protocol
## label: binary label (fall, no fall)

## Model, Prompt, and Chain

In [None]:
model_name = "mixtral:8x7b"
# other models used: "gemma:7b", "llama3"

In [None]:
llm = Ollama(model=model_name,
             callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]),
             temperature=0.0,
             top_p=0.25,
             num_predict=100)

In [None]:
# The few-shot prompt examples were removed as they contained protected health information. Please replace the [BRACKETED TEXT] with your examples.

prompt = PromptTemplate(
    input_variables=["note"],
    template = """You are a clinician who is reading a clinical note and looking for fall events. You are noting whether a patient fell or not after surgery. Please output 1 if the patient fell or 0 if the patient did not fall. Please note that historical falls, fall risk/precautions, or other miscellaneous mentions of falls like blood pressure falling are not fall events and the output should be 0 unless a fall event is also indicated in the note. 

    Here is an example of a note containing a fall event after surgery: 
    [INSERT NOTE EXAMPLE HERE]
    EXPLANATION: [INSERT EXAMPLE EXPLANATION HERE].; OUTPUT: 1.
    
    Here is an example of a note containing a fall event: 
    [INSERT NOTE EXAMPLE HERE]
    EXPLANATION: [INSERT EXAMPLE EXPLANATION HERE.]; OUTPUT: 1.
    
    Here is an example of a note that does not contain a fall event: 
    [INSERT NOTE EXAMPLE HERE]
    EXPLANATION: [INSERT EXAMPLE EXPLANATION HERE.]; OUTPUT: 0.
    
    Please provide your response in the following format- EXPLANATION: ; OUTPUT: .
    Clinical note: {note}. 
    Response: 
    """,
    stopwords=["\n"],
    max_tokens=100
)

chain = LLMChain(llm=llm, prompt=prompt)

## Functions

In [None]:
# This function takes each note and appends a sentence stating when a note was written and when the patient had surgery. It then runs the full prompt through the chain.
def detect_falls(prompt, note_date, surg_date, note):
    timing = "This note was written on %s. The patient had surgery on %s. " % (note_date, surg_date)
    prompt_note = timing+note
    
    response=chain.invoke(prompt_note)
    return response

In [None]:
# This function extracts the output from the model response.
def extract_label(response):
    pattern = r"OUTPUT: [01]"
    matches = re.finditer(pattern, response.upper())

    results=[]
    for match in matches:
        results.append((match.group(), match.start())) 
        # [('OUTPUT: 0', 140), ('OUTPUT: 1', 332)]
        
    if not results:
        label = response
        print(label)
    else:
        label = int(results[-1][0][-1])
    
    return label

## Few-shot Run

In [None]:
# This runs the notes through the chain once. Please adapt for running 5 times as was described in the manuscript.

start_time = time.time()
col_name = 'mixtral_7b_response'
for ind, x in dataset.iterrows():
    print(ind, '--- ')
    s_time = time.time()
    output = detect_falls(prompt, x['effective_time'], x['surg_date'], x['regex_chunked_note'])
    dataset.loc[ind, col_name] = output
    print('-->', time.time() - s_time)
    
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Run took {elapsed_time:.2f} seconds")

### Output Processing

In [None]:
dataset['mixtral_7b_label'] = dataset['mixtral_7b_response'].apply(extract_label)
dataset['mixtral_7b_label'] = dataset['mixtral_7b_label'].astype(int)
dataset['mixtral_7b_label'].value_counts()

### Performance Evaluation

In [None]:
from sklearn.metrics import classification_report, roc_auc_score
print(classification_report(chunk_results.label, chunk_results.mixtral_7b_label))
print(roc_auc_score(chunk_results.label, chunk_results.mixtral_7b_label))