In [1]:
import os
import torch
import pickle
import pandas as pd

os.chdir("/new-stg/home/banghua/Transformer-Explainability")

from BERT_explainability.modules.BERT.ExplanationGenerator import Generator
from BERT_explainability.modules.BERT.BertForSequenceClassification import BertForSequenceClassification

from captum.attr import visualization
from tqdm import tqdm
from transformers import BertTokenizer, AdamW, BertConfig
from transformers.modeling_outputs import SequenceClassifierOutput

from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

# If there's a GPU available...
if torch.cuda.is_available():
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

# Load the BERT tokenizer.
print('Loading BERT tokenizer...')
tokenizer = BertTokenizer.from_pretrained('medicalai/ClinicalBERT', do_lower_case=True)

model_path = os.path.expanduser('~/med264/models_balanced/')
preds_path = os.path.expanduser('~/med264/preds_balanced/')

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DistilBertTokenizer'. 
The class this function is called from is 'BertTokenizer'.


There are 1 GPU(s) available.
We will use the GPU: NVIDIA RTX A6000
Loading BERT tokenizer...


In [2]:
df_test = pd.read_csv("~/med264/Dataset1/day1_30mortality_test.csv", index_col=0)
num_test_pos = df_test[df_test['Label'] == 1].shape[0]
num_test_neg = df_test[df_test['Label'] == 0].shape[0]
num_balanced = min(num_test_pos, num_test_neg)
df_test_pos = df_test[df_test['Label'] == 1].sample(n=num_balanced, random_state=42)
df_test_neg = df_test[df_test['Label'] == 0].sample(n=num_balanced, random_state=42)
df_test = pd.concat([df_test_pos, df_test_neg])

# Report the number of sentences.
print('Number of testing sentences: {:,}\n'.format(df_test.shape[0]))

# Get the lists of sentences and their labels.
sentences_test = df_test.TEXT.values
labels_test = df_test.Label.values

Number of testing sentences: 3,246



In [3]:
# Correct the path by expanding the tilde to the user's home directory
file_path_test = os.path.expanduser('~/med264/Dataset2/input_ids_test.pickle')


# input_ids_train, input_ids_valid, input_ids_test = [], [], []
input_ids_test = []


if os.path.exists(file_path_test):
    with open(file_path_test, 'rb') as f:
        input_ids_test = pickle.load(f)
    print('Loaded input_ids_test.')
else:
    for sent in tqdm(sentences_test):
        encoded_sent = tokenizer.encode(sent, add_special_tokens = True)
        input_ids_test.append(encoded_sent)
    with open(file_path_test, 'wb') as f:
            pickle.dump(input_ids_test, f)
    print('Saved input_ids_test.')
    

# print('Max train sentence length: ', max([len(sen) for sen in input_ids_train]))
# print('Max valid sentence length: ', max([len(sen) for sen in input_ids_valid]))
print('Max test sentence length: ', max([len(sen) for sen in input_ids_test]))

file_path = os.path.expanduser('~/med264/Dataset2/processed_data.pickle')
with open(file_path, 'rb') as f:
    input_ids_train, input_ids_valid, input_ids_test, attention_masks_train, attention_masks_valid, attention_masks_test = pickle.load(f)

Loaded input_ids_test.
Max test sentence length:  878


In [4]:
train_inputs, validation_inputs, test_inputs, test_labels =\
input_ids_train, input_ids_valid, input_ids_test, labels_test
# Do the same for the masks.
train_masks, validation_masks, test_masks = attention_masks_train, attention_masks_valid, attention_masks_test

# Convert all inputs and labels into torch tensors, the required datatype
# for our model.
# train_inputs = torch.tensor(train_inputs)
# validation_inputs = torch.tensor(validation_inputs)
test_inputs = torch.tensor(test_inputs)

# train_labels = torch.tensor(train_labels)
# validation_labels = torch.tensor(validation_labels)
test_labels = torch.tensor(test_labels)

# train_masks = torch.tensor(train_masks)
# validation_masks = torch.tensor(validation_masks)
test_masks = torch.tensor(test_masks)

from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

# The DataLoader needs to know our batch size for training, so we specify it
# here.
# For fine-tuning BERT on a specific task, the authors recommend a batch size of
# 16 or 32.

batch_size = 64

# Create the DataLoader for our test set.
test_data = TensorDataset(test_inputs, test_masks, test_labels)
test_sampler = SequentialSampler(test_data)
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)

In [5]:
import pickle
with open(preds_path+"/4/preds.pickle", "rb") as f:
    preds_results = pickle.load(f)

test_true_labels = preds_results['test_true_labels']
test_preds = preds_results['test_preds']

# Find test_TPs given test_true_labels and test_preds
test_TPs = []
for i in range(len(test_true_labels)):
    if test_true_labels[i] == 1 and test_preds[i] == 1:
        test_TPs.append(i)

# Find test_TNs given test_true_labels and test_preds
test_TNs = []
for i in range(len(test_true_labels)):
    if test_true_labels[i] == 0 and test_preds[i] == 0:
        test_TNs.append(i)

