In [None]:
import datetime
import json
import math
import openai
from openai.error import RateLimitError, InvalidRequestError
import os
import pandas as pd
from pathlib import Path
import random
import regex as re
import time
from utils import gpt3_response, get_patient_info
import generation_utils  # import date_delta, filter_encounters, get_current_items, labs_section, sdoh_section, vitals_section, pronoun, observation_types, fhir_to_section_category
import medspacy
from medspacy.section_detection import Sectionizer
import tqdm
nlp = medspacy.load(enable=[])
bwh_sectionizer = Sectionizer(nlp, rules='section_patterns.json')

## Create prompts according to a specified level of detail

In [None]:
def form_prompt(patient_info, encounter_info, clinical_details: bool = False, history=None):
    # get age from encounter, birthdate
    encounter_date = datetime.date.fromisoformat(encounter_info['datetime']['start'][:10])
    birthdate = datetime.date.fromisoformat(patient_info['patient']['birthdate'][:10])
    age = math.floor((encounter_date - birthdate).days / 365)
    encounter_date = encounter_info['datetime']['start'][:10]
    birthdate = patient_info['patient']['birthdate'][:10]

    # get patient's introductory info
    nom_pronoun = generation_utils.pronoun(patient_info['patient']['gender'])
    poss_pronoun = generation_utils.pronoun(patient_info['patient']['gender'], case='possessive')
    prompt = f"Write an oncological clinical visit summary note about a { age }-year-old { ' '.join(patient_info['patient']['intro_info']) } patient named { patient_info['patient']['name_string']} who presents for {poss_pronoun.lower()} cancer treatment. "

    # get patient's medical history
    if len(patient_info['history']['cancer']) > 0:
        cancers = []
        for cond in patient_info['history']['cancer']:
            if cond['datetime']['recorded'][:10] <= encounter_date and cond['name'] not in cancers:
                if 'end' in cond['datetime']:
                    if cond['datetime']['end'][:10] <= encounter_date:
                        pass
                    else:
                        cancers.append(cond['name'])
                else:
                    cancers.append(cond['name'])
        if len(cancers) > 0:
            if len(cancers) == 1:
                prompt += f"The patient's current primary diagnosis is { cancers[0] }. "
            elif len(cancers) == 2:
                prompt += f"The patient's current primary diagnoses are { ' and '.join(cancers) }. "
            else:
                prompt += f"The patient's current primary diagnoses are { ', '.join(cancers[:-1]) + ', and ' + cancers[-1] }. "
            prompt += f"{nom_pronoun} were" if nom_pronoun == 'They' else f"{nom_pronoun} was "
            prompt += f"first diagnosed {generation_utils.date_delta(patient_info['first_cancer_diagnosis'], encounter_date)} ago. "
    if history is not None:
        # they have a history of... [history up to that date] and are currently prescribed [medications active at the time]
        current_conditions, historic_conditions = generation_utils.get_current_items(patient_info['history']['conditions'], encounter_date=encounter_date, include_historic=True)

        if len(historic_conditions) > 0:
            if len(historic_conditions) >= 3:
                historic_condition_string = ', '.join(historic_conditions[:-1]) + ', and ' + historic_conditions[-1]
            else:
                historic_condition_string = ' and '.join(historic_conditions)
            prompt += f"{nom_pronoun} have " if nom_pronoun == 'They' else f"{nom_pronoun} has "
            prompt += f"a history of {historic_condition_string}. "

        current_conditions = [cond for cond in current_conditions if cond not in cancers]  # list 'cancers' SHOULD exist!!
        if len(current_conditions) > 0:
            current_condition_string = ''
            if len(current_conditions) >= 3:
                current_condition_string = ', '.join(current_conditions[:-1]) + ', and ' + current_conditions[-1]
            else:
                current_condition_string = ' and '.join(current_conditions)
            prompt += f"{nom_pronoun} have " if nom_pronoun == 'They' else f"{nom_pronoun} has "
            prompt += f"also recently experienced { current_condition_string }. "
        
        medications, _ = generation_utils.get_current_items(history['medications'], encounter_date)
        if len(medications) > 0:
            medication_string = ''
            if len(medications) > 3:
                medication_string = ', '.join(medications[:-1]) + ', and ' + medications[-1]
            else:
                medication_string = ' and '.join(medications)
            prompt += f"{nom_pronoun} are" if nom_pronoun == 'They' else f"{nom_pronoun} is "
            prompt += f"currently prescribed { medication_string }. "
    if clinical_details:
        prompt += f'Be varied and creative with the values you put in; don\'t use generic placeholders like "John Doe". Do *not* just use standard placeholder values for {poss_pronoun} vitals like "98.6°F".'
    return prompt

