In [28]:
import subprocess
import json
import os

import torch

In [31]:
os.chdir('/home/ssbae/bae/kg_txt_multimodal/lxmert/src/')
EXP_PATH = os.path.abspath('..')

# diagnosis, procedures

In [33]:
# ======================= CONFIG ==================== #
## GPU setting
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
## TASK & DB
TASK_NAME = 'generation'
DB = 'dx,prx'
DB_size = 2000
MODEL_TYPE = 'both'
Unified = True
Align = False
Relation_Classification = False
Scratch_Downstream = False
## Important Model Config
Dim_Hidden = 128
NUM_Layers = {'lang':2, 'kg':2, 'cross':4}
Dropout = 0.1
Num_Negatives = 1
Margin = 1.0
# ======================= CONFIG ==================== #
# Variables
Var_TASK = {}
Var_MODEL = {'both':'KGenc_LMinit', 'lm':'LMinit', 'kg':'KGenc', 'rand':'Randinit'}
Var_Unified = 'Unified' if Unified else ''
Var_Align = 'Align_' if Align else ''
Var_RC = 'RC_' if Relation_Classification else ''
assert MODEL_TYPE in Var_MODEL, "Model not supported"
assert DB in ['px','dx,prx'], "DB not supported"
assert TASK_NAME in ['pretrain', 'binary_retrieval', 'generation', 'single_pretrain', 'single_binary_retrieval', 'single_generation'], "Task not supported"
if Scratch_Downstream is True:
    assert Align is False and Relation_Classification is False, "Scratch start downstream task must turn off alignment prediction & relation classification"

# Model Name
## <LMinit & KGenc> : both, <LMinit only> : lm, <KGenc only> : kg, <RandomInit> : rand
## Unified(Placeholder) for Abstract Node : True or False
MODEL_NAME = f'{DB}_{Var_Unified}{"Uni" if MODEL_TYPE in ["both","kg"] else "No"}KGenc'
RUN_NAME = f'{DB}/{Var_MODEL[MODEL_TYPE]}_H{Dim_Hidden}_L{NUM_Layers["lang"]},{NUM_Layers["kg"]},{NUM_Layers["cross"]}_{Var_Align}{Var_RC}{Var_Unified}{DB_size}'

In [34]:
TRAINING_CONFIG = {
    "seed":1234,
    "model_type":"lxmert",
    "do_train": True,
    "evaluate_during_training": True,
    "do_eval": True,
    "edge_cls": Relation_Classification,
    "align": Align,
    "n_negatives": Num_Negatives,
    "prediction_loss_only": False,
    "overwrite_output_dir": False,
    "mlm_probability": 0.15,
    "block_size": 512,
    "per_device_train_batch_size": 1,
    "per_device_eval_batch_size": 1,
    "learning_rate": 1e-4,
    "num_train_epochs": 40,
    "num_log_per_epoch": 20,
    "num_save_per_epoch": -1,
    "num_eval_per_epoch": 2,
    "task" : TASK_NAME,
    "train_data_file":os.path.join(EXP_PATH,f"data/{DB}_{DB_size}/{MODEL_NAME}/train"),
    "eval_data_file": os.path.join(EXP_PATH,f"data/{DB}_{DB_size}/{MODEL_NAME}/valid"),
    "test_data_file": os.path.join(EXP_PATH,f"data/{DB}_{DB_size}/{MODEL_NAME}/test"),
    "run_name": f"{TASK_NAME}_{RUN_NAME}"
}

In [35]:
import easydict
args = easydict.EasyDict(TRAINING_CONFIG)

In [36]:
temp_model_name_or_path = '/home/ssbae/bae/kg_txt_multimodal/lxmert/pretrained_models/\
generation/dx,prx/KGenc_LMinit_H128_L2,2,4_Unified2000'

In [40]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(temp_model_name_or_path)

from transformers import AutoConfig
config = AutoConfig.from_pretrained(temp_model_name_or_path)

from model import LxmertForGeneration
model = LxmertForGeneration.from_pretrained(
    temp_model_name_or_path,
    tokenizer=tokenizer
)

