In [None]:
from google.colab import auth, drive
auth.authenticate_user()

# Link Google drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
import os
OUTPUT_DIR = '/content/gdrive/My Drive/Prescribing Cascade/data/output'
data_dirs = [dirname for dirname in os.listdir(OUTPUT_DIR) if dirname.endswith("10000")]

In [None]:
from typing import Dict, List
import numpy as np
from functools import reduce

SCORE_THRESH = 0.9

def verify_sub_words_are_consecutive(stack):
    indices = [el['index'] for el in stack]
    try:
        assert reduce(lambda holds_true_prev, curr_pair: holds_true_prev and (
            curr_pair[0] + 1 == curr_pair[1]), zip(indices, indices[1:]), True), "Not consecutive!"
    except AssertionError:
        import ipdb
        ipdb.set_trace()


def word_accumulator_fn(info_dict: Dict[str, List], curr_tok):
    if curr_tok['entity'].startswith('I'):
        if len(info_dict['curr_stack']) == 0:
            return info_dict  # just don't modify it
        info_dict['curr_stack'].append(curr_tok)
    else:  # starts with B
        assert curr_tok['entity'].startswith(
            'B'), "THERE IS SOME SUSSY STUFF GOING ON WITH THE TAGGER. OUTPUT OTHER THAN I- OR B-???"
        info_dict['word_list'].append(info_dict['curr_stack'])
        info_dict['curr_stack'] = [curr_tok]
    return info_dict


def condense_stack(stack):
    condensed_word = ''.join([part['word'][2:] if part['word'].startswith(
        "##") else ' ' + part['word'] for part in stack]).lstrip()
    scores = np.array([part['score'] for part in stack])
    return condensed_word, scores


def split_one_stack(stack):
    '''Processes only 1 stack and returns a list of untangled stacks (1 stack = 1 entity)'''
    # split section of indices into partitions
    current_ind = -1000
    partitions = []
    current_partition = []
    for el in stack:
        if el['index'] != current_ind + 1:
            # time for new partition
            if len(current_partition) != 0:
                partitions.append(current_partition)
            current_partition = [el]
        else:
            current_partition.append(el)
        current_ind = el['index']

    if len(current_partition) != 0:
        partitions.append(current_partition)

    return partitions


def split_stacks(stacks):
    '''Processes the stacks to return the correct entity stacks. (Untangles multiple entities in 1 stack to ensure each stack = 1 entity)'''
    # sometimes there will be multiple "entities" grouped in 1 stack but with indices that are not contiguous
    return [stack_splitted for stack in stacks for stack_splitted in split_one_stack(stack)]

def valid_fn(element,score_thresh=SCORE_THRESH):
  text, scores = element
  return text != '' and '[PAD]' not in text and scores.mean() > 0.9
def get_tokens_from_ner_specific(raw_outputs, entity_type):
    '''gets proper tokens from nlp pipeline (merges subwords as well)'''
    # output is a list of {entity : str, score : float, index : int, word : str, start : int, end : int}
    # goal: merge consecutive indices into one word
    outputs = filter(lambda output: output['entity'].endswith(
        entity_type), raw_outputs)
    try:
        word_stack = reduce(word_accumulator_fn, outputs,
                            dict(word_list=[], curr_stack=[]))
        all_stacks = word_stack['word_list']
    except Exception as e:
        print("exception:", str(e))
        import ipdb
        ipdb.set_trace()

    if len(word_stack['curr_stack']) > 0:
        all_stacks.append(word_stack['curr_stack'])
    # little assert statement for sanity check
    clean_stacks = split_stacks(all_stacks)
    [verify_sub_words_are_consecutive(stack) for stack in clean_stacks]
    return list(map(lambda el: el[0],filter(valid_fn, map(condense_stack, clean_stacks))))


def get_tokens_from_ner(raw_outputs, entity_list=['ADR', 'DRUG']):
    return {entity_type: get_tokens_from_ner_specific(raw_outputs, entity_type) for entity_type in entity_list}