## Insert data associated with a given encounter into its associated note

In [None]:
def insert_observations(pt_info, encounter, note):
    doc = bwh_sectionizer(nlp(note))
    # find index to insert vitals at: either where 'labs_and_studies' is, or before phys exam
    section_spans = [str(doc._.section_spans[i]) for i in range(len(doc._.sections))]
    for i in range(len(doc._.sections)):
        if doc._.section_categories[i] is None:
            doc._.section_categories[i] = 'None'
    # SDOH
    if len(encounter['sdoh']) > 0:
        if 'social history' in doc._.section_categories:
            idx = [i for i, cat in doc._.section_categories if cat == 'social history'][0]
        else:
            idx = math.floor(len(section_spans)/2)
        try:
            lines_to_insert = generation_utils.sdoh_section(pt_info, encounter['sdoh'])
        except KeyError as e:
            pass
        if lines_to_insert is not None:
            if len(lines_to_insert) > 0:
                section_to_insert = lines_to_insert
                section_spans = section_spans[:idx] + [section_to_insert] + section_spans[idx:]

    # labs
    lab_observations = [observation for observation in encounter['observations'] if observation['category'] == 'laboratory']
    if len(lab_observations) > 0:
        section_spans = [str(doc._.section_spans[i]) for i in range(len(section_spans)) if doc._.section_categories[i] != 'Labs']
        if 'laboratory' in doc._.section_categories:
            idx = [i for i, cat in enumerate(doc._.section_categories) if cat == 'laboratory'][0]
        else:
            idx = math.floor(len(section_spans)/2)
        try:
            lines_to_insert = generation_utils.labs_section(lab_observations)
        except KeyError:
            print('lab observations not accounted for:')
            print(lab_observations)
        if lines_to_insert is not None:
            if len(lines_to_insert) > 0:
                section_to_insert = lines_to_insert
                section_spans = section_spans[:idx] + [section_to_insert] + section_spans[idx:]

    # exam
    for observation_type in generation_utils.observation_types:
        relevant_observations = [observation for observation in encounter['observations'] if observation['category'] == observation_type]
        if len(relevant_observations) > 0:
            idx = [i for i, cat in enumerate(doc._.section_categories) if cat == generation_utils.fhir_to_section_category[observation_type]]
            if len(idx) == 0:
                idx = [math.floor(len(section_spans)/2)]
            section_spans = [str(doc._.section_spans[i]) for i in range(len(doc._.section_spans)) if doc._.section_categories[i] is not generation_utils.fhir_to_section_category[observation_type]]
            try:
                lines_to_insert = generation_utils.observation_types[observation_type](relevant_observations)
            except KeyError:
                print('observations not accounted for:')
                print(relevant_observations)
            if lines_to_insert is not None:
                if len(lines_to_insert) > 0:
                    section_to_insert = '\n'.join(lines_to_insert)
                    section_spans = section_spans[:idx[0]] + [section_to_insert] + section_spans[idx[0]:]
            if len(idx) > 1:
                idx = idx[1:]

    # allergies
    allergies = pt_info['history']['allergies']
    if len(allergies) > 0:
        lines_to_insert = generation_utils.allergies_section(pt_info)
        if 'allergies' in doc._.section_categories:
            idx = [i for i, cat in enumerate(doc._.section_categories) if cat == 'allergies']
            if len(idx) == 0:
                idx = [math.floor(len(section_spans)/2)]
        else:
            idx = [math.floor(len(section_spans)/2)]
        section_to_insert = '\n'.join(lines_to_insert)
        section_spans = section_spans[:idx[0]] + [section_to_insert] + section_spans[idx[0]:]

    # medical history
    current, historic = generation_utils.get_current_items(pt_info['history']['conditions'], encounter['datetime']['start'], include_historic=True)
    if len(current + historic) > 0:
        lines_to_insert = ['History:'] + list(set(current)) + list(set(historic)) + ['']
        section_to_insert = '\n'.join(lines_to_insert)
        if 'history' in doc._.section_categories or 'History/Subjective' in doc._.section_categories:
            idx = [i for i, cat in enumerate(doc._.section_categories) if cat == 'History/Subjective' or cat == 'history']
            if len(idx) == 0:
                idx = [math.floor(len(section_spans)/2)]
            # if changing history, save old version
        else:
            idx = [math.floor(len(section_spans)/2)]
        section_spans = section_spans[:idx[0]] + [section_to_insert] + section_spans[idx[0]:]

    current, _ = generation_utils.get_current_items(pt_info['history']['medications'], encounter['datetime']['start'])
    if len(current) > 0:
        lines_to_insert = ['Medications:'] + list(set(current)) + ['']
        section_to_insert = '\n'.join(lines_to_insert)
        if 'medications' in doc._.section_categories:
            idx = [i for i, cat in enumerate(doc._.section_categories) if cat == 'medications']
        else:
            idx = [math.floor(len(section_spans)/2)]
        section_spans = section_spans[:idx[0]] + [section_to_insert] + section_spans[idx[0]:]

    # procedures
    if len(encounter['procedures']) > 0:
        idx = [i for i, cat in enumerate(doc._.section_categories) if cat == generation_utils.fhir_to_section_category[observation_type]]
        if len(idx) == 0:
            idx = [math.floor(len(section_spans)/2)]
        lines_to_insert = ['Procedures:'] + [f"{procedure['name']}" for procedure in encounter['procedures']] + ['']
        section_to_insert = '\n'.join(lines_to_insert)
        section_spans = section_spans[:idx[0]] + [section_to_insert] + section_spans[idx[0]:]    

    note = '\n'.join(section_spans)
    return note

