In [11]:
split = "valid"
version = "v1"

inputcol = "bhc_preceding_text"
outputcol = "brief_hospital_course"

In [2]:
import datasets
from transformers import pipeline
from transformers.pipelines.pt_utils import KeyDataset
from tqdm.auto import tqdm
from pathlib import Path


In [3]:
pipe = pipeline("text2text-generation", model="/home/vs428/Documents/DischargeMe/hail-dischargeme/notebooks/brief_hospital_course/template_code/bart-dischargeme-results_v2/checkpoint-16500", device=0)

In [4]:
from pynvml import *


def print_gpu_utilization():
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)
    info = nvmlDeviceGetMemoryInfo(handle)
    print(f"GPU memory occupied: {info.used//1024**2} MB.", flush=True)


def print_summary(result):
    print(f"Time: {result.metrics['train_runtime']:.2f}", flush=True)
    print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}", flush=True)
    print_gpu_utilization()

## Data

---

In [5]:
import pandas as pd
from datasets import Dataset


In [6]:
print_gpu_utilization()

GPU memory occupied: 2139 MB.


In [7]:
if split == "valid":
    fname = "/gpfs/gibbs/project/rtaylor/shared/DischargeMe/public/valid/discharge_target_with_preceding_text+structured_data.pickle"
elif split == "test_phase_1":
    fname = "/gpfs/gibbs/project/rtaylor/shared/DischargeMe/public/test/discharge_target_with_preceding_text+structured_data.pickle"
else:
    raise Exception("Split doesn't exist!")

In [8]:
data = pd.read_pickle(fname)
ds = Dataset.from_pandas(data[['hadm_id', inputcol, outputcol]], split="test")


In [12]:
data[f'{outputcol}_word_count'].describe()

count    14719.000000
mean       325.122495
std        238.098437
min         10.000000
25%        160.000000
50%        278.000000
75%        439.000000
max       4414.000000
Name: brief_hospital_course_word_count, dtype: float64

In [None]:
ds

In [None]:
BATCH_SIZE = 16

In [None]:
%%time
# KeyDataset (only *pt*) will simply return the item in the dict returned by the dataset item
# as we're not interested in the *target* part of the dataset. For sentence pair use KeyPairDataset
outs = []
for idx, out in enumerate(pipe(KeyDataset(ds, inputcol), 
                     batch_size=BATCH_SIZE,
                     clean_up_tokenization_spaces=True,
                     truncation=True,
                     return_text=True,
                     max_length=1028,
                     no_repeat_ngram_size=5, 
                     num_beams=3, 
                     early_stopping=True
                    )):
    print(f"{idx}/920")
    outs.append(out)

In [None]:
print_gpu_utilization()

In [None]:
def flatten(xss):
    return [x for xs in xss for x in xs]

In [None]:
outs = flatten(outs)

In [None]:
data_out = pd.concat([data, pd.DataFrame.from_records(outs).rename({"generated_text": f"{outputcol}_BART"}, axis=1)], axis=1)

In [None]:
# print("Writing file to:")
# print(f"/home/vs428/Documents/DischargeMe/hail-dischargeme/notebooks/brief_hospital_course/outputs/bhc_scores/{Path(fname).stem}_BART_gen_{version}.pickle")
# data_out.to_pickle(f"/home/vs428/Documents/DischargeMe/hail-dischargeme/notebooks/brief_hospital_course/outputs/bhc_scores/{Path(fname).stem}_BART_gen_{version}.pickle")

# Check results

In [15]:
import pandas as pd

In [26]:
data_out = pd.read_pickle("/home/vs428/Documents/DischargeMe/hail-dischargeme/notebooks/brief_hospital_course/outputs/bhc_scores/discharge_target_with_preceding_text+structured_data_BART_gen_test_phase_1_v3.pickle")

In [27]:
data_out[[f"{outputcol}_BART", outputcol]]

Unnamed: 0,brief_hospital_course_BART,brief_hospital_course
0,"___ with PMH HCV, ETOH cirrhosis with ascites,...","___ with PMH HCV, ETOH cirrhosis with ascites,..."
1,___ yo F s/p L4-5 discectomy/fusion (___) by \...,___ yo F with recent lumbar laminectomy c/b MS...
2,The patient was admitted to the colorectal \ns...,The patient was seen ___ the emergency departm...
3,"___ w/ h/o CAD (known 3VD) and ESRD on HD, \nw...","Ms. ___ is a ___ woman with a history of T2DM,..."
4,Mr. ___ is a ___ year old man with a history o...,___ with IDDM and h/o provoked PE on apixiban ...
...,...,...
14697,___ M PMHx dilated non-ischemic cardiomyopathy...,___ M PMHx dilated non-ischemic cardiomyopathy...
14698,___ year old M with PMH of CAD s/p 2xMI and co...,"Mr. ___ is a ___ year-old man with CAD, HTN, H..."
14699,"___ with CHF (EF 60%), a-fib on coumadin, \nHT...",___ year old female with past medical history ...
14700,This is a ___ with PMHx of hypertension and \n...,Mr. ___ is a ___ male with a past medical hist...


