In [1]:
import os
import torch
import pickle
import pandas as pd

os.chdir("/new-stg/home/banghua/Transformer-Explainability")

from BERT_explainability.modules.BERT.ExplanationGenerator import Generator
from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification

from captum.attr import visualization
from tqdm import tqdm
from transformers import BertTokenizer

# If there's a GPU available...
if torch.cuda.is_available():
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

# Load the BERT tokenizer.
tokenizer_path = os.path.expanduser('~/med264/clinicalBERTs/pretraining_checkpoint/')
print('Loading BERT tokenizer...')
tokenizer = BertTokenizer.from_pretrained(tokenizer_path, do_lower_case=True)

model_path = os.path.expanduser('~/med264/models_balanced_clinicalBERT/')
preds_path = os.path.expanduser('~/med264/preds_balanced_clinicalBERT/')

There are 1 GPU(s) available.
We will use the GPU: NVIDIA RTX A6000
Loading BERT tokenizer...


In [2]:
df_test = pd.read_csv("~/med264/Dataset1/day1_30mortality_test.csv", index_col=0)
num_test_pos = df_test[df_test['Label'] == 1].shape[0]
num_test_neg = df_test[df_test['Label'] == 0].shape[0]
num_balanced = min(num_test_pos, num_test_neg)
df_test_pos = df_test[df_test['Label'] == 1].sample(n=num_balanced, random_state=42)
df_test_neg = df_test[df_test['Label'] == 0].sample(n=num_balanced, random_state=42)
df_test = pd.concat([df_test_pos, df_test_neg])

# Report the number of sentences.
print('Number of testing sentences: {:,}\n'.format(df_test.shape[0]))

# Get the lists of sentences and their labels.
sentences_test = df_test.TEXT.values
labels_test = df_test.Label.values

Number of testing sentences: 3,246



In [3]:
# Correct the path by expanding the tilde to the user's home directory
file_path_test = os.path.expanduser('~/med264/Dataset3/input_ids_test.pickle')


# input_ids_train, input_ids_valid, input_ids_test = [], [], []
input_ids_test = []
input_ids_encode = []

if os.path.exists(file_path_test):
    with open(file_path_test, 'rb') as f:
        input_ids_test = pickle.load(f)
    print('Loaded input_ids_test.')
# else:
    for sent in tqdm(sentences_test):
        encoded_sent = tokenizer.encode(sent, add_special_tokens = True)
        input_ids_encode.append(encoded_sent)
    # with open(file_path_test, 'wb') as f:
    #         pickle.dump(input_ids_test, f)
    # print('Saved input_ids_test.')

print('Max test sentence length: ', max([len(sen) for sen in input_ids_test]))

Loaded input_ids_test.


100%|█████████████████████████████████████████████████████████████████████████████████████████| 3246/3246 [00:13<00:00, 244.88it/s]

Max test sentence length:  818





In [4]:
orig_index = df_test.index
test_text = df_test.TEXT.values.tolist()
df_test.head()

Unnamed: 0,TEXT,Label,ID
8858,left hemithorax. new hazy increased density ri...,1,106037
42635,there are radiopaque densities at the lung bas...,1,128774
76935,sinus tachycardia. intraventricular conduction...,1,151490
108877,placement of a right frontal approach ventricu...,1,173377
60101,atrial fibrillation with slow ventricular resp...,1,140092


In [5]:
for i in range(len(input_ids_encode)):
    assert input_ids_encode[i] == input_ids_test[i]

In [6]:
file_path = os.path.expanduser('~/med264/Dataset3/processed_data.pickle')
with open(file_path, 'rb') as f:
    input_ids_train, input_ids_valid, input_ids_test, attention_masks_train, attention_masks_valid, attention_masks_test = pickle.load(f)

In [7]:
test_text[-1]

'involving the inner and outer table of the left frontal bone, extending to the roof of the left orbit. there are also fractures of the roof of the right orbit. there is a comminuted fracture of the lateral wall of the left orbit. both lamina papyracea are fractured. the right lamina papyracea is fractured posteriorly, with a bone fragment intruding on the apex formed by the extraocular muscles. there is a fracture of the lateral wall of the right orbit posteriorly, with air and blood seen within the middle cranial fossa just posterior to this fracture. there are fractures of the floors of both orbits. the inferior rectus muscles approach the fracture defect, but do not definitely cross through them. the infraorbital foramina are likely involved in the fractures bilaterally. there is a fracture of the roof of the right sphenoid air cell. there are fractures of the lateral and anterior walls of both maxillary sinuses, with an osseous fragment seen within the left maxillary sinus. the ma

