In [1]:
import subprocess
import json
import os

import torch

In [2]:
os.chdir('/home/ssbae/bae/kg_txt_multimodal/lxmert/src/')
EXP_PATH = os.path.abspath('..')

# prescriptions

In [3]:
# ======================= CONFIG ==================== #
## GPU setting
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
## TASK & DB
TASK_NAME = 'generation'
DB = 'px'
DB_size = 1000
MODEL_TYPE = 'both'
Unified = True
Align = False
Relation_Classification = False
Scratch_Downstream = False
## Important Model Config
Dim_Hidden = 128
NUM_Layers = {'lang':2, 'kg':2, 'cross':4}
Dropout = 0.1
Num_Negatives = 1
Margin = 1.0
# ======================= CONFIG ==================== #
# Variables
Var_TASK = {}
Var_MODEL = {'both':'KGenc_LMinit', 'lm':'LMinit', 'kg':'KGenc', 'rand':'Randinit'}
Var_Unified = 'Unified' if Unified else ''
Var_Align = 'Align_' if Align else ''
Var_RC = 'RC_' if Relation_Classification else ''
assert MODEL_TYPE in Var_MODEL, "Model not supported"
assert DB in ['px','dx,prx'], "DB not supported"
assert TASK_NAME in ['pretrain', 'binary_retrieval', 'generation', 'single_pretrain', 'single_binary_retrieval', 'single_generation'], "Task not supported"
if Scratch_Downstream is True:
    assert Align is False and Relation_Classification is False, "Scratch start downstream task must turn off alignment prediction & relation classification"

# Model Name
## <LMinit & KGenc> : both, <LMinit only> : lm, <KGenc only> : kg, <RandomInit> : rand
## Unified(Placeholder) for Abstract Node : True or False
MODEL_NAME = f'{DB}_{Var_Unified}{"Uni" if MODEL_TYPE in ["both","kg"] else "No"}KGenc'
RUN_NAME = f'{DB}/{Var_MODEL[MODEL_TYPE]}_H{Dim_Hidden}_L{NUM_Layers["lang"]},{NUM_Layers["kg"]},{NUM_Layers["cross"]}_{Var_Align}{Var_RC}{Var_Unified}{DB_size}'

In [4]:
TRAINING_CONFIG = {
    "seed":1234,
    "model_type":"lxmert",
    "do_train": True,
    "evaluate_during_training": True,
    "do_eval": True,
    "edge_cls": Relation_Classification,
    "align": Align,
    "n_negatives": Num_Negatives,
    "prediction_loss_only": False,
    "overwrite_output_dir": False,
    "mlm_probability": 0.15,
    "block_size": 512,
    "per_device_train_batch_size": 1,
    "per_device_eval_batch_size": 1,
    "learning_rate": 1e-4,
    "num_train_epochs": 40,
    "num_log_per_epoch": 20,
    "num_save_per_epoch": -1,
    "num_eval_per_epoch": 2,
    "task" : TASK_NAME,
    "train_data_file":os.path.join(EXP_PATH,f"data/{DB}_{DB_size}/{MODEL_NAME}/train"),
    "eval_data_file": os.path.join(EXP_PATH,f"data/{DB}_{DB_size}/{MODEL_NAME}/valid"),
    "test_data_file": os.path.join(EXP_PATH,f"data/{DB}_{DB_size}/{MODEL_NAME}/test"),
    "run_name": f"{TASK_NAME}_{RUN_NAME}"
}

In [5]:
import easydict
args = easydict.EasyDict(TRAINING_CONFIG)

In [24]:
temp_model_name_or_path = '/home/ssbae/bae/kg_txt_multimodal/lxmert/pretrained_models/\
generation/px/KGenc_LMinit_H128_L2,2,4_Unified1000'

In [25]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(temp_model_name_or_path)

from transformers import AutoConfig
config = AutoConfig.from_pretrained(temp_model_name_or_path)

from model import LxmertForGeneration
model = LxmertForGeneration.from_pretrained(
    temp_model_name_or_path,
    tokenizer=tokenizer
)