# Scoring

In [28]:
import sys
sys.path.insert(0,"../../../scoring")

In [29]:
from scoring_bhc import calculate_scores_bhc
from scoring_bhc import compute_overall_score_bhc

In [30]:
from scoring_instructions import calculate_scores_instructions
from scoring_instructions import compute_overall_score_instructions

In [31]:
# This assumes that the output from BART has the column outputcol_BART (i.e. brief_hospital_course_BART)
bart_results = data_out[~data_out[f'{outputcol}_BART'].isna()]

In [32]:
print(bart_results['bhc_preceding_text'].sample(1).squeeze())

 
Name:  ___                  Unit No:   ___
 
Admission Date:  ___              Discharge Date:   ___
 
Date of Birth:  ___             Sex:   F
 
Service: MEDICINE
 
Allergies: 
No Known Allergies / Adverse Drug Reactions
 
Attending: ___
 
Chief Complaint:
Ascites and lower extremity edema
 
Major Surgical or Invasive Procedure:
Abdominal paracentesis ___ (2L removed) and ___ (2L 
removed)
 
 
History of Present Illness:
___ w/ HBV/HCV cirrhosis, HTN, DM who presented from her ___ 
clinic for worsening ascites and PCP concern for SBP. She was 
discharged from ___ on ___ after being admitted for a similar 
presentation. During that admission she had a 3L paracentesis 
and her diuretics were uptitrated at that time. She reports that 
her ascites has steadily worsened since discharge. Pt denies 
dietary salt indiscretion. She called her doctor on ___ who 
wanted to increase diuretics, but wanted to get blood work 
first. She followed up with her PCP today, who told her to come 
to ___ 

In [33]:
reference = bart_results[['hadm_id', outputcol]]
generated = bart_results[['hadm_id', f"{outputcol}_BART"]].rename({f"{outputcol}_BART":outputcol}, axis=1)

In [34]:
scores = calculate_scores_bhc(generated,
                 reference,
                 [
                     # "bleu", 
                  "rouge",
                     # "bertscore",
                     # "meteor"
                 ])

Beginning scoring...
rougeScorer initialized
Processed 128/14702 samples.
Processed 256/14702 samples.
Processed 384/14702 samples.
Processed 512/14702 samples.
Processed 640/14702 samples.
Processed 768/14702 samples.
Processed 896/14702 samples.
Processed 1024/14702 samples.
Processed 1152/14702 samples.
Processed 1280/14702 samples.
Processed 1408/14702 samples.
Processed 1536/14702 samples.
Processed 1664/14702 samples.
Processed 1792/14702 samples.
Processed 1920/14702 samples.
Processed 2048/14702 samples.
Processed 2176/14702 samples.
Processed 2304/14702 samples.
Processed 2432/14702 samples.
Processed 2560/14702 samples.
Processed 2688/14702 samples.
Processed 2816/14702 samples.
Processed 2944/14702 samples.
Processed 3072/14702 samples.
Processed 3200/14702 samples.
Processed 3328/14702 samples.
Processed 3456/14702 samples.
Processed 3584/14702 samples.
Processed 3712/14702 samples.
Processed 3840/14702 samples.
Processed 3968/14702 samples.
Processed 4096/14702 samples.
Pr

In [35]:
# write a few examples out to file
import tabulate
print(tabulate.tabulate(bart_results[[outputcol, outputcol]].sample(10), headers='keys', tablefmt='simple'))

       brief_hospital_course                                              brief_hospital_course
-----  -----------------------------------------------------------------  -----------------------------------------------------------------
13534  Mrs. ___ was admitted to the neurology service for evaluation      Mrs. ___ was admitted to the neurology service for evaluation
       of her diplopia. She had a brain MRI due to consideration of       of her diplopia. She had a brain MRI due to consideration of
       brainstem stroke causing CN III dysfunction (especially in the     brainstem stroke causing CN III dysfunction (especially in the
       setting of her family history and cardiovascular risk factors)     setting of her family history and cardiovascular risk factors)
       and this did not show any explanation, including stroke, for her   and this did not show any explanation, including stroke, for her
       vision problems. Therefore the most likely diagnosis became        vision

In [36]:
with open('sample_BART_results_v1.txt', 'w') as f:
    f.write(tabulate.tabulate(bart_results[[outputcol, f"{outputcol}_BART"]].sample(10), headers='keys', tablefmt='simple'))