Some weights of the model checkpoint at /home/ssbae/bae/kg_txt_multimodal/lxmert/pretrained_models/generation/dx,prx/KGenc_LMinit_H128_L2,2,4_Unified2000 were not used when initializing LxmertForGeneration: ['classifier.weight', 'classifier.bias', 'edge_classifier.0.weight', 'edge_classifier.0.bias', 'edge_classifier.2.weight', 'edge_classifier.2.bias']
- This IS expected if you are initializing LxmertForGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LxmertForGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LxmertForGeneration were not initialized from the model checkpoint at /home/ssbae/bae/kg_txt_multimodal/lxmert/pretrained_models/generation/dx,prx/KGenc_LM

In [41]:
from utils.dataset import get_dataset

# Get datasets
# train_dataset = get_dataset(args,
#                             tokenizer=tokenizer,
#                             token_type_vocab=config.token_type_vocab,
#                             )
eval_dataset = get_dataset(args,
                           tokenizer=tokenizer,
                           token_type_vocab=config.token_type_vocab,
                           evaluate=True
                           )
# test_dataset = get_dataset(args,
#                            tokenizer=tokenizer,
#                            token_type_vocab=config.token_type_vocab,
#                            test=True
#                            ) if training_args.do_eval else None

In [42]:
from utils.data_collator import UniLM_DataCollator

data_collator = UniLM_DataCollator(tokenizer=tokenizer,
                                       kg_special_token_ids=config.kg_special_token_ids,
                                       prediction=True)

In [49]:
node2id = torch.load(os.path.join(args.eval_data_file.replace('/valid',''), 'unified_node'))
id2node = {v:k.split('^^')[0] for k,v in node2id.items()}

# generation

In [60]:
def _prepare_inputs(inputs, device):
    if isinstance(inputs, dict):
        for k,v in inputs.items():
            if isinstance(v, torch.Tensor):
                inputs[k] = v.to(device)
    return inputs

def _prepare_model(model, device):
    model.to(device)
    return model

<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>

## sample1: eval_dataset, idx=21

In [160]:
base_idx = 21
sample = data_collator(eval_dataset[base_idx:base_idx+1])

In [161]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
sample = _prepare_inputs(sample, device)
model = _prepare_model(model, device)

### normal case

In [146]:
results = model(**sample, given_lang_tokens=1, know_gt_leng=False)

original graph

In [147]:
# [[t.item(),id2node[t.item()]] for t in sample['kg_input_ids'][0] if t > 0]
[[t.item(), id2node[t.item()]] for t in sample['kg_input_ids'][0] if t > config.num_relations]

[[2057, '"intestinal infection due to clostridium difficile"'],
 [478, '"injection or infusion of oxazolidinone class of antibiotics"'],
 [7089, '"atrial fibrillation"'],
 [1722, '"edema"'],
 [7443, '"parenteral infusion of concentrated nutritional substances"'],
 [2690, '"arterial catheterization"'],
 [3183, '"venous catheterization, not elsewhere classified"'],
 [5809, '"unspecified septicemia"'],
 [119, '"unspecified protein-calorie malnutrition"'],
 [947, '"other bipolar disorders"'],
 [6138, '"sepsis"'],
 [7000, '"other disorders of plasma protein metabolism"'],
 [275, '"pneumonia, organism unspecified"']]

original text

In [148]:
temp = results[2][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] pna c diff colitis paroxysmal a fib pancytopenia bipolar disorder [SEP] none [SEP]'

generation output

In [149]:
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. pneumonia 2. c. difficile colitis 3. atrial fibrillation 4. central'

### random init

In [150]:
results = model(**sample, given_lang_tokens=1, perturb_type='init_all', know_gt_leng=False)

generation output

In [151]:
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. left lower extremity 2. left lower extremity thrombosis 3. left lower ex'

### use only pre-training model

In [164]:
model_pt = LxmertForGeneration.from_pretrained(
    config._name_or_path,
    tokenizer=tokenizer
)
model_pt = _prepare_model(model_pt, device)

Some weights of the model checkpoint at /home/ssbae/bae/kg_txt_multimodal/lxmert/pretrained_models/pretrain/dx,prx/KGenc_LMinit_H128_L2,2,4_Unified2000 were not used when initializing LxmertForGeneration: ['classifier.weight', 'classifier.bias', 'edge_classifier.0.weight', 'edge_classifier.0.bias', 'edge_classifier.2.weight', 'edge_classifier.2.bias']
- This IS expected if you are initializing LxmertForGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LxmertForGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LxmertForGeneration were not initialized from the model checkpoint at /home/ssbae/bae/kg_txt_multimodal/lxmert/pretrained_models/pretrain/dx,prx/KGenc_LMinit