In [8]:
tokenizer.decode(input_ids_test[-1])

'[CLS] involving the inner and outer table of the left frontal bone, extending to the roof of the left orbit. there are also fractures of the roof of the right orbit. there is a comminuted fracture of the lateral wall of the left orbit. both lamina papyracea are fractured. the right lamina papyracea is fractured posteriorly, with a bone fragment intruding on the apex formed by the extraocular muscles. there is a fracture of the lateral wall of the right orbit posteriorly, with air and blood seen within the middle cranial fossa just posterior to this fracture. there are fractures of the floors of both orbits. the inferior rectus muscles approach the fracture defect, but do not definitely cross through them. the infraorbital foramina are likely involved in the fractures bilaterally. there is a fracture of the roof of the right sphenoid air cell. there are fractures of the lateral and anterior walls of both maxillary sinuses, with an osseous fragment seen within the left maxillary sinus. 

In [9]:
train_inputs, validation_inputs, test_inputs, test_labels = input_ids_train, input_ids_valid, input_ids_test, labels_test
# Do the same for the masks.
train_masks, validation_masks, test_masks = attention_masks_train, attention_masks_valid, attention_masks_test

# Convert all inputs and labels into torch tensors, the required datatype
test_inputs = torch.tensor(test_inputs)
test_labels = torch.tensor(test_labels)
test_masks = torch.tensor(test_masks)


from torch.utils.data import TensorDataset, DataLoader, SequentialSampler

batch_size = 64

# Create the DataLoader for our test set.
test_data = TensorDataset(test_inputs, test_masks, test_labels)
test_sampler = SequentialSampler(test_data)
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)

In [10]:
from tqdm import tqdm
import numpy as np

acc_results = {}
preds_results = {}

i=1
print('Model ' + str(i) + ':')
model_path_i = model_path + str(i) + '/'
print('Loading model from ' + model_path_i + '...')
model = BertForSequenceClassification.from_pretrained(
model_path_i, num_labels = 2, output_attentions = False, output_hidden_states = False
)
model.cuda()
model.eval()
classifications = ["NOT_DEAD", "DEAD"]

Model 1:
Loading model from /new-stg/home/banghua/med264/models_balanced_clinicalBERT/1/...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /new-stg/home/banghua/med264/models_balanced_clinicalBERT/1/ and are newly initialized: ['bert.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
def get_one_pred(i):
    # Unpack the inputs from our dataloader
    b_input_ids, b_input_mask, b_labels = test_inputs[i:i+1].cuda(), test_masks[i:i+1].cuda(), test_labels[i:i+1].cuda()
    # Telling the model not to compute or store gradients, saving memory and
    # speeding up prediction
    outputs = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)
    logits = outputs[0]
    # Move logits and labels to CPU
    logits = logits.detach().cpu().numpy()
    label_ids = b_labels.to('cpu').numpy()
    pred = np.argmax(logits)
    return pred, label_ids

In [12]:
get_one_pred(0)



(1, array([1]))

In [13]:
import pickle
with open(preds_path+"/1/preds.pickle", "rb") as f:
    preds_results = pickle.load(f)

test_true_labels = preds_results['test_true_labels']
test_preds = preds_results['test_preds']

# Find test_TPs given test_true_labels and test_preds
test_TPs = []
test_TPs_orig_index = []
for i in range(len(test_true_labels)):
    if test_true_labels[i] == 1 and test_preds[i] == 1:
        test_TPs.append(i)
        test_TPs_orig_index.append(orig_index[i])


# Find test_TNs given test_true_labels and test_preds
test_TNs = []
test_TNs_orig_index = []
for i in range(len(test_true_labels)):
    if test_true_labels[i] == 0 and test_preds[i] == 0:
        test_TNs.append(i)
        test_TNs_orig_index.append(orig_index[i])

# Find test_FPs given test_true_labels and test_preds
test_FPs = []
test_FPs_orig_index = []
for i in range(len(test_true_labels)):
    if test_true_labels[i] == 0 and test_preds[i] == 1:
        test_FPs.append(i)
        test_FPs_orig_index.append(orig_index[i])

# Find test_FNs given test_true_labels and test_preds
test_FNs = []
test_FNs_orig_index = []
for i in range(len(test_true_labels)):
    if test_true_labels[i] == 1 and test_preds[i] == 0:
        test_FNs.append(i)
        test_FNs_orig_index.append(orig_index[i])

