In [1]:
import os
import sys
import argparse

import torch
from transformers import DataCollatorForLanguageModeling, BertForMaskedLM
from transformers import Trainer, TrainingArguments

from data import LineByLineTextDataset
from tokens import WordLevelBertTokenizer
from vocab import create_vocab
from utils import DATA_PATH, make_dirs

In [2]:
vocab = create_vocab(merged=True)
vocab

'/nfs/turbo/lsa-regier/emr-data/vocabs/vocab_merged.json'

In [3]:
tokenizer = WordLevelBertTokenizer(vocab)
    
len(tokenizer)

In [5]:
user_group = [str(i) for i in range(10)]

vocab, size = {}, 0
for group in user_group:
    read = os.path.join(DATA_PATH, f'group_{group}_merged.csv')

    with open(read, 'r') as raw:
        for line in raw:
            line = line.replace('\n', '')
            user, tokens = line.split(',')
            tokens = tokens.strip()
            token_list = tokens.split(' ')

            for token in token_list:
                if token not in ['[SEP]', 'document', '']:
                    if token in vocab:
                        # If a token is existed, don't do anything.
                        pass
                    else:
                        # A new token: tokens value will start from 0
                        vocab[token] = size
                        size += 1

for j, v in enumerate(['[UNK]', '[SEP]', '[CLS]']):
    vocab[v] = size + j

In [6]:
vocab

{'SIMVASTATIN': 0,
 'METOPROLOL_SUCCINATE': 1,
 'LISINOPRIL': 2,
 'ESCITALOPRAM_OXALATE': 3,
 'METFORMIN_HCL': 4,
 'SITAGLIPTIN_PHOSPHATE': 5,
 'INSULIN_GLARGINE_HUM.REC.ANLOG': 6,
 'LEVOTHYROXINE_SODIUM': 7,
 'icd:9.0_diag:25000': 8,
 'icd:9.0_diag:4011': 9,
 'icd:9.0_diag:4659': 10,
 'icd:9.0_diag:7862': 11,
 'AZITHROMYCIN': 12,
 'HYDROCODONE/CHLORPHEN_P-STIREX': 13,
 'icd:9.0_diag:2724': 14,
 'icd:9.0_diag:V5869': 15,
 'icd:9.0_diag:41401': 16,
 'ERGOCALCIFEROL_(VITAMIN_D2)': 17,
 'OXYCODONE_HCL/ACETAMINOPHEN': 18,
 'INSULIN_LISPRO': 19,
 'NEEDLES__INSULIN_DISPOSABLE': 20,
 'AMOXICILLIN': 21,
 'METHYLPREDNISOLONE': 22,
 'BLOOD_SUGAR_DIAGNOSTIC': 23,
 'LANCETS': 24,
 'icd:9_diag:2724': 25,
 'icd:9_diag:25000': 26,
 'icd:9_diag:41401': 27,
 'icd:9_diag:78652': 28,
 'icd:9_diag:7904': 29,
 'icd:9_diag:V4581': 30,
 'icd:9_diag:412': 31,
 'icd:9_diag:79439': 32,
 'PEN_NEEDLE__DIABETIC': 33,
 'icd:10_diag:J029': 34,
 'icd:10_diag:J309': 35,
 'icd:10_diag:J320': 36,
 'CEFUROXIME_AXETIL': 3

In [None]:
dataset = LineByLineTextDataset(tokenizer=tokenizer, data_type='merged', max_length=max_length)

In [None]:
mlm_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15, )

In [None]:
from transformers import BertConfig
config = BertConfig(vocab_size=len(tokenizer), max_position_embeddings=max_length,
                    num_attention_heads=4,
                    num_hidden_layers=4,
                    hidden_size=128,
                    type_vocab_size=1, )

In [None]:
model = BertForMaskedLM(config=config)

In [None]:
model.num_parameters()

In [None]:
training_args = TrainingArguments(output_dir='./result-dev/MLM', overwrite_output_dir=True,
                              num_train_epochs=1,
                              per_device_train_batch_size=bsz,
                              save_steps=10_000,)

In [None]:
trainer = Trainer(model=model,
                  args=training_args,
                  data_collator=mlm_collator,
                  train_dataset=dataset,
                  prediction_loss_only=True, )

In [11]:
dataloader = trainer.get_train_dataloader()

In [None]:
trainer.train()

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=1.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=31284.0, style=ProgressStyle(description_…



In [None]:
for data in dataloader:
    print(data)
    break

In [None]:
data['labels'].shape