In [None]:
from os.path import join as pjoin
import pickle
from tqdm import tqdm
files_to_load = ['preds.pickle','id_mappings.pickle']
def load_pickle(filename, dir):
    with open(pjoin(OUTPUT_DIR,dir, filename), "rb") as f:
        return pickle.load(f)
def check_files_exist(dir):
  return reduce(lambda acc,el : acc and el,[os.path.exists(pjoin(OUTPUT_DIR,dir,fn)) for fn in files_to_load],True)
def load_pickles(filenames,dir):
  return [load_pickle(filename,dir) for filename in filenames]
def read_data(data_dir):
  preds, id_mappings = load_pickles(['preds.pickle','id_mappings.pickle'],dir=data_dir)
  processed_preds = [(get_tokens_from_ner(pred), id_map) for pred, id_map in tqdm(zip(preds,id_mappings),total=len(preds)) if len(pred) != 0]
  return [(pred,id_map) for pred, id_map in processed_preds if len(pred['ADR']) != 0 or len(pred['DRUG']) != 0]
data_all = [entity_list for data_dir in data_dirs if check_files_exist(data_dir) for entity_list in read_data(data_dir)]

100%|██████████| 178744/178744 [00:03<00:00, 53230.11it/s]
100%|██████████| 174799/174799 [00:03<00:00, 54627.93it/s]
100%|██████████| 176457/176457 [00:03<00:00, 54615.71it/s]
100%|██████████| 177532/177532 [00:03<00:00, 54015.54it/s]
100%|██████████| 176809/176809 [00:03<00:00, 48605.83it/s]
100%|██████████| 179231/179231 [00:03<00:00, 54423.48it/s]
100%|██████████| 177389/177389 [00:03<00:00, 52616.67it/s]
100%|██████████| 176475/176475 [00:03<00:00, 52691.43it/s]
100%|██████████| 175958/175958 [00:03<00:00, 53500.14it/s]
100%|██████████| 177191/177191 [00:05<00:00, 33961.44it/s]
100%|██████████| 176888/176888 [00:03<00:00, 52961.80it/s]
100%|██████████| 178775/178775 [00:03<00:00, 53561.36it/s]
100%|██████████| 178190/178190 [00:03<00:00, 53738.33it/s]
100%|██████████| 179570/179570 [00:03<00:00, 54482.64it/s]
100%|██████████| 177428/177428 [00:05<00:00, 30269.17it/s]
100%|██████████| 178787/178787 [00:03<00:00, 53666.49it/s]
100%|██████████| 174800/174800 [00:05<00:00, 29580.73it/

In [None]:
def save_pickle(obj, filename, dir=OUTPUT_DIR):
    with open(pjoin(dir, filename), "wb") as f:
        pickle.dump(obj, f)



In [None]:
save_pickle(data_all,"data_all.pickle")

In [None]:
len(data_all)

8011332

In [None]:
from collections import Counter
all_adrs = Counter(adv for chunk in data_all for adv in chunk[0]['ADR'])
all_drugs = Counter(adv for chunk in data_all for adv in chunk[0]['DRUG'])

In [None]:
all_adrs.most_common()

[('pneumothorax', 175462),
 ('pericardial effusion', 121145),
 ('hypotension', 61609),
 ('mitral regurgitation', 60885),
 ('pleural effusion', 59778),
 ('aortic regurgitation', 46319),
 ('edema', 42664),
 ('fever', 35430),
 ('pneumonia', 34932),
 ('voiding', 34691),
 ('nausea', 33759),
 ('sepsis', 32399),
 ('pleural effusions', 31759),
 ('bilateral pleural effusions', 29613),
 ('respiratory failure', 28850),
 ('sinus tachycardia', 28295),
 ('pain', 28215),
 ('pulmonary edema', 27551),
 ('diarrhea', 27225),
 ('chest pain', 27193),
 ('atelectasis', 24836),
 ('voiding stooling', 24262),
 ('hypertension', 23309),
 ('diuresis', 22859),
 ('tachycardia', 22599),
 ('constipation', 21623),
 ('stooling', 20907),
 ('atrial fibrillation', 20339),
 ('cough', 19859),
 ('shortness breath', 19788),
 ('hemorrhage', 18888),
 ('mitral annular calcification', 18512),
 ('anemia', 17658),
 ('hydronephrosis', 17309),
 ('abdominal pain', 17217),
 ('aortic valve stenosis', 16897),
 ('vomiting', 16497),
 ('atri

In [None]:
all_drugs.most_common()

[('fentanyl', 92364),
 ('propofol', 89226),
 ('vancomycin', 85308),
 ('heparin', 81379),
 ('coumadin', 76251),
 ('caffeine', 73493),
 ('aspirin', 63179),
 ('tylenol', 60670),
 ('metoprolol', 60670),
 ('lasix', 58297),
 ('insulin', 55629),
 ('morphine', 52723),
 ('amiodarone', 50408),
 ('lisinopril', 41910),
 ('levofloxacin', 40062),
 ('acetaminophen', 38520),
 ('albuterol', 37410),
 ('prednisone', 36017),
 ('creatinine', 35563),
 ('dilantin', 33822),
 ('lidocaine', 33324),
 ('hydralazine', 33057),
 ('ceftriaxone', 32541),
 ('furosemide', 30978),
 ('oxycodone', 29241),
 ('troponin', 26514),
 ('dopamine', 26081),
 ('ciprofloxacin', 23519),
 ('ativan', 22782),
 ('captopril', 20860),
 ('pantoprazole', 20858),
 ('atorvastatin', 19974),
 ('simvastatin', 19105),
 ('ampicillin', 18649),
 ('metoprolol tartrate', 17957),
 ('digoxin', 17723),
 ('metronidazole', 17025),
 ('lactulose', 16770),
 ('atenolol', 16759),
 ('famotidine', 16413),
 ('meropenem', 16230),
 ('levothyroxine', 16212),
 ('docusat

In [None]:
from collections import defaultdict
def groupby_rowid(data_all):
  row_id_dict = defaultdict(lambda : dict())
  for pred, indices in data_all:
    row_id, sentence_id = indices
    row_id_dict[row_id][sentence_id] = pred
  return row_id_dict

In [None]:
groupedby = groupby_rowid(data_all)

In [None]:
save_pickle(dict(groupedby),"groupedby.pickle")

In [None]:
groupedby

defaultdict(<function __main__.groupby_rowid.<locals>.<lambda>>,
            {21073: {7: {'ADR': ['nausea'], 'DRUG': []},
              10: {'ADR': ['dilaudid p . o'], 'DRUG': []},
              14: {'ADR': [], 'DRUG': ['doxycycline']},
              33: {'ADR': ['clubbing', 'cyanosis edema'], 'DRUG': []},
              35: {'ADR': ['cervical motion tenderness'], 'DRUG': []},
              38: {'ADR': ['uterine tenderness'], 'DRUG': []},
              42: {'ADR': [], 'DRUG': ['beta hcg']},
              43: {'ADR': ['hematosalpinx pyosalpinx'], 'DRUG': []},
              46: {'ADR': [], 'DRUG': ['doxycycline flagyl']},
              56: {'ADR': [], 'DRUG': ['doxycycline']},
              62: {'ADR': [], 'DRUG': ['percocet']},
              67: {'ADR': [], 'DRUG': ['ciprofloxacin']},
              71: {'ADR': [], 'DRUG': ['ciprofloxacin']},
              73: {'ADR': [], 'DRUG': ['percocet']},
              75: {'ADR': ['pain'], 'DRUG': []},
              76: {'ADR': [], 'DRUG': ['ibupro