In [105]:
import os
import csv
import sys
import argparse
import json

import pandas as pd
import numpy as np

import torch
from transformers import DataCollatorForLanguageModeling, BertForMaskedLM
from transformers import Trainer, TrainingArguments

from data import LineByLineTextDataset
from tokens import WordLevelBertTokenizer
from vocab import create_vocab
from utils import DATA_PATH, make_dirs

# Convert Diag code

In [117]:
diagmap_path = os.path.join('/home/liutianc/emr/data/', 'diag_9to10.csv')
icd_csv = csv.reader(open(diagmap_path), delimiter='|')

In [250]:
icd_map = []
for code in open(diagmap_path).readlines():
    code=code.strip()
    if code.split("|")[1] == 'Flags':
        header = code.split("|")
        print(code)
    else:
        icd_map.append(code.split("|"))
        
icd_map_df = pd.DataFrame(icd_map).loc[:,:1]
icd_map_df.columns = ['icd9', 'icd10']

# Drop some outlier characters.
icd_map_df['icd9'] = icd_map_df['icd9'].str.replace("'", "", regex=False)
icd_map_df['icd10'] = icd_map_df['icd10'].str.replace("'", "", regex=False)

# Only keep rows with defined mapping relationships: coverage rate is 14,660 / 15,086 = 95.85%
icd_map_df = icd_map_df.loc[icd_map_df['icd10'] != '']

# Only keep rows with UNIQUE mapping relationships: coverage rate is 13,516 / 14,660 = 92.20%
icd9_mapnum = icd_map_df.groupby('icd9').count().reset_index()
icd9_unimap = set(list(icd9_mapnum.loc[icd9_mapnum['icd10'] == 1, 'icd9']))
icd_map_df = icd_map_df.loc[icd_map_df['icd9'].apply(lambda x: x in icd9_unimap)]

# Cast icd10 encode to our data version: coverage rate is 9261 / 9802 = 94.48%
icd_map_df['icd10'] = icd_map_df['icd10'].str.replace('.', '', regex=False).str.upper()

icd_map_df.shape

(13516, 2)

In [261]:
icd_map_df = icd_map_df.set_index('icd9')
icd_map_raw = icd_map_df.to_dict(orient='index')

In [None]:
icd_map = {}
for idx, icd9 in enumerate(icd_map_raw):
    value = test[icd9]['icd10']
    
    _icd9 = icd9.replace('.', '')
    keys = [_icd9]
    # Add zero: We only add 0 to the end.
    keys += [_icd9 + '0']
    # If the last two chars are not 00, then we will add double 00 to the end.
    if _icd9[-2:] != '00':
        keys += [_icd9 + '00']

    # Strip ALL 0.
    keys += [_icd9.strip('0')]

    for key in set(keys):
        if key in icd_map:
            icd_map[key] = None
        else:
            icd_map[key] = [value, icd9]


In [474]:
icd9 = set(diag_icd9)

vc = [k for k in icd_map if icd_map[k] is not None]
ivc = [k for k in icd_map if icd_map[k] is None]

print(f'''
Total icd9 diag we have: {len(icd9)},
Caught by valid mapping: {len(icd9.difference(set(vc)))},
Caught by invalid mapping: {len(icd9.intersection(ivc))},
Not caught: {len(icd9.difference(set(icd_map)))}
''')


Total icd9 diag we have: 26963,
Caught by valid mapping: 11962,
Caught by invalid mapping: 976,
Not caught: 10986



In [465]:
icd9 = set(diag_icd9)
icd10 = set(diag_icd10)

vc = [k for k in icd_map if icd_map[k] is not None]
ivc = [k for k in icd_map if icd_map[k] is None]

fail, success = [], []
for t in icd9:
    is_succ = 0
    if t in vc:
        success.append(t)
        is_succ = 1
    
    # Since we didn't add 0 to the beginning, removing possible extra 0 is done here.
    while t[0] == '0':
        t = t[1:]
        if t in vc:
            success.append(t)
            is_succ = 1
            continue
    if is_succ == 0:
        fail.append(t)

fail = set(fail)

print(f'''
Total icd9 diag we have: {len(icd9)},
Caught by invalid mapping: {len(success)},
Caught by invalid mapping: {len(fail.intersection(ivc))},
Not caught: {len(fail.difference(set(icd_map)))}
Actually ICD10: {len(fail.intersection(icd10))}
''')

