In [1]:
import os
import sys
import argparse
import json

import pandas as pd

import torch
from transformers import DataCollatorForLanguageModeling, BertForMaskedLM
from transformers import Trainer, TrainingArguments

from data import LineByLineTextDataset
from tokens import WordLevelBertTokenizer
from vocab import create_vocab
from utils import DATA_PATH, make_dirs

In [2]:
tokenizer = WordLevelBertTokenizer(create_vocab(merged=True, uni_diag=True))    
len(tokenizer)

86803

In [101]:
user_group = [str(i) for i in range(10)]

vocab, size = {}, 0
for group in user_group:
    read = os.path.join(DATA_PATH, f'group_{group}.csv')

    with open(read, 'r') as raw:
        for line in raw:
            if '2860253' in line:
                print(line)
                break
            line = line.replace('\n', '')
            user, tokens = line.split(',')
            tokens = tokens.strip()
            token_list = tokens.split(' ')

            for token in token_list:
                if token not in ['[SEP]', 'document', '']:
                    if token in vocab:
                        # If a token is existed, don't do anything.
                        pass
                    else:
                        # A new token: tokens value will start from 0
                        vocab[token] = size
                        size += 1

for j, v in enumerate(['[UNK]', '[SEP]', '[CLS]']):
    vocab[v] = size + j

In [103]:
icd9_vocab, icd10_vocab = {}, {}
for k in vocab:
    if '10_' in k:
        icd10_vocab[k] = 1
    elif '9_' in k:
        icd9_vocab[k] = 1

'/nfs/turbo/lsa-regier/emr-data/vocabs/vocab_merged.json'

In [68]:
diag_write = os.path.join('/home/liutianc/emr/data/', 'diag_stat.json')
proc_write = os.path.join('/home/liutianc/emr/data/', 'proc_stat.json')
pharm_write = os.path.join('/home/liutianc/emr/data/', 'pharm_stat.json')

with open(diag_write, 'r') as to_write:
    diags = json.load(to_write)
        
with open(proc_write, 'r') as to_write:
    procs = json.load(to_write)
    
with open(pharm_write, 'r') as to_write:
    pharms = json.load(to_write)

In [95]:
diag_tokens = {token for token in vocab if 'diag' in token}
diag_tokens = set([token.replace('diag:', '') for token in diag_tokens])

miss_pv = 0
miss_key = []
for key in set(diags).difference(diag_tokens_list):
    miss_pv += diags[key]
    miss_key.append(key)
print(miss_pv)
print(len(miss_key))

232
180


In [94]:
proc_tokens = {token for token in vocab if 'proc' in token}
proc_tokens = set([token.replace('proc:', '') for token in proc_tokens])

miss_pv = 0
miss_key = []
for key in set(procs).difference(proc_tokens):
    miss_pv += procs[key]
    miss_key.append(key)
print(miss_pv)
print(len(miss_key))

50
43


In [96]:
pharm_tokens = set({token for token in vocab if not 'proc' in token and not 'diag' in token})
# proc_tokens = set([token.replace('proc:', '') for token in proc_tokens])

miss_pv = 0
miss_key = []
for key in set(pharms).difference(pharm_tokens):
    miss_pv += pharms[key]
    miss_key.append(key)

print(miss_pv)
print(len(miss_key))

8
8


In [78]:
invalid = {}
valid = set({token for token in vocab if 'nan' not in token})
for key in set({token for token in vocab if 'nan' in token}):
    code = key.split('_')[1]
    for token in valid:
        if 'diag' in token or 'proc' in token:
            if code == token.split('_')[1]:
                invalid[key] = token
#                 print(f'miss: {code}, valid: {token}')
invalid        