Some weights of the model checkpoint at /home/ssbae/bae/kg_txt_multimodal/lxmert/pretrained_models/generation/px/KGenc_LMinit_H128_L2,2,4_Unified1000 were not used when initializing LxmertForGeneration: ['classifier.weight', 'classifier.bias', 'edge_classifier.0.weight', 'edge_classifier.0.bias', 'edge_classifier.2.weight', 'edge_classifier.2.bias']
- This IS expected if you are initializing LxmertForGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LxmertForGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LxmertForGeneration were not initialized from the model checkpoint at /home/ssbae/bae/kg_txt_multimodal/lxmert/pretrained_models/generation/px/KGenc_LMinit_H12

In [26]:
from utils.dataset import get_dataset

# Get datasets
# train_dataset = get_dataset(args,
#                             tokenizer=tokenizer,
#                             token_type_vocab=config.token_type_vocab,
#                             )
eval_dataset = get_dataset(args,
                           tokenizer=tokenizer,
                           token_type_vocab=config.token_type_vocab,
                           evaluate=True
                           )
# test_dataset = get_dataset(args,
#                            tokenizer=tokenizer,
#                            token_type_vocab=config.token_type_vocab,
#                            test=True
#                            ) if training_args.do_eval else None

In [27]:
from utils.data_collator import UniLM_DataCollator

data_collator = UniLM_DataCollator(tokenizer=tokenizer,
                                       kg_special_token_ids=config.kg_special_token_ids,
                                       prediction=True)

In [28]:
node2id = torch.load(os.path.join(args.eval_data_file.replace('/valid',''), 'unified_node'))
id2node = {v:k.split('^^')[0] for k,v in node2id.items()}

# generation

In [29]:
def _prepare_inputs(inputs, device):
    if isinstance(inputs, dict):
        for k,v in inputs.items():
            if isinstance(v, torch.Tensor):
                inputs[k] = v.to(device)
    return inputs

def _prepare_model(model, device):
    model.to(device)
    return model

<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>

## sample1: eval_dataset, idx=0

In [30]:
base_idx = 0
sample = data_collator(eval_dataset[base_idx:base_idx+1])

In [31]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
sample = _prepare_inputs(sample, device)
model = _prepare_model(model, device)

### normal case

In [32]:
results = model(**sample, given_lang_tokens=1, know_gt_leng=False)

original graph

In [33]:
# [[t.item(),id2node[t.item()]] for t in sample['kg_input_ids'][0] if t > 0]
[[t.item(), id2node[t.item()]] for t in sample['kg_input_ids'][0] if t > config.num_relations]