15659
11729


In [481]:
# real_miss = fail.difference(set(icd_map)).difference(icd10)


In [477]:
# user_group = [str(i) for i in range(10)]

# for group in user_group:
#     read = os.path.join(DATA_PATH, f'group_{group}.csv')

#     with open(read, 'r') as raw:
#         for line in raw:
#             for token in icd10_9:
#                 if f'icd:9_diag:{token}' in line:
#                     print(line)
#                     break
    
#     print(f'Checked: {group}')

560499201141940,2014-10-02,icd:9_diag:25000

560499201141940,2014-10-28,icd:9_diag:25000 icd:9_diag:4011 icd:9_diag:4659 icd:9_diag:7862 AZITHROMYCIN HYDROCODONE/CHLORPHEN_P-STIREX

560499201141940,2014-11-06,icd:9_diag:25000 icd:9_diag:2724 icd:9_diag:4011 icd:9_diag:V5869

560499201141940,2014-11-11,icd:9_diag:25000 icd:9_diag:4011 icd:9_diag:41401 icd:9_diag:4659 AZITHROMYCIN METFORMIN_HCL ERGOCALCIFEROL_(VITAMIN_D2)

560499201141940,2014-12-09,icd:9_diag:25000 icd:9_diag:2724 icd:9_diag:4011 icd:9_diag:41401

560499201141940,2015-06-02,icd:9_diag:25000 icd:9_diag:41401 icd:9_diag:78652 icd:9_diag:7904 INSULIN_LISPRO NEEDLES__INSULIN_DISPOSABLE

560499201141940,2015-07-20,icd:9_diag:41401 icd:9_diag:412 icd:9_diag:79439 icd:9_diag:V4581

560499201299620,2010-09-27,icd:9_diag:V700 icd:9_diag:V700 icd:9_diag:V700

560499201299620,2014-02-20,icd:9_diag:V700 icd:9_diag:V700 icd:9_diag:V700

560499201299620,2015-04-30,icd:9_diag:59972 icd:9_diag:V700 icd:9_diag:V700

560499201299620,2015

KeyboardInterrupt: 

In [475]:
icd10_9 = fail.intersection(icd10)

In [476]:
icd10_9

