In [1]:
startN = 1000
endN = 1002


In [2]:
import traceback
import logging


In [3]:
import sys
import os
import pandas as pd
import tiktoken
from tenacity import (
    retry,
    stop_after_attempt,
    wait_random_exponential,
    RetryError
)  # for exponential backoff

sys.path.append('/home/vs428/Documents/DischargeMe/hail-dischargeme/scoring')

In [4]:
import re
import openai
from datasets import load_dataset
import pandas as pd

import tiktoken

In [5]:
from openai.error import InvalidRequestError

In [6]:
%load_ext dotenv
%dotenv /vast/palmer/home.mccleary/vs428/Documents/DischargeMe/hail-dischargeme/.env

In [7]:
def num_tokens_from_string(string: str, encoding_name: str="cl100k_base") -> int:
    """Returns the number of tokens in a text string."""
    encoding = tiktoken.get_encoding(encoding_name)
    num_tokens = len(encoding.encode(string))
    return num_tokens

In [8]:
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def completion_with_backoff(**kwargs):
    return openai.ChatCompletion.create(**kwargs)


In [9]:
from IPython.display import display, HTML

def add_line_breaks(text):
    return text.replace('\n', '<br>')


def pretty_print(df):
    return display( HTML( df.to_html().replace("\\n","<br>") ) )

In [10]:
challenge_data_fp = "/gpfs/gibbs/project/rtaylor/shared/DischargeMe/public/"


In [11]:
test_dataset = load_dataset('json', data_files="/home/vs428/Documents/DischargeMe/hail-dischargeme/notebooks/data_processing/simple_test.json", split="train")

In [12]:
openai.api_type = "azure"
openai.api_base = os.getenv("AZURE_OPENAI_ENDPOINT_CANADAEAST")
openai.api_version = "2024-02-15-preview"
openai.api_key = os.getenv("AZURE_OPENAI_KEY_CANADAEAST")
engine = "decile-gpt-35-turbo-16k-0125"

# Integrate Service/Ward info

This section will integrate the service-file to match the structure of the note as much as possible. 

In [13]:
target_test = pd.read_csv(challenge_data_fp + "test_phase_1/discharge_target.csv.gz", keep_default_na=False)


In [14]:
# ward transfers
transfers = pd.read_pickle('/gpfs/gibbs/project/rtaylor/shared/DischargeMe/mimiciv/hosp/cohort_transfers.pkl')

# higher-level services (ICU, CARD, etc)
services = pd.read_pickle('/gpfs/gibbs/project/rtaylor/shared/DischargeMe/mimiciv/hosp/cohort_services.pkl')


In [15]:
transfers = transfers[transfers['eventtype'] != "discharge"]

In [16]:
transfers = transfers.sort_values(['hadm_id', "intime"])

In [17]:
discharging_transfer = transfers.groupby("hadm_id").last().reset_index()

In [18]:
discharging_service = services.sort_values(['hadm_id', "transfertime"]).groupby("hadm_id").last().reset_index()

In [19]:
test_dataset_df = test_dataset.to_pandas(); test_dataset_df.shape

(14702, 4)

In [20]:
test_dataset_df = test_dataset_df.merge(discharging_transfer[['hadm_id', 'careunit', 'eventtype']], on="hadm_id", how="left"); test_dataset_df.shape

(14702, 6)

In [21]:
test_dataset_df = test_dataset_df.merge(discharging_service[['hadm_id', 'curr_service']], on="hadm_id", how="left"); test_dataset_df

Unnamed: 0,note_id,hadm_id,input,output,careunit,eventtype,curr_service
0,19766998-DS-20,26231944,\nName: ___ Unit No: ___\n ...,"___ with PMH HCV, ETOH cirrhosis with ascites,...",Med/Surg,transfer,MED
1,10336082-DS-11,28542384,\nName: ___ Unit No: ___\n \nA...,___ yo F with recent lumbar laminectomy c/b MS...,Neurology,admit,MED
2,10481170-DS-22,26489329,\nName: ___ Unit No: ___\n ...,The patient was seen ___ the emergency departm...,Med/Surg/GYN,transfer,SURG
3,11576109-DS-14,22641254,\nName: ___ Unit No: ___\n ...,"Ms. ___ is a ___ woman with a history of T2DM,...",Medicine/Cardiology,transfer,CMED
4,19249697-DS-17,29265750,\nName: ___ Unit No: ...,___ with IDDM and h/o provoked PE on apixiban ...,Vascular,admit,MED
...,...,...,...,...,...,...,...
14697,18995174-DS-17,27445071,\nName: ___ Unit No: ___\...,___ M PMHx dilated non-ischemic cardiomyopathy...,Medicine/Cardiology,admit,CMED
14698,13588195-DS-9,26192891,\nName: ___ Unit No: ___\n ...,"Mr. ___ is a ___ year-old man with CAD, HTN, H...",Med/Surg,admit,MED
14699,10873131-DS-17,26584893,\nName: ___ Unit No: _...,___ year old female with past medical history ...,Medicine,transfer,MED
14700,12332377-DS-11,25623241,\nName: ___ Unit No: ___\n ...,Mr. ___ is a ___ male with a past medical hist...,Med/Surg/Trauma,admit,MED