In [14]:
print("test_TPs: ", len(test_TPs), "test_TNs: ", len(test_TNs), "test_FPs: ", len(test_FPs), "test_FNs: ", len(test_FNs))

test_TPs:  1475 test_TNs:  785 test_FPs:  838 test_FNs:  148


In [15]:
def process_one(input_ids, attention_mask, true_class, explanations):
    # generate an explanation for the input
    expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=0)[0]
    # normalize scores
    expl = (expl - expl.min()) / (expl.max() - expl.min())

    # get the model classification
    output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)
    classification = output.argmax(dim=-1).item()
    # get class name
    class_name = classifications[classification]
    # if the classification is negative, higher explanation scores are more negative
    # flip for visualization
    if class_name == "DEAD":
        expl *= (-1)

    tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())
    tokens_results = [(tokens[i], expl[i].item()) for i in range(len(tokens))]
    return tokens_results

In [16]:
explanations = Generator(model)

def get_many(inputs, masks, labels, explanations):
    token_results = []
    # viz_data_records = []
    
    for j in tqdm(range(inputs.shape[0])):
        # input_ids, attention_mask, expl, output, classification, tokens, tokens_output = None, None, None, None, None, None, None
        input_ids = inputs[j].unsqueeze(0).to(device)
        attention_mask = masks[j].unsqueeze(0).to(device)
        input_label = labels[j]
        # tokens_output, viz = process_one(input_ids, attention_mask, input_label, explanations)
        tokens_output = process_one(input_ids, attention_mask, input_label, explanations)
        token_results.append(tokens_output)
    # viz_data_records.append(viz)

    return token_results

# import pickle
# save_file_path = "/new-stg/home/banghua/med264/trans_inter/token_results_{}.pkl".format(i)
# with open(save_file_path, 'wb') as f:
#     pickle.dump(token_results, f)

In [17]:
test_TPs_50_token_results = get_many(test_inputs[test_TPs[:50]], test_masks[test_TPs[:50]], test_labels[test_TPs[:50]], explanations)
test_TPs_50_token_results_obj = {
    'test_TPs_50_token_results': test_TPs_50_token_results,
    'test_TPs_50_orig_index': test_TPs_orig_index[:50]
}
# Save test_TPs_50_token_results
with open("/new-stg/home/banghua/med264/trans_inter_2/test_TPs_50_token_results.pkl", 'wb') as f:
    pickle.dump(test_TPs_50_token_results_obj, f)

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:12<00:00,  4.10it/s]


In [17]:
test_TNs_50_token_results = get_many(test_inputs[test_TNs[:50]], test_masks[test_TNs[:50]], test_labels[test_TNs[:50]], explanations)
test_TNs_50_token_results_obj = {
    'test_TNs_50_token_results': test_TNs_50_token_results,
    'test_TNs_50_orig_index': test_TNs_orig_index[:50]
}
# Save test_TNs_50_token_results
with open("/new-stg/home/banghua/med264/trans_inter_2/test_TNs_50_token_results.pkl", 'wb') as f:
    pickle.dump(test_TNs_50_token_results_obj, f)

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:13<00:00,  3.57it/s]


In [17]:
test_FPs_50_token_results = get_many(test_inputs[test_FPs[:50]], test_masks[test_FPs[:50]], test_labels[test_FPs[:50]], explanations)
test_FPs_50_token_results_obj = {
    'test_FPs_50_token_results': test_FPs_50_token_results,
    'test_FPs_50_orig_index': test_FPs_orig_index[:50]
}
# Save test_FPs_50_token_results
with open("/new-stg/home/banghua/med264/trans_inter_2/test_FPs_50_token_results.pkl", 'wb') as f:
    pickle.dump(test_FPs_50_token_results_obj, f)

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:12<00:00,  4.09it/s]


In [17]:
test_FNs_50_token_results = get_many(test_inputs[test_FNs[:50]], test_masks[test_FNs[:50]], test_labels[test_FNs[:50]], explanations)
test_FNs_50_token_results_obj = {
    'test_FNs_50_token_results': test_FNs_50_token_results,
    'test_FNs_50_orig_index': test_FNs_orig_index[:50]
}
# Save test_FNs_50_token_results
with open("/new-stg/home/banghua/med264/trans_inter_2/test_FNs_50_token_results.pkl", 'wb') as f:
    pickle.dump(test_FNs_50_token_results_obj, f)

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:12<00:00,  4.11it/s]