{'I4510',
 'I482',
 'C4440',
 'H402211',
 'H20013',
 'L270',
 'N870',
 'L080',
 'M84333A',
 'D61818',
 'S82402S',
 'M25762',
 'F0150',
 'W260XXA',
 'E0865',
 'H01003',
 'C786',
 'S83206A',
 'I479',
 'I77819',
 'R110',
 'Z9225',
 'R1031',
 'K9420',
 'S066X0A',
 'J939',
 'H401122',
 'L089',
 'I4949',
 'N261',
 'E11621',
 '7390',
 'S72002A',
 'I517',
 'Z3403',
 'H353124',
 'R2689',
 'E889',
 'W540XXA',
 'S39013A',
 'H179',
 'L89310',
 'S76812D',
 'S4292XA',
 'D0359',
 'E1343',
 'C8190',
 'M5414',
 'S0083XA',
 'C8512',
 'H5442',
 'N186',
 'S3314',
 '84409',
 'A4189',
 'M94261',
 'M9943',
 'Z452',
 'W19XXXA',
 'I87323',
 '5609',
 'K269',
 'M4316',
 'Z96649',
 'Z6835',
 'M8589',
 'J918',
 'T792XXS',
 'G301',
 'C7802',
 'H5710',
 'M25461',
 'M8710',
 'T889XXD',
 'F325',
 'M24172',
 'H04423',
 'F1010',
 'S61412A',
 'S12101A',
 'M2557',
 'E440',
 'M9904',
 'S60511A',
 'N6021',
 'I213',
 'M50222',
 'G0491',
 'I8390',
 'M990',
 'H10503',
 'I69319',
 'S065X9A',
 'C132',
 'Z6826',
 'M9240',
 'E2749

In [421]:
z = [t for t in diag_icd9 if t[0] == '0']
ext = [t.replace('.', '') for t in test]

In [424]:
# miss = []
# cat = 0
# for t in z:
#     if t in ext:
#         cat +=1 
#     else:
#         miss.append(t)

# miss

In [460]:
# import re

# result = []
# pattern = '^.*0{3}$'
# for diag in diag_icd9:
#     if re.search(pattern, diag):
#         result.append(diag)
# result

In [418]:
[token for token in diag_icd9 if '0350' in token]

['80350', '40350', '03500']

In [373]:
# [token for token in diag_icd9 if len(token) > 5]

[token for token in diag_icd9 if '024' in token]


['20240',
 '90240',
 'S0240CA',
 '40240',
 'H02403',
 'H02409',
 'H02401',
 '30240',
 '80240']

In [347]:
# [ key for key in icd_map if '.' not in key]

In [202]:
icd9 = icd_map_df['icd9']
icd9_points = icd9.str.split('.', expand=True)
icd9_p_len = icd9_points[1].apply(lambda x: len(x) if x else 0)

icd9_p_len.value_counts()

2    9065
1    5201
0     394
Name: 1, dtype: int64

In [None]:
mapping = {}
for icd9 in icd_map_df

In [2]:
icd_dict_path = os.path.join('/home/liutianc/emr/data/', 'ICD_9_10_d_v1.1.csv')
icd_csv = csv.reader(open(icd_dict_path), delimiter='|')

In [22]:
# icd_dict = []
# for code in open(icd_dict_path).readlines():
#     code=code.strip()
#     if code.split("|")[1] == 'Flags':
#         header = code.split("|")
#         print(code)
# #     else:
#     icd_dict[code.split("|")[1]] = code

TargetI9|Flags|I9Name


In [118]:
icd_map = []
for code in open(icd_dict_path).readlines():
    code=code.strip()
    if code.split("|")[1] == 'Flags':
        header = code.split("|")
        print(code)
    else:
        icd_map.append(code.split("|"))
        
icd_map_df = pd.DataFrame(icd_map)
icd_map_df.columns = ['icd10', 'icd9', 'i9name']

TargetI9|Flags|I9Name


In [410]:
vocabs['icd:9_diag:2860253']

99960

In [68]:
icd10 = icd_map_df['icd10']
icd10_points = icd10.str.split('.', expand=True)
icd10_p_len = icd10_points[1].apply(lambda x: len(x) if x else 0)

icd10_p_len.value_counts()

In [125]:
# icd_map_df.loc[icd_map_df['icd9'] == '0010']

In [72]:
icd9 = icd_map_df['icd9']
icd9_points = icd9.str.split('.', expand=True)
icd9_p_len = icd9_points[1].apply(lambda x: len(x) if x else 0)

icd9_p_len.value_counts()

0    108498
2     89130
1     63513
Name: 1, dtype: int64

In [75]:
vocab_path = os.path.join(DATA_PATH, 'vocabs', 'vocab_merged.json')

with open(vocab_path, 'r') as file:
    vocabs = json.load(file)
        

In [463]:
diag_icd9, diag_icd10 = [], []
proc_icd9, proc_icd10 = [], []
for token in vocabs:
    if 'diag' in token:
        if 'icd:9_' in token:
            token = token.split('_')[1]
            token = token.split(':')[1]
            diag_icd9.append(token)
        elif 'icd:10_' in token:
            token = token.split('_')[1]
            token = token.split(':')[1]
            diag_icd10.append(token)   
    elif 'proc' in token:
        if 'icd:9_' in token:
            token = token.split('_')[1]
            token = token.split(':')[1]
            proc_icd9.append(token)
        elif 'icd:10_' in token:
            token = token.split('_')[1]
            token = token.split(':')[1]
            proc_icd10.append(token)   
            
diag_icd9 = [token.replace('-', '') for token in diag_icd9]
diag_icd10 = [token.replace('-', '') for token in diag_icd10]

# diag_icd9 = dict.fromkeys(diag_icd9, 1)
# diag10 = dict.fromkeys(diag_icd10, 1)

In [462]:
[t for t in vocabs if 'W260XXA' in t]

['icd:10_diag:W260XXA', 'icd:9_diag:W260XXA']

In [249]:
map10 = icd_map_df['icd10']
map10 = [token.upper() for token in map10]
print(f'''
match: {len(set(diag_icd10).intersection(set(map10)))} / {len(set([token.upper() for token in map10]))} = \
{len(set(diag_icd10).intersection(set(map10))) / len(set([token.upper() for token in map10]))}, 
''')

# set(map10).difference(set(diag_icd10))?


match: 9261 / 9802 = 0.9448071822077128, 



In [292]:
'05310'.strip('0')

'531'