In [165]:
results = model_pt(**sample, given_lang_tokens=1, perturb_type=None, know_gt_leng=False)

generation output

In [166]:
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] sepsis sepsis sepsis sepsis sepsis sepsis sepsis sepsis sepsis sepsis sepsis sep'

### perturbation

perturb a node of graph

In [171]:
sample_perturb = sample.copy()

In [172]:
# original graph
[[t.item(), id2node[t.item()]] for t in sample['kg_input_ids'][0] if t > config.num_relations]

[[2057, '"intestinal infection due to clostridium difficile"'],
 [478, '"injection or infusion of oxazolidinone class of antibiotics"'],
 [7089, '"atrial fibrillation"'],
 [1722, '"edema"'],
 [7443, '"parenteral infusion of concentrated nutritional substances"'],
 [2690, '"arterial catheterization"'],
 [3183, '"venous catheterization, not elsewhere classified"'],
 [5809, '"unspecified septicemia"'],
 [119, '"unspecified protein-calorie malnutrition"'],
 [947, '"other bipolar disorders"'],
 [6138, '"sepsis"'],
 [7000, '"other disorders of plasma protein metabolism"'],
 [275, '"pneumonia, organism unspecified"']]

In [174]:
sample_perturb['kg_input_ids'] = sample_perturb['kg_input_ids'].not_equal(275) * sample_perturb['kg_input_ids']

In [182]:
# perturbed graph
[[t.item(), id2node[t.item()]] for t in sample_perturb['kg_input_ids'][0] if t > config.num_relations]

[[2057, '"intestinal infection due to clostridium difficile"'],
 [478, '"injection or infusion of oxazolidinone class of antibiotics"'],
 [7089, '"atrial fibrillation"'],
 [1722, '"edema"'],
 [7443, '"parenteral infusion of concentrated nutritional substances"'],
 [2690, '"arterial catheterization"'],
 [3183, '"venous catheterization, not elsewhere classified"'],
 [5809, '"unspecified septicemia"'],
 [119, '"unspecified protein-calorie malnutrition"'],
 [947, '"other bipolar disorders"'],
 [6138, '"sepsis"'],
 [7000, '"other disorders of plasma protein metabolism"']]

In [185]:
results = model(**sample, given_lang_tokens=1, know_gt_leng=False)
results_perturb = model(**sample_perturb, given_lang_tokens=1, know_gt_leng=False)

In [189]:
# generate text given an original graph
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. pneumonia 2. c. difficile colitis 3. atrial fibrillation 4. central'

