In [3]:
# load data from csv
# path for input and target data tables

diagnosis_path = '../data/diagnosis_hadm.csv'
discharge_path ='../data/discharge.csv'
edstays_path = '../data/edstays.csv'
radiology_path = '../data/radiology.csv'
triage_path = '../data/triage.csv'
target_path = '../data/discharge_target.csv'
discharge_sections_path = '../data/discharge_sections.csv'
radiology_sections_path = '../data/radiology_sections.csv'

import pandas as pd
import re
import os

# read data
diagnosis_df = pd.read_csv(diagnosis_path, keep_default_na=False)
discharge_df = pd.read_csv(discharge_path, keep_default_na=False)
edstays_df = pd.read_csv(edstays_path, keep_default_na=False)
radiology_df = pd.read_csv(radiology_path, keep_default_na=False)
triage_df = pd.read_csv(triage_path, keep_default_na=False)
target_df = pd.read_csv(target_path, keep_default_na=False)

discharge_sections_df = pd.read_csv(discharge_sections_path, keep_default_na=False)
radiology_sections_df = pd.read_csv(radiology_sections_path, keep_default_na=False)

In [None]:
# embed data frames into a vector store.  Meta data should include hadm id and column name for retrieval

In [5]:
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM
import pandas as pd

# Load the model and tokenizer
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
def query_with_llama_raj(question, dataframes, column_lists, hadm_id, output_column):
    
    context_data = ""

    # Iterate through each DataFrame and its corresponding columns to create the context
    for df, cols in zip(dataframes, column_lists):
        # Filter the dataframe for the specific hadm_id
        target_df = df[df['hadm_id'] == hadm_id]
        if not target_df.empty:
            # Create a context string from the specified columns
            context_part = target_df[cols].apply(lambda x: ' '.join(x.dropna().astype(str)), axis=1).str.cat(sep=' ')
            context_data += " " + context_part
    
    if not context_data:
        dataframes[0].loc[dataframes[0]['hadm_id'] == hadm_id, output_column] = "No records found for the given HADM ID."
        return dataframes[0]

    # Combine the question with the aggregated context data
    inputs = tokenizer(question + " " + context_data, return_tensors="pt", truncation=True, max_length=1024)

    # Generate the response using the model
    with torch.no_grad():  # Disable gradient calculation for inference
        outputs = model.generate(**inputs, max_length=1024)
    
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
    
    # Append the answer to the first DataFrame
    dataframes[0].loc[dataframes[0]['hadm_id'] == hadm_id, output_column] = answer
    return dataframes[0]

# Example usage
question = "Summarize the radiological tests and findings"
hadm_id = 24962904  # Specify the HADM ID you want to query
dfs = [discharge_sections_df, radiology_sections_df]  # List of DataFrames
relevant_cols = [['Pertinent Results'], ['EXAMINATION','INDICATION','IMPRESSION']]  # List of column lists for each DataFrame
output_col_name = 'radiology tests summary'  # Specify the name of the new column

# Assuming df1 is the primary DataFrame where the output should be stored
df1 = query_with_llama_raj(question, dfs, relevant_cols, hadm_id, output_col_name)
print(df1.loc[df1['hadm_id'] == hadm_id, output_col_name])

0    Summarize the radiological tests and findings ...
Name: radiology tests summary, dtype: object


In [11]:
def question_tester(prompt):
    
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=len(prompt))
    max_input_length = min(1024 + len(prompt), tokenizer.model_max_length) 
    # Generate the response using the model
    with torch.no_grad():  # Disable gradient calculation for inference
        outputs = model.generate(**inputs, max_length=max_input_length)
    
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
    answer = answer[len(prompt):] #strip out prompt echo
    return answer

# prepare prompt

hadm_id = 24962904
# Specify the columns to use
df = discharge_sections_df
cols = ['HPI']
# Filter DataFrames for the given HADM ID
discharge_info = df[df['hadm_id'] == hadm_id][cols].fillna('').agg(' '.join, axis=1).values[0]

df = diagnosis_df
cols = ['icd_title']
diagnosis_info = df[df['hadm_id'] == hadm_id][cols].fillna('').agg(' '.join, axis=1).values[0]


# Combine all relevant information into one context string
context = f"The history of present illness is: {discharge_info} \n"
context = context + f"The patient recieved the following ICD diagnosis: {diagnosis_info} \n"
question = "Make a list of the patient's medical problems. Limit the list to 12 items"

# Prepare the final prompt for the model
final_prompt = f"This conversation is about clinical notes from a patient's hospital stay. Use only information given to answer the at the end. Context: {context} \n\nQuestion: Based on the information provided, {question} \n\nAnswer:"
ans = question_tester(final_prompt)
print(ans)


This conversation is about clinical notes from a patient's hospital stay. Use only information given to answer the at the end. Context: The history of present illness is: Ms. ___ is a ___ female with history of 
COPD on home O2, atrial fibrillation on apixaban, hypertension, 
CAD, and hyperlipidemia who presents with shortness of breath, 
cough, and wheezing for one day.

The patient reports shortness of breath, increased cough 
productive of ___ red-flected sputum, and wheezing since 
yesterday evening.  She has been using albuterol IH more 
frequently (___) with ipratropium nebs every 4 hours with 
minimal relief. She had to increase her O2 flow up to 4L without 
significant improvement. She was currently taking 10mg of 
prednisone. She has also been taking tiotropium IH, 
theophylline, advair IH at home as prescribed. She denies sick 
contacts. She quit smoking approximately 1 month ago.

She reports an episode of chest pain in waiting room while 
sitting down, non-exertional, resol

In [10]:
response = ans[len(final_prompt):]
print(response)



1. COPD
2. Atrial fibrillation
3. Hypertension
4. CAD
5. Hyperlipidemia
6. Shortness of breath
7. Cough
8. Wheezing
9. Increased cough productive of red-flected sputum
10. Chest pain
11. Steroid taper
12. Azithromycin therapy