In [26]:
with pd.option_context("display.max_colwidth", 2000):
    display(test_dataset_df[test_dataset_df['curr_service'] == "GU"].sample(4)['output'])

13522                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 Ms. ___ was admitted to the urology service under the care \nof Dr. ___. She was initially given IV pain medication and \nslowly transitioned to an adequate oral regimen that controlled \nher symptoms. Given difficulty with coping with her new \ndiagnosis, s

# N-shot

This time in the N-shot, instead of hard-coding a few examples, we will specifically pull the "supposed" discharing provider's notes as a sample so that we can match the format as closely as possible. 

In [231]:
ex1 = """Mr. ___ presented to the ED s/p mechanical fall with head 
strike. CTH at OSH with subacute right SDH for which patient was 
transferred to ___ for evaluation and neurosurgery was 
consulted on arrival to ___ ED. 

#Right SDH
Patient remained neurologically at his baseline in the ED. He 
was admitted to the neurosurgery service and transferred to the 
floor for continued neurological monitoring. Patient's INR on 
admission was 1.5 and he was given Vitamin K 10mg IV x1 and then 
vitamin K 5mg x2 for a total of 3 doses. Patient's INR was 
monitored closely during hospitalization and down trended 
appropriately. Repeat CTH showed a stable SDH size. Patient 
remained stable neurologically. He remained on TF as he was at 
home and will follow up with SLP and laryngologist outpatient. 
While hospitalized patient remained at his neurological 
baseline. 

#Hx colloid cyst rupture s/p bilateral VPS placement
On patient's initial and repeat NCHCTs, his ventricles were 
shown to be slit. His bilateral VPS were turned up from 1.0 to 
1.5. No scan was repeated following adjustment of his VPS 
settings. He remained neurologically stable following this 
change. 

#Hx of DVT
Patient with a history of bilateral femoral vein occlusive DVTs 
for which he was started on Eliquis. Due to this subacute SDH 
Eliquis was held and bilateral LENIs were obtained to assess 
DVTs. These were negative for thrombus and the Eliquis was 
stopped indefinitely. Patient was placed on ___ 24hrs s/p 
trauma. 

#Disposition
Patient was evaluated by physical therapy who recommended 
___ rehab. Patient was discharged to rehab on ___. 
Patient discharged with follow up instructions with Dr. ___ in 
4 weeks with a repeat non contrast CTH"""

ex2 = """Mr. ___ is a ___ y/o man w/ HTN who presented w/ sudden-onset LUE 
weakness ___ AM). Next day, presented to ED. Pt noted to have 
L facial droop, LUE weakness (more prominent distally) and 
decreased LUE sensation. CTA showed significant R ICA stenosis; 
U/S indicated occlusion between 80 and 100%. MRI brain showed 
acute small infarcts in R frontal and parietal lobes in the MCA 
territory, consistent w/ watershed infarcts. Started on ASA and 
clopidogrel."""

examples = [ex1, ex2]

In [232]:
# # in order to do the above, we need to find the latest order for any particular pt and find the ordering physician

# poe = pd.read_csv('/gpfs/gibbs/project/rtaylor/shared/DischargeMe/mimiciv/hosp/poe.csv.gz')
# providers = pd.read_csv("/gpfs/gibbs/project/rtaylor/shared/DischargeMe/mimiciv/hosp/provider.csv.gz")


In [233]:
# def get_discharging_provider_note(hadm_id, poe, providers, N):
    

In [234]:
# discharge_details = pd.read_csv("/gpfs/gibbs/project/rtaylor/shared/DischargeMe/mimiciv/note/discharge_detail.csv.gz")

# discharge_details

In [235]:
# (num_tokens_from_string(ex1) + num_tokens_from_string(ex2)) * 0.003

In [236]:
test_dataset_df = test_dataset_df.iloc[startN:endN]