[[111, '"ih"'],
 [93, '"1neb"'],
 [333, '"albuterol 0.083% neb soln"'],
 [628, '"albu3h"'],
 [320, '"lact30l"'],
 [128, '"30ml"'],
 [372, '"lactulose"'],
 [16, '"1000ml"'],
 [17, '"iv"'],
 [259, '"lr1000"'],
 [24, '"base"'],
 [427, '"lr"'],
 [295, '"15mmol"'],
 [626, '"kpho45i"'],
 [17, '"iv"'],
 [160, '"potassium phosphate"'],
 [223, '"morp2i"'],
 [17, '"iv"'],
 [177, '"morphine sulfate"'],
 [45, '"1mg"'],
 [280, '"hald5i"'],
 [17, '"iv"'],
 [522, '"haloperidol"'],
 [1045, '"0.25mg"'],
 [223, '"morp2i"'],
 [17, '"iv"'],
 [177, '"morphine sulfate"'],
 [108, '"2-4mg"'],
 [35, '"250ml"'],
 [188, '"ns"'],
 [24, '"base"'],
 [139, '"ns250"'],
 [17, '"iv"'],
 [115, '"25mg"'],
 [1182, '"quetiapine fumarate"'],
 [398, '"quet25"'],
 [62, '"insulin"'],
 [33, '"sc"'],
 [62, '"insulin"'],
 [53, '"0unit"'],
 [1707, '"morphine sulfate (concentrated oral soln)"'],
 [1586, '"morpconc"'],
 [597, '"5-10mg"'],
 [33, '"sc"'],
 [391, '"hepa5i"'],
 [263, '"heparin"'],
 [112, '"5000unit"'],
 [30, '"lidocaine

original text

In [34]:
temp = results[2][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. bisacodyl 10 mg suppository sig : suppositorys rectal daily ( daily ) as needed for constipation. 2. ipratropium bromide 0. 02 % solution sig : one ( 1 ) nebulizer inhalation q6h ( every 6 hours ) as needed for wheezing. 3. albuterol sulfate 2. 5 mg / 3 ml solution for nebulization sig : one ( 1 ) solution inhalation q6h ( every 6 hours ) as needed for wheezing. 4. morphine concentrate 20 mg / ml solution sig : 5 - 10 mg po q4h ( every 4 hours ) as needed for pain : may shorten interval as needed to control pain. 5. olanzapine 5 mg tablet, rapid dissolve sig : 0. 5 tablet, rapid dissolve po qhs ( once a day ( at bedtime ) ) as needed for agitation. 6. tamsulosin 0. 4 mg capsule, sust. release 24 hr sig : one ( 1 ) capsule, sust. release 24 hr po daily ( daily ) : may be discontinued if patient not tolerating pills or refusing to take. [SEP]'

generation output

In [35]:
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. acetaminophen 325 mg tablet sig : 1 - 2 tablets po q6h ( every 6 hours ) as needed. 2. docusate sodium 50 mg / 5 ml liquid sig : one ( 650 ) mg po bid ( 2 times a day ). 3. senna 8. mg tablet sig : one ( 1 ) tablet po bid ( 2 times a day as.. 3. senna 8. 6 mg tablet si bid ( 2 ) tablet bid ( 2 times a day ) as needed. 4. bis. bisacodyl 5 mg tablet, delayed release ( e. c. ) sig : one ( 1 ) tablet, delayed release ( e. c. c. ) po daily ( daily ). 5. ) 5. docusate sodium 6 mg / 5 ml liquid sig : one ( 1 ) po bid ( 2 times a day as.. 6. bisacodyl 5 mg tablet, sig : one ( 1 tablet, delayed release ( e. c. ) po daily ( daily ). 7. bis. bisacodyl 5 mg tablet sig : one ( 1 ) tablet po daily ( daily ) 7. bis. 6. ace 6 mg tablet sig'

<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>

## sample2: eval_dataset, idx=822

In [73]:
base_idx = 822
sample = data_collator(eval_dataset[base_idx:base_idx+1])

In [74]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
sample = _prepare_inputs(sample, device)
model = _prepare_model(model, device)

### normal case

In [75]:
results = model(**sample, given_lang_tokens=1, know_gt_leng=False)

original graph

In [76]:
# [[t.item(),id2node[t.item()]] for t in sample['kg_input_ids'][0] if t > 0]
[[t.item(), id2node[t.item()]] for t in sample['kg_input_ids'][0] if t > config.num_relations]

[[20, '"atorvastatin"'],
 [43, '"20mg"'],
 [576, '"ator20"'],
 [39, '"2.5mg"'],
 [63, '"warf25"'],
 [419, '"warfarin"'],
 [873, '"ezetimibe"'],
 [159, '"10mg"'],
 [1496, '"ezet10"'],
 [3248, '"penicillin v potassium"'],
 [29, '"250mg"'],
 [3637, '"penv250"'],
 [364, '"325mg"'],
 [798, '"ferr325"'],
 [477, '"ferrous sulfate"'],
 [873, '"ezetimibe"'],
 [159, '"10mg"'],
 [1496, '"ezet10"'],
 [1057, '"amiodarone hcl"'],
 [315, '"amio150i"'],
 [281, '"150mg"'],
 [199, '"iv bolus"'],
 [17, '"iv"'],
 [127, '"1000mg"'],
 [244, '"vancomycin hcl"'],
 [233, '"vanc1f"'],
 [490, '"aspirin ec"'],
 [265, '"81mg"'],
 [25, '"asa81ec"'],
 [70, '"furo40i"'],
 [59, '"40mg"'],
 [49, '"furosemide"'],
 [17, '"iv"'],
 [328, '"ranitidine"'],
 [505, '"rani150"'],
 [281, '"150mg"'],
 [98, '"docusate sodium"'],
 [19, '"docu100"'],
 [96, '"100mg"'],
 [419, '"warfarin"'],
 [261, '"1dose"'],
 [829, '"warf0"'],
 [127, '"1000mg"'],
 [244, '"vancomycin hcl"'],
 [233, '"vanc1f"'],
 [17, '"iv"'],
 [508, '"calc500"'],
 [1

original text

In [77]:
temp = results[2][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

"[CLS] 1. docusate sodium 100 mg capsule sig : one ( 1 ) capsule po bid ( 2 times a day ). 2. ranitidine hcl 150 mg tablet sig : one ( 1 ) tablet po once a day. 3. aspirin 81 mg tablet, delayed release ( e. c. ) sig : one ( 1 ) tablet, delayed release ( e. c. ) po daily ( daily ). 4. ferrous sulfate 325 ( 65 ) mg tablet sig : one ( 1 ) tablet po daily ( daily ). 5. levothyroxine 100 mcg tablet sig : one ( 1 ) tablet po daily ( daily ). 6. atorvastatin 20 mg tablet sig : one ( 1 ) tablet po daily ( daily ). 7. ezetimibe 10 mg tablet sig : one ( 1 ) tablet po daily ( daily ). 8. tramadol 50 mg tablet sig : one ( 1 ) tablet po q4 - 6h ( every 4 to 6 hours ) as needed. 9. amiodarone 200 mg tablet sig : two ( 2 ) tablet po once a day : 400 mg daily x 1 week, then 200 mg daily ongoing ( until dc'd by cardiologist ). 10. penicillin v potassium 250 mg tablet sig : one ( 1 ) tablet po q6h ( every 6 hours ). disp : * 120 tablet ( s ) * refills : * 2 * 11. warfarin 1 mg tablet sig : as directed t

generation output

In [78]:
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. docusate sodium 100 mg capsule sig : one ( 1 ) capsule po bid ( 2 times a day ). disp : * 60 capsule ( s ) * refills : * 0 * 2. aspirin 81 mg tablet, delayed release ( e release ( e c. c. ) sig : one ( 1 ) tablet, delayed release ( e. c. ) po daily ( daily daily ). disp : * tablet delayed release ( e. c. ) ( s ) * ref refills : * 0 * ref refills : * 0 * 2 * 2. atorvastatin 20 mg tablet sig : one ( 1 ) tablet po daily ( daily ( ). disp : * 30 30 tablet ( s ) * refills : * 2 * * 2 3. atorvastatin 20 mg tablet sig : one ( 1 ) tablet po daily ( daily ). disp : * 30 tablet ( s ) * refills : * 2 * 4 4. atorvastatin 20 mg tablet sig : one ( 1 ) tablet po daily ( daily ). disp : * 30 tablet ( s ) * refills : * 2 2 * 5. ox hydrocodone - acetaminophen 5 - 500 mg tablet sig : 1 2 tablets po q4h ( every 4 hours ) as needed. disp : * 50 50 tablet ( s ) * refills : 0 * * 6. metopoprolol tartrate 25 mg tablet sig : one ( 1 ) tablet po bid ( 2 times a day ). disp : * 60 tablet ( s ) * ref re

<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>

## sample3: eval_dataset, idx=777

In [110]:
base_idx = 777
sample = data_collator(eval_dataset[base_idx:base_idx+1])

In [111]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
sample = _prepare_inputs(sample, device)
model = _prepare_model(model, device)

### normal case

In [112]:
results = model(**sample, given_lang_tokens=1, know_gt_leng=False)

original graph

In [113]:
# [[t.item(),id2node[t.item()]] for t in sample['kg_input_ids'][0] if t > 0]
[[t.item(), id2node[t.item()]] for t in sample['kg_input_ids'][0] if t > config.num_relations]

[[1970, '"minx25"'],
 [162, '"5mg"'],
 [2219, '"minoxidil"'],
 [57, '"50ml"'],
 [24, '"base"'],
 [300, '"frbd50"'],
 [17, '"iv"'],
 [127, '"1000mg"'],
 [50, '"vancomycin"'],
 [233, '"vanc1f"'],
 [17, '"iv"'],
 [737, '"metr500pm"'],
 [17, '"iv"'],
 [241, '"metronidazole (flagyl)"'],
 [154, '"500mg"'],
 [17, '"iv"'],
 [285, '"dextrose 50%"'],
 [774, '"25gm"'],
 [264, '"dex50sy"'],
 [448, '"senn187"'],
 [248, '"1tab"'],
 [255, '"senna"'],
 [59, '"40mg"'],
 [164, '"pantoprazole"'],
 [701, '"pant40"'],
 [17, '"iv"'],
 [363, '"cefepime"'],
 [629, '"1g"'],
 [140, '"cefe1i"'],
 [117, '"500ml"'],
 [17, '"iv"'],
 [483, '"d5w500"'],
 [24, '"base"'],
 [89, '"5% dextrose"'],
 [17, '"iv"'],
 [122, '"ns/mbp100i"'],
 [240, '"100ml"'],
 [176, '"0.9% sodium chloride (mini bag plus)"'],
 [24, '"base"'],
 [121, '"piperacillin-tazobactam"'],
 [17, '"iv"'],
 [630, '"zosy2.25i"'],
 [635, '"2.25g"'],
 [159, '"10mg"'],
 [17, '"iv"'],
 [344, '"hydralazine"'],
 [461, '"hydz20i"'],
 [552, '"amlodipine"'],
 [245, 

original text

In [114]:
temp = results[2][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. bisacodyl 5 mg tablet, delayed release ( e. c. ) : two ( 2 ) tablet, delayed release ( e. c. ) po daily ( daily ) as needed for constipation. 2. senna 8. 6 mg tablet : one ( 1 ) tablet po bid prn as needed for constipation. 3. amlodipine 5 mg tablet : two ( 2 ) tablet po daily ( daily ). 4. simvastatin 40 mg tablet : one ( 1 ) tablet po daily ( daily ). 5. metoprolol tartrate 50 mg tablet : one ( 1 ) tablet po bid ( 2 times a day ). 6. lisinopril 20 mg tablet : one ( 1 ) tablet po daily ( daily ). 7. minoxidil 2. 5 mg tablet : two ( 2 ) tablet po bid ( 2 times a day ). 8. lansoprazole 30 mg tablet, rapid dissolve, dr : one ( 1 ) tablet, rapid dissolve, dr daily ( daily ). 9. aspirin 81 mg tablet, chewable : one ( 1 ) tablet, chewable po daily ( daily ). 10. calcium acetate 667 mg capsule : two ( 2 ) capsule po tid w / meals ( 3 times a day with meals ). 11. memantine 5 mg tablet : one ( 1 ) tablet po qhs ( once a day ( at bedtime ) ). 12. docusate sodium 100 mg tablet : 1 - 2

generation output

In [115]:
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. amlodipine 5 mg tablet : one ( 1 ) tablet po daily ( daily ). 2. amlodipine 5 mg tablet : one ( 1 ) po daily ( daily ). 2. amlodipine 5 mg tablet : one ( 1 ) tablet po daily ( daily ( daily ). 3. amlodipine 5 mg tablet : one ( 1 ) tablet po daily daily ( daily ). 4. 4. amlodipine 5 mg tablet : one 1 ) tablet po daily ( daily ). 5. 5. am 5. amlodipine 5 mg tablet po daily ( daily ) 6. amlodipine 5 mg : one ( 1 ) tablet po daily ( daily ). 6. amlodipine 5 mg tablet : ( 5 tablets po daily ( daily ). 7. metoprolol tartrate 25 mg : 1. tablet po bid ( ). 8. amlodipine 5 mg : 1 5. 9. amlodipine 5 mg tablet : po daily 10. 9. amlodipine 5 mg tablet po daily 10. am. amdipine 5 mg tablet : one ( ) po daily ( daily ( daily ) 10.. metoprolol tartrate 25 mg po bid 12. 5 mg po bid 12. am.. 5 mg tablet : ( 5 mg tablet po bid ( 2xx week ( ).'

### random init

In [116]:
results = model(**sample, given_lang_tokens=1, perturb_type='init_all', know_gt_leng=False)

generation output

In [117]:
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. fluticasone propionate 110 mcg / actuation aerosol sig : two ( 2 ) puff inhalation inhalation ( 2 times a day. ). 2. fluticasone propate 110 mcg / dose disk with device sig si : sig : one ( 1 ) inhalation ( 2 ) inhalation ( 2 times a day ). 3. 4. 5 mg capsule sig : one ( 1 ) capsule inhalation twice a day. 4. oxcocet 5 - 5 mg capsule sig : one ( ) capsule po twice a day. 5. oxcocet 5 mg capsule sig : one ( 1 ) capsule po twice a day. 5. 5 - 6. trodone 50 mg tablet sig one ( 1 ) tablet po hs as needed for insomn.'

### use only pre-training model

In [118]:
model_pt = LxmertForGeneration.from_pretrained(
    config._name_or_path,
    tokenizer=tokenizer
)
model_pt = _prepare_model(model_pt, device)

Some weights of the model checkpoint at /home/ssbae/bae/kg_txt_multimodal/lxmert/pretrained_models/pretrain/px/KGenc_LMinit_H128_L2,2,4_Unified1000 were not used when initializing LxmertForGeneration: ['classifier.weight', 'classifier.bias', 'edge_classifier.0.weight', 'edge_classifier.0.bias', 'edge_classifier.2.weight', 'edge_classifier.2.bias']
- This IS expected if you are initializing LxmertForGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LxmertForGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LxmertForGeneration were not initialized from the model checkpoint at /home/ssbae/bae/kg_txt_multimodal/lxmert/pretrained_models/pretrain/px/KGenc_LMinit_H128_L2

In [119]:
results = model_pt(**sample, given_lang_tokens=1, perturb_type=None, know_gt_leng=False)

generation output

In [120]:
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS]xy oxcodone 5 mg po q4h4 - 6 - 6h 2 mg po q4 - 6h4 - 6h4 6 -4 6 -4 q4 - 4 6 - 44 -4 - 64h 6 mg 6 q mg po q44 - 64 - 6h4 - 6h 6 -4 6 - 64 q4 - 64 - 6 - 6h4 - 64 -4 qh -4 q4 - 6h 4 - 6 4 - 6 - 6 -4 6 - 6 4 - 6 hours 6 4 hours - 6 6 every 4 - 6 - 6 every 6 - 64 - 6 - 6 6 - 6 - 6 every 4 - 6 - 6 - 6 every 4 - 6 - 6 - 64 6 - 6 every 4 - 6 - 64h 6 - 64 - 6 - 64 - 6 hours 6 hours 6 hours hours to 4 every 6 to 6 6 6 to 4 6 - 6 every 4 6 - 6 6 - 64 6 - 6 - 64 6 - 6 - 6 6 - 6 6 every 6 - 6 hours 6 hours 6 to 6 hours to 6 6 to 6 - 6 - 6 everyh4 - 6 - 6en 4 - 64 - 64 - 6 - 64 - 6h 6 - 6 q4 q 6 - 6 - 64 - 6 - 6 - 64 - 6 - every 4h 6 -4 qh4 - 6 - 64 - 6 6 - 6 4 - 64 -h 64 - 6h4 - 64 - 6 every 6 - 6 6 - 64 - 64 6 - 6'

<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>

## sample4: eval_dataset, idx=938

In [79]:
base_idx = 938
sample = data_collator(eval_dataset[base_idx:base_idx+1])

In [80]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
sample = _prepare_inputs(sample, device)
model = _prepare_model(model, device)

### normal case

In [81]:
results = model(**sample, given_lang_tokens=1, know_gt_leng=False)

original graph

In [82]:
# [[t.item(),id2node[t.item()]] for t in sample['kg_input_ids'][0] if t > 0]
[[t.item(), id2node[t.item()]] for t in sample['kg_input_ids'][0] if t > config.num_relations]

[[348, '"kcl40i"'],
 [85, '"40meq"'],
 [157, '"potassium chloride"'],
 [17, '"iv"'],
 [263, '"heparin"'],
 [33, '"sc"'],
 [112, '"5000unit"'],
 [391, '"hepa5i"'],
 [1571, '"liepi20i"'],
 [33, '"sc"'],
 [254, '"1ml"'],
 [2519, '"lidocaine 1%/epinephrine 1:100000"'],
 [527, '"dila100"'],
 [96, '"100mg"'],
 [712, '"phenytoin"'],
 [57, '"50ml"'],
 [17, '"iv"'],
 [484, '"sw50"'],
 [368, '"sw"'],
 [24, '"base"'],
 [17, '"iv"'],
 [85, '"40meq"'],
 [157, '"potassium chloride"'],
 [348, '"kcl40i"'],
 [23, '"ondansetron"'],
 [17, '"iv"'],
 [119, '"ondan4i"'],
 [92, '"4mg"'],
 [157, '"potassium chloride"'],
 [85, '"40meq"'],
 [204, '"pota20"'],
 [75, '"300mg"'],
 [527, '"dila100"'],
 [712, '"phenytoin"'],
 [518, '"200mg"'],
 [527, '"dila100"'],
 [712, '"phenytoin"'],
 [262, '"neut"'],
 [66, '"neutra-phos"'],
 [757, '"1pkt"'],
 [712, '"phenytoin"'],
 [527, '"dila100"'],
 [96, '"100mg"'],
 [24, '"base"'],
 [483, '"d5w500"'],
 [117, '"500ml"'],
 [14, '"d5w"'],
 [17, '"iv"'],
 [157, '"potassium chlor

original text

In [83]:
temp = results[2][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. phenytoin sodium extended 100 mg capsule sig : three ( 3 ) take 2 capsules in the morning and 1 capsule at night. disp : * 90 capsule ( s ) * refills : * 2 * [SEP]'

generation output

In [84]:
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. phenytoin sodium extended 100 mg capsule sig : one ( 1 ) capsule po tid ( 3 times a day ). disp : * 60 capsule ( s ) * refills : * 0 *'

### random init

In [85]:
results = model(**sample, given_lang_tokens=1, perturb_type='init_all', know_gt_leng=False)

generation output

In [86]:
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. fluticasone propionate 110 mcg / actuation aerosol sig : one ( 1 ) inhalation inhalation ( 2 times a day.. disp : * 1 1 bottle * refills :'

### use only pre-training model

In [87]:
model_pt = LxmertForGeneration.from_pretrained(
    config._name_or_path,
    tokenizer=tokenizer
)
model_pt = _prepare_model(model_pt, device)

Some weights of the model checkpoint at /home/ssbae/bae/kg_txt_multimodal/lxmert/pretrained_models/pretrain/px/KGenc_LMinit_H128_L2,2,4_Unified1000 were not used when initializing LxmertForGeneration: ['classifier.weight', 'classifier.bias', 'edge_classifier.0.weight', 'edge_classifier.0.bias', 'edge_classifier.2.weight', 'edge_classifier.2.bias']
- This IS expected if you are initializing LxmertForGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LxmertForGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LxmertForGeneration were not initialized from the model checkpoint at /home/ssbae/bae/kg_txt_multimodal/lxmert/pretrained_models/pretrain/px/KGenc_LMinit_H128_L2

In [88]:
results = model_pt(**sample, given_lang_tokens=1, perturb_type=None, know_gt_leng=False)

generation output

In [89]:
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS]xy oxcodone 5 mg po q4h4 - 6 -4h 6 - 6h every 6 to 4 to 6 -4 6h 6 to 4 to 4 to 4 to 6 6 to 4 6 to 6h4'

### perturbation

perturb a node of graph

In [90]:
sample_perturb = sample.copy()

In [91]:
# original graph
[[t.item(), id2node[t.item()]] for t in sample['kg_input_ids'][0] if t > config.num_relations]

[[348, '"kcl40i"'],
 [85, '"40meq"'],
 [157, '"potassium chloride"'],
 [17, '"iv"'],
 [263, '"heparin"'],
 [33, '"sc"'],
 [112, '"5000unit"'],
 [391, '"hepa5i"'],
 [1571, '"liepi20i"'],
 [33, '"sc"'],
 [254, '"1ml"'],
 [2519, '"lidocaine 1%/epinephrine 1:100000"'],
 [527, '"dila100"'],
 [96, '"100mg"'],
 [712, '"phenytoin"'],
 [57, '"50ml"'],
 [17, '"iv"'],
 [484, '"sw50"'],
 [368, '"sw"'],
 [24, '"base"'],
 [17, '"iv"'],
 [85, '"40meq"'],
 [157, '"potassium chloride"'],
 [348, '"kcl40i"'],
 [23, '"ondansetron"'],
 [17, '"iv"'],
 [119, '"ondan4i"'],
 [92, '"4mg"'],
 [157, '"potassium chloride"'],
 [85, '"40meq"'],
 [204, '"pota20"'],
 [75, '"300mg"'],
 [527, '"dila100"'],
 [712, '"phenytoin"'],
 [518, '"200mg"'],
 [527, '"dila100"'],
 [712, '"phenytoin"'],
 [262, '"neut"'],
 [66, '"neutra-phos"'],
 [757, '"1pkt"'],
 [712, '"phenytoin"'],
 [527, '"dila100"'],
 [96, '"100mg"'],
 [24, '"base"'],
 [483, '"d5w500"'],
 [117, '"500ml"'],
 [14, '"d5w"'],
 [17, '"iv"'],
 [157, '"potassium chlor

In [104]:
sample_perturb['kg_input_ids'] = sample_perturb['kg_input_ids'].not_equal(712) * sample_perturb['kg_input_ids']
sample_perturb['kg_input_ids'] = sample_perturb['kg_input_ids'].not_equal(32) * sample_perturb['kg_input_ids']

In [105]:
# perturbed graph
[[t.item(), id2node[t.item()]] for t in sample_perturb['kg_input_ids'][0] if t > config.num_relations]

[[348, '"kcl40i"'],
 [85, '"40meq"'],
 [157, '"potassium chloride"'],
 [17, '"iv"'],
 [263, '"heparin"'],
 [33, '"sc"'],
 [112, '"5000unit"'],
 [391, '"hepa5i"'],
 [1571, '"liepi20i"'],
 [33, '"sc"'],
 [254, '"1ml"'],
 [2519, '"lidocaine 1%/epinephrine 1:100000"'],
 [527, '"dila100"'],
 [96, '"100mg"'],
 [57, '"50ml"'],
 [17, '"iv"'],
 [484, '"sw50"'],
 [368, '"sw"'],
 [24, '"base"'],
 [17, '"iv"'],
 [85, '"40meq"'],
 [157, '"potassium chloride"'],
 [348, '"kcl40i"'],
 [23, '"ondansetron"'],
 [17, '"iv"'],
 [119, '"ondan4i"'],
 [92, '"4mg"'],
 [157, '"potassium chloride"'],
 [85, '"40meq"'],
 [204, '"pota20"'],
 [75, '"300mg"'],
 [527, '"dila100"'],
 [518, '"200mg"'],
 [527, '"dila100"'],
 [262, '"neut"'],
 [66, '"neutra-phos"'],
 [757, '"1pkt"'],
 [527, '"dila100"'],
 [96, '"100mg"'],
 [24, '"base"'],
 [483, '"d5w500"'],
 [117, '"500ml"'],
 [14, '"d5w"'],
 [17, '"iv"'],
 [157, '"potassium chloride"'],
 [129, '"kcl20p"'],
 [82, '"60meq"'],
 [415, '"diaz5"'],
 [358, '"diazepam"'],
 [162

In [106]:
results = model(**sample, given_lang_tokens=1, know_gt_leng=False)
results_perturb = model(**sample_perturb, given_lang_tokens=1, know_gt_leng=False)

In [107]:
# generate text given an original graph
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. phenytoin sodium extended 100 mg capsule sig : one ( 1 ) capsule po tid ( 3 times a day ). disp : * 60 capsule ( s ) * refills : * 0 *'

In [108]:
# generate text given a perturbed graph
temp = results_perturb[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. phenytoin sodium extended 100 mg capsule sig : one ( 1 ) capsule po tid ( 3 times a day ). disp : * 60 capsule ( s ) * refills : * 0 *'

In [109]:
# ground truth text
temp = results[2][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. phenytoin sodium extended 100 mg capsule sig : three ( 3 ) take 2 capsules in the morning and 1 capsule at night. disp : * 90 capsule ( s ) * refills : * 2 * [SEP]'