In [190]:
# generate text given a perturbed graph
temp = results_perturb[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. c. difficile colitis 2. atrial fibrillation 3. difficile'

In [192]:
# ground truth text
temp = results[2][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] pna c diff colitis paroxysmal a fib pancytopenia bipolar disorder [SEP] none [SEP]'

<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>

## sample2: eval_dataset, idx=1822

In [453]:
base_idx = 1822
sample = data_collator(eval_dataset[base_idx:base_idx+1])

In [454]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
sample = _prepare_inputs(sample, device)
model = _prepare_model(model, device)

### normal case

In [455]:
results = model(**sample, given_lang_tokens=1, know_gt_leng=False)

original graph

In [456]:
# [[t.item(),id2node[t.item()]] for t in sample['kg_input_ids'][0] if t > 0]
[[t.item(), id2node[t.item()]] for t in sample['kg_input_ids'][0] if t > config.num_relations]

[[3886, '"closed fracture of shaft of humerus"'],
 [1255, '"open reduction of fracture with internal fixation, humerus"'],
 [5449, '"other incision of cranial and peripheral nerves"'],
 [536, '"unspecified fall"'],
 [5815, '"alcohol withdrawal delirium"'],
 [6781, '"alcohol abuse, continuous"'],
 [2325, '"acute posthemorrhagic anemia"']]

original text

In [304]:
temp = results[2][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. right humerus fracture 2. alcohol withdrawal [SEP] 1. orif of right humerus and radial nerve neuroplasty [SEP]'

generation output

In [305]:
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. alcohol withdrawal 2. alcohol withdrawal 3. alcohol 1. orif right humerus fracture [SEP]'

### random init

In [306]:
results = model(**sample, given_lang_tokens=1, perturb_type='init_all', know_gt_leng=False)

generation output

In [307]:
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. acute on chronic systolic heart failure 2. none [SEP]'

### use only pre-training model

In [308]:
model_pt = LxmertForGeneration.from_pretrained(
    config._name_or_path,
    tokenizer=tokenizer
)
model_pt = _prepare_model(model_pt, device)

Some weights of the model checkpoint at /home/ssbae/bae/kg_txt_multimodal/lxmert/pretrained_models/pretrain/dx,prx/KGenc_LMinit_H128_L2,2,4_Unified2000 were not used when initializing LxmertForGeneration: ['classifier.weight', 'classifier.bias', 'edge_classifier.0.weight', 'edge_classifier.0.bias', 'edge_classifier.2.weight', 'edge_classifier.2.bias']
- This IS expected if you are initializing LxmertForGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LxmertForGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LxmertForGeneration were not initialized from the model checkpoint at /home/ssbae/bae/kg_txt_multimodal/lxmert/pretrained_models/pretrain/dx,prx/KGenc_LMinit

In [309]:
results = model_pt(**sample, given_lang_tokens=1, perturb_type=None, know_gt_leng=False)

generation output

In [310]:
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] left hip fracture left hip fracture left hip fracture left hip fracture left hip fracture left hip fracture left hip fracture left hip fracture left hip fracture left'

### perturbation

perturb a node of graph

In [311]:
sample_perturb = sample.copy()

In [321]:
# original graph
[[t.item(), id2node[t.item()]] for t in sample['kg_input_ids'][0] if t > config.num_relations]

[[3886, '"closed fracture of shaft of humerus"'],
 [1255, '"open reduction of fracture with internal fixation, humerus"'],
 [5449, '"other incision of cranial and peripheral nerves"'],
 [536, '"unspecified fall"'],
 [5815, '"alcohol withdrawal delirium"'],
 [6781, '"alcohol abuse, continuous"'],
 [2325, '"acute posthemorrhagic anemia"']]

In [333]:
sample_perturb['kg_input_ids'] = sample_perturb['kg_input_ids'].not_equal(5815) * sample_perturb['kg_input_ids']
sample_perturb['kg_input_ids'] = sample_perturb['kg_input_ids'].not_equal(6781) * sample_perturb['kg_input_ids']

In [334]:
# perturbed graph
[[t.item(), id2node[t.item()]] for t in sample_perturb['kg_input_ids'][0] if t > config.num_relations]

[[3886, '"closed fracture of shaft of humerus"'],
 [5449, '"other incision of cranial and peripheral nerves"'],
 [536, '"unspecified fall"'],
 [2325, '"acute posthemorrhagic anemia"']]

In [335]:
results = model(**sample, given_lang_tokens=1, know_gt_leng=False)
results_perturb = model(**sample_perturb, given_lang_tokens=1, know_gt_leng=False)

In [336]:
# generate text given an original graph
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. alcohol withdrawal 2. right humerus fracture [SEP]'

In [337]:
# generate text given a perturbed graph
temp = results_perturb[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] s / p motor vehicle crash left inguinal hernia [SEP]'

In [338]:
# ground truth text
temp = results[2][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. right humerus fracture 2. alcohol withdrawal [SEP] 1. orif of right humerus and radial nerve neuroplasty [SEP]'

<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>

## sample3: eval_dataset, idx=788

In [350]:
base_idx = 788
sample = data_collator(eval_dataset[base_idx:base_idx+1])

In [351]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
sample = _prepare_inputs(sample, device)
model = _prepare_model(model, device)

### normal case

In [352]:
results = model(**sample, given_lang_tokens=1, know_gt_leng=False)

original graph

In [356]:
# [[t.item(),id2node[t.item()]] for t in sample['kg_input_ids'][0] if t > 0]
[[t.item(), id2node[t.item()]] for t in sample['kg_input_ids'][0] if t > config.num_relations]

[[4270, '"closed fracture of orbital floor (blow-out)"'],
 [4439, '"pneumonitis due to inhalation of food or vomitus"'],
 [261, '"toxic encephalopathy"'],
 [2341, '"temporary tracheostomy"'],
 [4626,
  '"chronic viral hepatitis b without mention of hepatic coma without mention of hepatitis delta"'],
 [4314,
  '"continuous invasive mechanical ventilation for less than 96 consecutive hours"'],
 [5603, '"closed [endoscopic] biopsy of bronchus"'],
 [1764,
  '"closed fracture of base of skull with subarachnoid, subdural, and extradural hemorrhage, with no loss of consciousness"'],
 [3176,
  '"continuous invasive mechanical ventilation for 96 consecutive hours or more"'],
 [6274, '"chronic hepatitis c without mention of hepatic coma"'],
 [4865,
  '"pseudomonas infection in conditions classified elsewhere and of unspecified site"'],
 [5066, '"fiber-optic bronchoscopy"'],
 [2431, '"insertion of endotracheal tube"'],
 [570, '"acute alcoholic hepatitis"'],
 [5815, '"alcohol withdrawal delirium"'

original text

In [359]:
temp = results[2][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] [SEP] percutaneous tracheostomy percutaneous gastric tube placement [SEP]'

generation output

In [360]:
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] s / p tracheostomy peg placement [SEP]'

### random init

In [361]:
results = model(**sample, given_lang_tokens=1, perturb_type='init_all', know_gt_leng=False)

generation output

In [362]:
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. 1. left subclavian central venous catheter placement 2'

### use only pre-training model

In [363]:
model_pt = LxmertForGeneration.from_pretrained(
    config._name_or_path,
    tokenizer=tokenizer
)
model_pt = _prepare_model(model_pt, device)

Some weights of the model checkpoint at /home/ssbae/bae/kg_txt_multimodal/lxmert/pretrained_models/pretrain/dx,prx/KGenc_LMinit_H128_L2,2,4_Unified2000 were not used when initializing LxmertForGeneration: ['classifier.weight', 'classifier.bias', 'edge_classifier.0.weight', 'edge_classifier.0.bias', 'edge_classifier.2.weight', 'edge_classifier.2.bias']
- This IS expected if you are initializing LxmertForGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LxmertForGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LxmertForGeneration were not initialized from the model checkpoint at /home/ssbae/bae/kg_txt_multimodal/lxmert/pretrained_models/pretrain/dx,prx/KGenc_LMinit

In [364]:
results = model_pt(**sample, given_lang_tokens=1, perturb_type=None, know_gt_leng=False)

generation output

In [365]:
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] pneumonia aspiration aspiration aspiration aspiration aspiration aspiration aspiration as'

### perturbation

perturb a node of graph

In [366]:
sample_perturb = sample.copy()

In [367]:
# original graph
[[t.item(), id2node[t.item()]] for t in sample['kg_input_ids'][0] if t > config.num_relations]

[[4270, '"closed fracture of orbital floor (blow-out)"'],
 [4439, '"pneumonitis due to inhalation of food or vomitus"'],
 [261, '"toxic encephalopathy"'],
 [2341, '"temporary tracheostomy"'],
 [4626,
  '"chronic viral hepatitis b without mention of hepatic coma without mention of hepatitis delta"'],
 [4314,
  '"continuous invasive mechanical ventilation for less than 96 consecutive hours"'],
 [5603, '"closed [endoscopic] biopsy of bronchus"'],
 [1764,
  '"closed fracture of base of skull with subarachnoid, subdural, and extradural hemorrhage, with no loss of consciousness"'],
 [3176,
  '"continuous invasive mechanical ventilation for 96 consecutive hours or more"'],
 [6274, '"chronic hepatitis c without mention of hepatic coma"'],
 [4865,
  '"pseudomonas infection in conditions classified elsewhere and of unspecified site"'],
 [5066, '"fiber-optic bronchoscopy"'],
 [2431, '"insertion of endotracheal tube"'],
 [570, '"acute alcoholic hepatitis"'],
 [5815, '"alcohol withdrawal delirium"'

In [374]:
sample_perturb['kg_input_ids'] = sample_perturb['kg_input_ids'].not_equal(6047) * sample_perturb['kg_input_ids']

In [375]:
# perturbed graph
[[t.item(), id2node[t.item()]] for t in sample_perturb['kg_input_ids'][0] if t > config.num_relations]

[[4270, '"closed fracture of orbital floor (blow-out)"'],
 [4439, '"pneumonitis due to inhalation of food or vomitus"'],
 [261, '"toxic encephalopathy"'],
 [2341, '"temporary tracheostomy"'],
 [4626,
  '"chronic viral hepatitis b without mention of hepatic coma without mention of hepatitis delta"'],
 [4314,
  '"continuous invasive mechanical ventilation for less than 96 consecutive hours"'],
 [5603, '"closed [endoscopic] biopsy of bronchus"'],
 [1764,
  '"closed fracture of base of skull with subarachnoid, subdural, and extradural hemorrhage, with no loss of consciousness"'],
 [3176,
  '"continuous invasive mechanical ventilation for 96 consecutive hours or more"'],
 [6274, '"chronic hepatitis c without mention of hepatic coma"'],
 [4865,
  '"pseudomonas infection in conditions classified elsewhere and of unspecified site"'],
 [5066, '"fiber-optic bronchoscopy"'],
 [2431, '"insertion of endotracheal tube"'],
 [570, '"acute alcoholic hepatitis"'],
 [5815, '"alcohol withdrawal delirium"'

In [376]:
results = model(**sample, given_lang_tokens=1, know_gt_leng=False)
results_perturb = model(**sample_perturb, given_lang_tokens=1, know_gt_leng=False)

In [377]:
# generate text given an original graph
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] s / p tracheostomy peg placement [SEP]'

In [378]:
# generate text given a perturbed graph
temp = results_perturb[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] s / p tracheostomy [SEP]'

In [379]:
# ground truth text
temp = results[2][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] [SEP] percutaneous tracheostomy percutaneous gastric tube placement [SEP]'

<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>

## sample4: eval_dataset, idx=325

In [458]:
base_idx = 325
sample = data_collator(eval_dataset[base_idx:base_idx+1])

In [459]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
sample = _prepare_inputs(sample, device)
model = _prepare_model(model, device)

### normal case

In [460]:
results = model(**sample, given_lang_tokens=1, know_gt_leng=False)

original graph

In [461]:
# [[t.item(),id2node[t.item()]] for t in sample['kg_input_ids'][0] if t > 0]
[[t.item(), id2node[t.item()]] for t in sample['kg_input_ids'][0] if t > config.num_relations]

[[6202, '"old myocardial infarction"'],
 [5909, '"hematoma complicating a procedure"'],
 [256, '"hypovolemia"'],
 [5806, '"personal history of malignant neoplasm of large intestine"'],
 [1849, '"graft of muscle or fascia"'],
 [7790, '"pyogenic arthritis, lower leg"'],
 [1821, '"attachment of pedicle or flap graft to other sites"'],
 [2559, '"personal history of malignant neoplasm of cervix uteri"'],
 [4239, '"transfusion of packed cells"'],
 [3183, '"venous catheterization, not elsewhere classified"'],
 [2483, '"other iatrogenic hypotension"'],
 [7431,
  '"other specified bacterial infections in conditions classified elsewhere and of unspecified site, other specified bacteria"'],
 [3183, '"venous catheterization, not elsewhere classified"'],
 [6207, '"transfusion of other serum"'],
 [7135, '"other arthrotomy, knee"'],
 [553,
  '"other specified surgical operations and procedures causing abnormal patient reaction, or later complication, without mention of misadventure at time of operati

original text

In [462]:
temp = results[2][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] r knee effusion, bleeding, hypotension, exposed hardware r knee [SEP] [SEP]'

generation output

In [463]:
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. left knee arthroplasty 2. left knee arthroplasty 3.'

### random init

In [464]:
results = model(**sample, given_lang_tokens=1, perturb_type='init_all', know_gt_leng=False)

generation output

In [465]:
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. acute on chronic systolic heart failure 2. pulmonary edema 3. pulmonary ed none'

### use only pre-training model

In [466]:
model_pt = LxmertForGeneration.from_pretrained(
    config._name_or_path,
    tokenizer=tokenizer
)
model_pt = _prepare_model(model_pt, device)

Some weights of the model checkpoint at /home/ssbae/bae/kg_txt_multimodal/lxmert/pretrained_models/pretrain/dx,prx/KGenc_LMinit_H128_L2,2,4_Unified2000 were not used when initializing LxmertForGeneration: ['classifier.weight', 'classifier.bias', 'edge_classifier.0.weight', 'edge_classifier.0.bias', 'edge_classifier.2.weight', 'edge_classifier.2.bias']
- This IS expected if you are initializing LxmertForGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LxmertForGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LxmertForGeneration were not initialized from the model checkpoint at /home/ssbae/bae/kg_txt_multimodal/lxmert/pretrained_models/pretrain/dx,prx/KGenc_LMinit

In [467]:
results = model_pt(**sample, given_lang_tokens=1, perturb_type=None, know_gt_leng=False)

generation output

In [468]:
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] cad cad hypertension diabetes type 2 type 2 type 2 diabetes type 2 type 2 diabetes type 2 type'

### perturbation

perturb a node of graph

In [469]:
sample_perturb = sample.copy()

In [470]:
# original graph
[[t.item(), id2node[t.item()]] for t in sample['kg_input_ids'][0] if t > config.num_relations]

[[6202, '"old myocardial infarction"'],
 [5909, '"hematoma complicating a procedure"'],
 [256, '"hypovolemia"'],
 [5806, '"personal history of malignant neoplasm of large intestine"'],
 [1849, '"graft of muscle or fascia"'],
 [7790, '"pyogenic arthritis, lower leg"'],
 [1821, '"attachment of pedicle or flap graft to other sites"'],
 [2559, '"personal history of malignant neoplasm of cervix uteri"'],
 [4239, '"transfusion of packed cells"'],
 [3183, '"venous catheterization, not elsewhere classified"'],
 [2483, '"other iatrogenic hypotension"'],
 [7431,
  '"other specified bacterial infections in conditions classified elsewhere and of unspecified site, other specified bacteria"'],
 [3183, '"venous catheterization, not elsewhere classified"'],
 [6207, '"transfusion of other serum"'],
 [7135, '"other arthrotomy, knee"'],
 [553,
  '"other specified surgical operations and procedures causing abnormal patient reaction, or later complication, without mention of misadventure at time of operati

In [477]:
sample_perturb['kg_input_ids'] = sample_perturb['kg_input_ids'].not_equal(7135) * sample_perturb['kg_input_ids']

In [478]:
# perturbed graph
[[t.item(), id2node[t.item()]] for t in sample_perturb['kg_input_ids'][0] if t > config.num_relations]

[[6202, '"old myocardial infarction"'],
 [5909, '"hematoma complicating a procedure"'],
 [256, '"hypovolemia"'],
 [5806, '"personal history of malignant neoplasm of large intestine"'],
 [1849, '"graft of muscle or fascia"'],
 [7790, '"pyogenic arthritis, lower leg"'],
 [1821, '"attachment of pedicle or flap graft to other sites"'],
 [2559, '"personal history of malignant neoplasm of cervix uteri"'],
 [4239, '"transfusion of packed cells"'],
 [3183, '"venous catheterization, not elsewhere classified"'],
 [2483, '"other iatrogenic hypotension"'],
 [7431,
  '"other specified bacterial infections in conditions classified elsewhere and of unspecified site, other specified bacteria"'],
 [3183, '"venous catheterization, not elsewhere classified"'],
 [6207, '"transfusion of other serum"'],
 [553,
  '"other specified surgical operations and procedures causing abnormal patient reaction, or later complication, without mention of misadventure at time of operation"'],
 [463, '"urge incontinence"'],

In [479]:
results = model(**sample, given_lang_tokens=1, know_gt_leng=False)
results_perturb = model(**sample_perturb, given_lang_tokens=1, know_gt_leng=False)

In [480]:
# generate text given an original graph
temp = results[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. left knee arthroplasty 2. left knee arthroplasty 3.'

In [481]:
# generate text given a perturbed graph
temp = results_perturb[0][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] 1. left lateral wall wall 2. left lateral wall wall 3. left lateral wall wall wall 1'

In [482]:
# ground truth text
temp = results[2][0]
tokenizer.decode(temp[:sum(temp.not_equal(0))])

'[CLS] r knee effusion, bleeding, hypotension, exposed hardware r knee [SEP] [SEP]'