{'icd:nan_diag:5601': 'icd:9_diag:5601',
 'icd:nan_proc:0DJD8ZZ': 'icd:10_proc:0DJD8ZZ',
 'icd:nan_diag:5733': 'icd:9_diag:5733',
 'icd:nan_proc:0DP67UZ': 'icd:10_proc:0DP67UZ',
 'icd:nan_proc:0DHA7UZ': 'icd:10_proc:0DHA7UZ',
 'icd:nan_proc:B2151ZZ': 'icd:10_proc:B2151ZZ',
 'icd:nan_proc:3E0H8GC': 'icd:10_proc:3E0H8GC',
 'icd:nan_proc:B2111ZZ': 'icd:9_proc:B2111ZZ',
 'icd:nan_proc:4A023N7': 'icd:10_proc:4A023N7',
 'icd:nan_proc:0DB68ZX': 'icd:10_proc:0DB68ZX',
 'icd:nan_proc:02HV33Z': 'icd:10_proc:02HV33Z',
 'icd:nan_diag:2630': 'icd:9_diag:2630',
 'icd:nan_diag:5770': 'icd:9_diag:5770'}

In [79]:
test = '560499293410900,2018-01-15,icd:10_diag:K56690 icd:10_diag:K760 icd:10_diag:I10 icd:10_diag:K56699 icd:10_diag:R1084 icd:10_diag:E119 icd:10_diag:E785 icd:10_diag:E871 icd:10_diag:G8929 icd:10_diag:I10 icd:10_diag:I2510 icd:10_diag:K56600 icd:10_diag:M549 icd:10_diag:N151 icd:10_diag:N179 icd:10_diag:N319 icd:10_diag:Z794 icd:10_diag:Z7982 icd:10_diag:Z79899 icd:10_diag:Z8711 icd:10_diag:Z950 icd:10_diag:Z955 icd:10_diag:Z981 icd:10_diag:Z993 icd:nan_proc:02HV33Z icd:10_proc:02HV33Z'

In [85]:
for key in invalid:
    test = test.replace(key, invalid[key])

In [89]:
len(set({token for token in vocab if 'nan' in token}))
len(invalid)

13

In [None]:
dataset = LineByLineTextDataset(tokenizer=tokenizer, data_type='merged', max_length=max_length)

In [None]:
mlm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15, )

In [None]:
from transformers import BertConfig
config = BertConfig(vocab_size=len(tokenizer), max_position_embeddings=max_length,
                    num_attention_heads=4,
                    num_hidden_layers=4,
                    hidden_size=128,
                    type_vocab_size=1, )

In [None]:
model = BertForMaskedLM(config=config)

In [None]:
model.num_parameters()

In [None]:
training_args = TrainingArguments(output_dir='./result-dev/MLM', overwrite_output_dir=True,
                              num_train_epochs=1,
                              per_device_train_batch_size=bsz,
                              save_steps=10_000,)

In [None]:
trainer = Trainer(model=model,
                  args=training_args,
                  data_collator=mlm_collator,
                  train_dataset=dataset,
                  prediction_loss_only=True, )

In [11]:
dataloader = trainer.get_train_dataloader()

In [27]:
from pathlib import Path
import csv

input_path = '/nfs/turbo/lsa-regier/OPTUMInsight_csv/'
diags = [str(x) for x in Path(input_path).glob("**/diag_201[0-9].csv")]
procs = [str(x) for x in Path(input_path).glob("**/proc_201[0-9].csv")]
pharms = [str(x) for x in Path(input_path).glob("**/pharm_201[0-9].csv")]

diag_pattern = '{},' * 6 + '{}\n'
proc_pattern = '{},' * 5 + '{}\n'
pharm_pattern = '{},' * 5 + '{}\n'

print('Start: Select diag data.')
for file in diags:
    file_name = os.path.split(file)[1]

#     print(f'Start: {file_name}.')
    with open(file, newline='') as infile:
        diagreader = csv.reader(infile)
        row_num = 0
        for row in diagreader:
            row_num += 1 
#             row = [cell.replace('.0', '').strip() for cell in row]
            row = [cell.strip() for cell in row]
            patid = row[0].split('.0')[0].strip()
            claimid = row[2].split('.0')[0].strip()
            Diag = row[3].split('.0')[0].strip()
            Diag_Position = row[4].split('.0')[0].strip()
            Icd_Flag = row[5].split('.0')[0].strip()
            Loc_cd = row[6].split('.0')[0].strip()
            Fst_Dt = row[10].strip()

            select_row = diag_pattern.format(patid, claimid, Diag, Diag_Position, Icd_Flag, Loc_cd, Fst_Dt)
            print(select_row)
            
            if row_num >= 2:
                break