## Get patient information and make GPT calls

In [None]:
synthea_path = './test_output/'
output_path = './gpt_synth_notes/'
for fname in tqdm.tqdm(os.listdir(synthea_path)):  # for each patient:
    pt_info = get_patient_info(os.path.join(synthea_path, fname))
    pt_name = '_'.join([pt_info['patient']['name']['given'].replace(' ', '_'), pt_info['patient']['name']['family']])
    encounters = generation_utils.filter_encounters(pt_info)  # filter only encounters on or after their initial cancer diagnosis
    for encounter in encounters:  # for each of the patient's enounters:
        # ask gpt for summary note based on various levels of info
        # don't bother with "(True, None)"
        for details, history in [(False, None), (True, pt_info['history'])]:
            prompt = form_prompt(pt_info, encounter_info=encounter, clinical_details=details, history=history)
            messages = [{"role": "system", "content": "You are an oncologist."},
                        {'role': 'user', 'content': prompt}]
            try:
                response = gpt3_response(messages=messages)
            except InvalidRequestError:
                continue
            intermediate_dir = 'detailed' if details else 'basic'
            intermediate_dir += '_withhistory' if history else ''
            out = {'patient_url': pt_info['patient']['url'],
                   'encounter_url': encounter['url'],
                   'prompt': prompt,
                   'response': response}
            out = pd.DataFrame(out, index=[0])
            try:
                df = pd.read_csv(os.path.join(output_path, intermediate_dir, intermediate_dir + '.csv'))
                df = pd.concat([df, out], ignore_index=True)
            except FileNotFoundError:
                df = out
            df.to_csv(os.path.join(output_path, intermediate_dir, intermediate_dir + '.csv'), encoding='utf-8', index=False)
            
            out['response'] = insert_observations(pt_info, encounter, response)
            out = pd.DataFrame(out, index=[0])
            try:
                df = pd.read_csv(os.path.join(output_path, intermediate_dir, intermediate_dir + '_augmented.csv'))
                df = pd.concat([df, out], ignore_index=True)
            except FileNotFoundError:
                df = out
            df.to_csv(os.path.join(output_path, intermediate_dir, intermediate_dir + '_augmented.csv'), encoding='utf-8', index=False)