# Find test_FPs given test_true_labels and test_preds
test_FPs = []
for i in range(len(test_true_labels)):
    if test_true_labels[i] == 0 and test_preds[i] == 1:
        test_FPs.append(i)

# Find test_FNs given test_true_labels and test_preds
test_FNs = []
for i in range(len(test_true_labels)):
    if test_true_labels[i] == 1 and test_preds[i] == 0:
        test_FNs.append(i)

In [6]:
print("test_TPs: ", len(test_TPs), "test_TNs: ", len(test_TNs), "test_FPs: ", len(test_FPs), "test_FNs: ", len(test_FNs))

test_TPs:  1362 test_TNs:  971 test_FPs:  652 test_FNs:  261


In [7]:
i = 4

print('Model ' + str(i) + ':')
model_path_i = os.path.expanduser('~/med264/models_balanced/' + str(i) + '/')
print('Loading model from ' + model_path_i)
model = BertForSequenceClassification.from_pretrained(
    model_path_i, num_labels = 2, output_attentions = False, output_hidden_states = False
    )
model.cuda()
model.eval()

classifications = ["NOT_DEAD", "DEAD"]

Model 4:
Loading model from /new-stg/home/banghua/med264/models_balanced/4/


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /new-stg/home/banghua/med264/models_balanced/4/ and are newly initialized: ['bert.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
def process_one(input_ids, attention_mask, true_class, explanations):
    # generate an explanation for the input
    expl = explanations.generate_LRP(input_ids=input_ids, attention_mask=attention_mask, start_layer=0)[0]
    # normalize scores
    expl = (expl - expl.min()) / (expl.max() - expl.min())

    # get the model classification
    output = torch.nn.functional.softmax(model(input_ids=input_ids, attention_mask=attention_mask)[0], dim=-1)
    classification = output.argmax(dim=-1).item()
    # get class name
    class_name = classifications[classification]
    # if the classification is negative, higher explanation scores are more negative
    # flip for visualization
    if class_name == "DEAD":
        expl *= (-1)

    tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())
    tokens_results = [(tokens[i], expl[i].item()) for i in range(len(tokens))]
    return tokens_results

In [9]:
explanations = Generator(model)

def get_many(inputs, masks, labels, explanations):
    token_results = []
    # viz_data_records = []
    
    for j in tqdm(range(inputs.shape[0])):
        # input_ids, attention_mask, expl, output, classification, tokens, tokens_output = None, None, None, None, None, None, None
        input_ids = inputs[j].unsqueeze(0).to(device)
        attention_mask = masks[j].unsqueeze(0).to(device)
        input_label = labels[j]
        # tokens_output, viz = process_one(input_ids, attention_mask, input_label, explanations)
        tokens_output = process_one(input_ids, attention_mask, input_label, explanations)
        token_results.append(tokens_output)
    # viz_data_records.append(viz)

    return token_results

# import pickle
# save_file_path = "/new-stg/home/banghua/med264/trans_inter/token_results_{}.pkl".format(i)
# with open(save_file_path, 'wb') as f:
#     pickle.dump(token_results, f)

In [10]:
test_TPs_50_token_results = get_many(test_inputs[test_TPs[:50]], test_masks[test_TPs[:50]], test_labels[test_TPs[:50]], explanations)
# Save test_TPs_50_token_results
with open("/new-stg/home/banghua/med264/trans_inter/test_TPs_50_token_results.pkl", 'wb') as f:
    pickle.dump(test_TPs_50_token_results, f)

100%|█████████████████████████████████████████████████████████████████████████████████| 50/50 [00:12<00:00,  4.00it/s]


In [10]:
test_TNs_50_token_results = get_many(test_inputs[test_TNs[:50]], test_masks[test_TNs[:50]], test_labels[test_TNs[:50]], explanations)
# Save test_TNs_50_token_results
with open("/new-stg/home/banghua/med264/trans_inter/test_TNs_50_token_results.pkl", 'wb') as f:
    pickle.dump(test_TNs_50_token_results, f)

100%|█████████████████████████████████████████████████████████████████████████████████| 50/50 [00:12<00:00,  3.86it/s]


In [12]:
test_FPs_50_token_results = get_many(test_inputs[test_FPs[:50]], test_masks[test_FPs[:50]], test_labels[test_FPs[:50]], explanations)
# Save test_FPs_50_token_results
with open("/new-stg/home/banghua/med264/trans_inter/test_FPs_50_token_results.pkl", 'wb') as f:
    pickle.dump(test_FPs_50_token_results, f)

100%|█████████████████████████████████████████████████████████████████████████████████| 50/50 [00:12<00:00,  3.95it/s]


In [10]:
test_FNs_50_token_results = get_many(test_inputs[test_FNs[:50]], test_masks[test_FNs[:50]], test_labels[test_FNs[:50]], explanations)
# Save test_FNs_50_token_results
with open("/new-stg/home/banghua/med264/trans_inter/test_FNs_50_token_results.pkl", 'wb') as f:
    pickle.dump(test_FNs_50_token_results, f)

100%|█████████████████████████████████████████████████████████████████████████████████| 50/50 [00:12<00:00,  4.03it/s]