In [237]:
test_dataset_df.shape

(2, 7)

In [238]:
test_dataset_df['careunit'].value_counts()

careunit
Med/Surg/Trauma    1
Medicine           1
Name: count, dtype: int64

In [239]:
careunit_map = pd.read_csv("careunit_map.csv")

In [240]:
prompts = pd.read_csv("bhc_prompts.csv")

In [241]:
bhc_gpt4_simple_medical_prompt = prompts[prompts['prompt_name'] == "bhc_gpt4_simple_medical"]['prompt'].squeeze()
bhc_gpt4_simple_surgical_prompt = prompts[prompts['prompt_name'] == "bhc_gpt4_simple_surgical"]['prompt'].squeeze()

In [242]:
bhc_gpt4_simple_medical_prompt, bhc_gpt4_simple_surgical_prompt

("Summarize the following patient hospital encounter into a brief hospital course. All brief hospital courses start with 1-2 sentences of introduction, describing why the patient arrived to the ED and any relevant features of the initial presentation. Break down the rest of the brief hospital course by condition. Include section headers for 'ACTIVE ISSUES:', 'CHRONIC ISSUES:', and 'TRANSITIONAL ISSUES:' where appropriate. MAKE SURE TO recede each condition with a #. For example, after the introduction, the first section might be '# Right SDH'. \r\n                     \r\nDescribe the course of events that the patient went through during their stay in sequential order, relating relevant labs, procedures, and medications to diagnoses of note found in the encounter information. The goal is to describe the clinical reasoning for various procedures, medications, imaging, and labs as you would to another physician on your care team. Organize your thoughts and then write the brief hospital c

In [243]:
gpt_inputs = []
for _, row in test_dataset_df.iterrows():
    message_text = [{"role":"system","content":"You are a physician generating a summary brief hospital course from the patient encounter information given."}]

    careunit_category = careunit_map[careunit_map['careunit'] == row['careunit']]['category'].squeeze()

    if careunit_category == "medical":
        intro_prompt = bhc_gpt4_simple_medical_prompt
    elif careunit_category == "surgical":
        intro_prompt = bhc_gpt4_simple_surgical_prompt
    else:
        intro_prompt = bhc_gpt4_simple_medical_prompt
        
    gpt_service_prompt = {"role":"user",
                     "content": intro_prompt + f"""\n\nEXAMPLE\n-------------\n{'\n\nEXAMPLE\n-------------\n'.join(examples)}\n\n\n\nPATIENT ENCOUNTER INFORMATION:\n----------------\n\nDischarging Service: {row['careunit']}\n\n{row['input']}."""}
    
    message_text.append(gpt_service_prompt)
    
    gpt_inputs.append(message_text)

In [244]:
test_dataset_df['gpt_input'] = gpt_inputs

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test_dataset_df['gpt_input'] = gpt_inputs


### Start Token Counts

In [245]:
in_tokens = [num_tokens_from_string(x[1]['content']) for x in test_dataset_df['gpt_input'].tolist()]

In [246]:
pd.Series(in_tokens).describe()

count       2.000000
mean     2255.000000
std       224.859956
min      2096.000000
25%      2175.500000
50%      2255.000000
75%      2334.500000
max      2414.000000
dtype: float64

In [247]:
output = pd.read_csv("comparison_simple_gpt3.5_2shot_v2.csv")

In [248]:
out_tokens = [num_tokens_from_string(x) for x in output['gpt_completion'].tolist()]

In [249]:
sum(in_tokens), sum(out_tokens)

(4510, 150054)

In [250]:
pd.Series(in_tokens).describe(), pd.Series(out_tokens).describe()

(count       2.000000
 mean     2255.000000
 std       224.859956
 min      2096.000000
 25%      2175.500000
 50%      2255.000000
 75%      2334.500000
 max      2414.000000
 dtype: float64,
 count    500.000000
 mean     300.108000
 std       84.556442
 min      102.000000
 25%      241.750000
 50%      292.000000
 75%      354.000000
 max      767.000000
 dtype: float64)

In [251]:
(sum(out_tokens) / 1000 * 0.004) + (sum(in_tokens) / 1000 * 0.004)

0.6182559999999999

In [252]:
preceding_text_tokens = [num_tokens_from_string(x) for x in test_dataset_df['input'].tolist()]

In [253]:
pd.Series(preceding_text_tokens).describe()

count       2.00000
mean     1478.50000
std       228.39549
min      1317.00000
25%      1397.75000
50%      1478.50000
75%      1559.25000
max      1640.00000
dtype: float64

In [254]:
instructions_input = 796.8636363636364
instructions_output = 540.6898395721926

In [255]:
bhc_input = 2783.676000
bhc_output = 306.312000

In [256]:
((instructions_input * 24000)/1000 * 0.0005) + ((instructions_output * 24000)/1000 * 0.0015)

29.027197860962566

In [257]:
((bhc_input * 24000)/1000 * 0.0005) + ((bhc_output * 24000)/1000 * 0.0015)

44.43134400000001

### End Token Counts

In [258]:
len(gpt_inputs)

2

In [259]:
completions = []
for idx, (row_idx, row) in enumerate(test_dataset_df.iterrows()):
    try:
        completion = completion_with_backoff(
          engine=engine,
          messages = row['gpt_input'],        
        )
    except (RetryError, InvalidRequestError) as e:
        print(f"error in: {idx} with row index: {row_idx}, note index: {row['note_id']} and hadm_id: {row['hadm_id']}")
        print("Error with OpenAI", traceback.format_exc())
        completions.append("ERROR CONTENT POLICY OPENAI")
        continue
    
    print(idx)
    try:
        completions.append(completion['choices'][0]['message']['content'])
    except Exception as e:
        print(completion)

0
1


19


20


21


22


23


24


25


26


27


28


29


30


31


32


33


34


35


36


37


38


39


40


41


42


43


44


45


46


47


48


49


50


51


52


53


54


55


56


57


58


59


60


61


62


63


64


65


66


67


68


69


70


71


72


73


74


75


76


77


78


79


80


81


82


83


84


85


86


87


88


89


90


91


92


93


94


95


96


97


98


99


100


101


102


103


104


105


106


107


108


109


110


111


112


113


114


115


116


117


118


119


120


121


122


123


124


125


126


127


128


129


130


131


132


133


134


135


136


137


138


139


140


141


142


143


144


145


146


147


148


149


150


151


152


153


154


155


156


157


158


159


160


161


162


163


164


165


166


167


168


169


170


171


172


173


174


175


176


177


178


179


180


181


182


183


184


185


186


187


188


189


190


191


192


193


194


195


196


197


198


199


200


201


202


203


204


205


206


207


208


209


210


211


212


213


214


215


216


217


218


219


220


221


222


223


224


225


226


227


228


229


230


231


232


233


234


235


236


237


238


239


240


241


242


243


244


245


246


247


248


249


250


251


252


253


254


255


256


257


258


259


260


261


262


263


264


265


266


267


268


269


270


271


272


273


274


275


276


277


278


279


280


281


282


283


284


285


286


287


288


289


290


291


292


293


294


295


296


297


298


299


300


301


302


303


304


305


306


307


308


309


310


311


312


313


314


315


316


317


318


319


320


321


322


323


324


325


326


327


328


329


330


331


332


333


334


335


336


337


338


339


340


341


342


343


344


345


346


347


348


349


350


351


352


353


354


355


356


357


358


359


360


361


362


363


364


365


366


367


368


369


370


371


372


373


374


375


376


377


378


379


380


381


382


383


384


385


386


387


388


389


390


391


392


393


394


395


396


397


398


399


400


401


402


403


404


405


406


407


408


409


410


411


412


413


414


415


416


417


418


419


420


421


422


423


424


425


426


427


428


429


430


431


432


433


434


435


436


437


438


439


440


441


442


443


444


445


446


447


448


449


450


451


452


453


454


455


456


457


458


459


460


461


462


463


464


465


466


467


468


469


470


471


472


473


474


475


476


477


478


479


480


481


482


483


484


485


486


487


488


489


490


491


492


493


494


495


496


497


498


499


500


501


502


503


504


505


506


507


508


509


510


511


512


513


514


515


516


517


518


519


520


521


522


523


524


525


526


527


528


529


530


531


532


533


534


535


536


537


538


539


540


541


542


543


544


545


546


547


548


549


550


551


552


553


554


555


556


557


558


559


560


561


562


563


564


565


566


567


568


569


570


571


572


573


574


575


576


577


578


579


580


581


582


583


584


585


586


587


588


589


590


591


592


593


594


595


596


597


598


599


600


601


602


603


604


605


606


607


608


609


610


611


612


613


614


615


616


617


618


619


620


621


622


623


624


625


626


627


628


629


630


631


632


633


634


635


636


637


638


639


640


641


642


643


644


645


646


647


648


649


650


651


652


653


654


655


656


657


658


659


660


661


662


663


664


665


666


667


668


669


670


671


672


673


674


675


676


677


678


679


680


681


682


683


684


685


686


687


688


689


690


691


692


693


694


695


696


697


698


699


700


701


702


703


704


705


706


707


708


709


710


711


712


713


714


715


716


717


718


719


720


721


722


723


724


725


726


727


728


729


730


731


732


733


734


735


736


737


738


739


740


741


742


743


744


745


746


747


748


749


750


751


752


753


754


755


756


757


758


759


760


761


762


763


764


765


766


767


768


769


770


771


772


773


774


775


776


777


778


779


780


781


782


783


784


785


786


787


788


789


790


791


792


793


794


795


796


797


798


799


800


801


802


803


804


805


806


807


808


809


810


811


812


813


814


815


816


817


818


819


820


821


822


823


824


825


826


827


828


829


830


831


832


833


834


835


836


837


838


839


840


841


842


843


844


845


846


847


848


849


850


851


852


853


854


855


856


857


858


859


860


861


862


863


864


865


866


867


868


869


870


871


872


873


874


875


876


877


878


879


880


881


882


883


884


885


886


887


888


889


890


891


892


893


894


895


896


897


898


899


900


901


902


903


904


905


906


907


908


909


910


911


912


913


914


915


916


917


918


919


920


921


922


923


924


925


926


927


928


929


930


931


932


933


934


935


936


937


938


939


940


941


942


943


944


945


946


947


948


949


950


951


952


953


954


955


956


957


958


959


960


961


962


963


964


965


966


967


968


969


970


971


972


973


974


975


976


977


978


979


980


981


982


983


984


985


986


987


988


989


990


991


992


993


994


995


996


997


998


999


In [260]:
# completions#[187]

In [261]:
from rouge_score import rouge_scorer

scorer = rouge_scorer.RougeScorer(['rouge1', "rouge2", 'rougeL'], use_stemmer=True)


In [262]:
all_scores = []
for gen, ref in zip(completions, test_dataset['output']):
    scores = scorer.score(gen, ref)
    print(scores)
    scores_dict = {"rouge1_precision":scores['rouge1'][0], 
                  "rouge1_recall":scores['rouge1'][1], 
                  "rouge1_f1":scores['rouge1'][2], 
                  "rouge2_precision":scores['rouge2'][0], 
                  "rouge2_recall":scores['rouge2'][1], 
                  "rouge2_f1":scores['rouge2'][2], 
                  "rougeL_precision":scores['rougeL'][0], 
                  "rougeL_recall":scores['rougeL'][1], 
                  "rougeL_f1":scores['rougeL'][2], }
    all_scores.append(scores_dict)

{'rouge1': Score(precision=0.22413793103448276, recall=0.07103825136612021, fmeasure=0.10788381742738591), 'rouge2': Score(precision=0.0, recall=0.0, fmeasure=0.0), 'rougeL': Score(precision=0.13793103448275862, recall=0.04371584699453552, fmeasure=0.06639004149377593)}
{'rouge1': Score(precision=0.21788990825688073, recall=0.4846938775510204, fmeasure=0.30063291139240506), 'rouge2': Score(precision=0.04597701149425287, recall=0.10256410256410256, fmeasure=0.06349206349206349), 'rougeL': Score(precision=0.0871559633027523, recall=0.19387755102040816, fmeasure=0.12025316455696203)}


In [263]:
comparison_df = pd.DataFrame([completions, test_dataset['output']]).T.rename({0:"GPT", 1:"gold-standard"}, axis=1)

In [264]:
test_dataset_df['gpt_completion'] = completions

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test_dataset_df['gpt_completion'] = completions


In [265]:
test_dataset_df.to_csv(f"comparison_simple_gpt3.5_2shot_{startN}-{endN}.csv")

In [266]:
comparison_df['GPT'] = comparison_df['GPT'].str.replace("\n", "<br>")
comparison_df['gold-standard'] = comparison_df['gold-standard'].str.replace("\n", "<br>")

In [267]:
comparison_df.to_html(f"comparison_simple_gpt3.5_2shot_{startN}-{endN}.html", escape=False)

In [268]:
pd.DataFrame.from_records(all_scores).to_csv(f"scores_{startN}-{endN}.csv")

In [269]:
pd.DataFrame.from_records(all_scores).mean()

rouge1_precision    0.221014
rouge1_recall       0.277866
rouge1_f1           0.204258
rouge2_precision    0.022989
rouge2_recall       0.051282
rouge2_f1           0.031746
rougeL_precision    0.112543
rougeL_recall       0.118797
rougeL_f1           0.093322
dtype: float64