print('*' * 200)
print('Start: Select proc data.')
for file in procs:
    file_name = os.path.split(file)[1]

#     print(f'Start: {file_name}.')
    with open(file, newline='') as infile:
        procreader = csv.reader(infile)

        row_num = 0
        for row in procreader:
            row_num += 1
            
#                 row = [cell.replace('.0', '').strip() for cell in row]
            row = [cell.strip() for cell in row]

            patid = row[0].split('.0')[0].strip()
            claimid = row[2].split('.0')[0].strip()
            Icd_Flag = row[3].split('.0')[0].strip()
            Proc = row[4].split('.0')[0].strip()
            Proc_Position = row[5].split('.0')[0].strip()
            Fst_Dt = row[8]
            select_row = proc_pattern.format(patid, claimid, Icd_Flag, Proc, Proc_Position, Fst_Dt)
            print(select_row)
            if row_num >= 2:
                break


# logger.info('Start: Select pharm data.')
# for file in pharms:
#     file_name = os.path.split(file)[1]
#     output_file = os.path.join(output_path, file_name)

#     logger.info(f'Start: {file_name}.')
#     with open(file, newline='') as infile:
#         with open(output_file, 'w') as outfile:
#             pharmreader = csv.reader(infile)

#             for row in pharmreader:
#                 row = [cell.strip() for cell in row]

#                 patid = row[0].split('.0')[0].strip()
#                 claimid = row[7].split('.0')[0].strip()
#                 Fill_Date = row[14].split('.0')[0].strip()
#                 Gnrc_Nm = row[19].split('.0')[0].strip()
#                 Quantity = row[25].split('.0')[0].strip()
#                 Rfl_Nbr = row[26].split('.0')[0].strip()

#                 Gnrc_Nm = Gnrc_Nm.replace('"', '').replace(',', '_').replace(' ', '_').strip()

#                 select_row = pharm_pattern.format(patid, claimid, Fill_Date, Gnrc_Nm, Quantity, Rfl_Nbr)
#                 outfile.write(select_row)

#     logger.info(f'Finish: {file_name}.')
# logger.info('Finish: Select pharm data.')

Start: Select diag data.
Patid,Clmid,Diag,Diag_Position,Icd_Flag,Loc_cd,Fst_Dt

560499200160862,N9NFNVL9L8,J029,01,10,2,2015-12-17

Patid,Clmid,Diag,Diag_Position,Icd_Flag,Loc_cd,Fst_Dt

560499200160862,339J9VJJ3R,2724,01,9,2,2013-11-19

Patid,Clmid,Diag,Diag_Position,Icd_Flag,Loc_cd,Fst_Dt

560499200160862,NF9NVLNF3V,E782,02,10,2,2017-04-24

Patid,Clmid,Diag,Diag_Position,Icd_Flag,Loc_cd,Fst_Dt

560499200160862,N38RF9RJLV,D125,01,10,2,2016-11-21

Patid,Clmid,Diag,Diag_Position,Icd_Flag,Loc_cd,Fst_Dt

560499200782112,39J998VRFO,2572,01,9,2,2011-11-15

Patid,Clmid,Diag,Diag_Position,Icd_Flag,Loc_cd,Fst_Dt

560499200160862,39LR9FFVVJ,79981,01,9,2,2012-11-06

Patid,Clmid,Diag,Diag_Position,Icd_Flag,Loc_cd,Fst_Dt

560499200160862,NLVFLV98N3,H6982,01,10,2,2018-02-26

Patid,Clmid,Diag,Diag_Position,Icd_Flag,Loc_cd,Fst_Dt

560499200782112,JF39FJO39F,72252,01,9,2,2010-06-11

Patid,Clmid,Diag,Diag_Position,Icd_Flag,Loc_cd,Fst_Dt

560499200160862,38RO383O38,2724,02,9,2,2014-11-06

**************

In [None]:
for data in dataloader:
    print(data)
    break

In [None]:
data['labels'].shape