In [1]:
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt

import torch 
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from transformers import AutoTokenizer

from loop_train_berts import (
    BertClassifier,
    preprocessing_for_bert,
    bert_predict
)

batch_size=8

if torch.cuda.is_available():       
    device = torch.device("cuda")
    print(f'There are {torch.cuda.device_count()} GPU(s) available.')
    print('Device name:', torch.cuda.get_device_name(0))
    
import shap
import scipy

There are 2 GPU(s) available.
Device name: GeForce RTX 2080 Ti


In [2]:
val = pd.read_csv('./data/Validation.tsv', sep='\t')
val['concat'] = val.Title.map(str) + " " + val.Abstract.fillna(' ').map(str)
val['bert'] = val['concat'].apply(lambda x: x.lower())

In [3]:
tokenizer = AutoTokenizer.from_pretrained('allenai/biomed_roberta_base', do_lower_case=True)

bert_classifier = torch.load('./models/' + 'full-allenai-biomed_roberta_base.pkl')

In [4]:
def predict_shap(x):
    
    val_inputs, val_masks = preprocessing_for_bert(tokenizer, x)

    val_data = TensorDataset(val_inputs, val_masks)
    val_sampler = SequentialSampler(val_data)
    val_dataloader = DataLoader(val_data, sampler=val_sampler, batch_size=batch_size, num_workers=1)
    
    outputs = bert_classifier(val_inputs.to(device), val_masks.to(device)).detach().cpu().numpy()
    scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
    val = scipy.special.logit(scores[:, 1]) # use one vs rest logit units
    
    return val

In [5]:
# build an explainer using a token masker
explainer = shap.Explainer(predict_shap, tokenizer)

In [164]:
shap_values = explainer(val.bert[:200],
                        batch_size=batch_size,
                        error_bounds=True,
                        fixed_context=1)

Partition explainer: 201it [03:56,  1.22s/it]                         


In [268]:
shap.plots.text(shap_values[75])

In [None]:
val_inputs, val_masks = preprocessing_for_bert(tokenizer, val.bert[0])

val_data = TensorDataset(val_inputs, val_masks)
val_sampler = SequentialSampler(val_data)
val_dataloader = DataLoader(val_data, sampler=val_sampler, batch_size=1, num_workers=1)

In [69]:
def bert_predict(model, test_dataloader, device, return_hidden=False):

    model.eval()

    all_logits = []


    for batch in test_dataloader:

        b_input_ids, b_attn_mask = tuple(t.to(device) for t in batch)[:2]

        with torch.no_grad():
            logits = model(b_input_ids, b_attn_mask)
        all_logits.append(logits)
        
        if return_hidden==True:
            with torch.no_grad():
                logits, model_out = model(b_input_ids, b_attn_mask, return_hidden_state=True)

                return model_out # hidden_state 
       
    all_logits = torch.cat(all_logits, dim=0)

    probs = F.softmax(all_logits, dim=1).cpu().numpy()

    return probs

In [284]:
positive = pd.read_csv('data/positive.tsv', sep='\t', index_col=0)
positive['target'] = 1
negative = pd.read_csv('data/negative.tsv', sep='\t', index_col=0)
negative['target'] = 0
data = positive.append(negative)
data['concat'] = data.Title.map(str) + " " + data.Abstract.fillna(' ').map(str)
data['bert'] = data['concat'].apply(lambda x: x.lower())

In [445]:
train_inputs, train_masks = preprocessing_for_bert(tokenizer, data.bert[:100])

train_data = TensorDataset(train_inputs, train_masks)
train_sampler = SequentialSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=8, num_workers=1)

In [446]:
last_layer = np.empty([0, 512, 768])
for batch in train_dataloader:
    model_output = bert_predict(bert_classifier, train_dataloader,
                                device, return_hidden=True).last_hidden_state.cpu().detach().numpy()
    last_layer = np.append(last_layer, model_output, axis=0)

In [447]:
train_encodings = [tokenizer(x, return_offsets_mapping=True, padding=True, truncation=True) for x in data.bert[:100]]

In [448]:
def single_word_tuples(encoding):
    desired_output = []
    for word_id in encoding.word_ids():
        if word_id is not None:
            start, end = encoding.word_to_tokens(word_id)
            desired_output.append((start, end))

    data = set(desired_output)
    sort = sorted(data, key=lambda tup: tup[0])
    
    return sort

def extract_word_vecs(sort, encodings, last_hidden_layer_output, strategy='avg'):
    words = []
    word_vecs = np.empty([0, 768])
    for (start, end) in sort:
        decode_lst = tokenizer.batch_decode(encodings.input_ids)
        
        if strategy == 'avg':
            vector = last_hidden_layer_output[start:end, :].mean(axis=0)
        elif strategy == 'max':
            vector = last_hidden_layer_output[start:end, :].max(axis=0)
        
        word_vecs = np.append(word_vecs, vector[None, ...], axis=0)
            
        words.append(''.join(decode_lst[start:end]).strip())
    
    return words, word_vecs


def get_full_words_and_vecs(train_encodings, bert_output):
    words = []
    vecs = np.empty([0, 768])
    for tokens, vectors in zip(train_encodings, bert_output):
        sort = single_word_tuples(tokens)
        sentence_words, sentence_word_vecs = extract_word_vecs(sort, tokens, vectors, strategy='max')
        words.extend(sentence_words)
        vecs = np.append(vecs, sentence_word_vecs, axis=0)
        
    return words, vecs

In [449]:
words, vecs = get_full_words_and_vecs(train_encodings, last_layer)

In [453]:
from sklearn.manifold import TSNE
import plotly.express as px

X_2D = TSNE(n_components=2, perplexity=50, random_state=42).fit_transform(vecs)
tsne_pd = pd.DataFrame(X_2D)
tsne_pd['word'] = words

In [451]:
fig = px.scatter(tsne_pd, 0, 1,
                 text = 'word', 
                 width=600, size_max=2)

fig.update_layout(
    autosize=False,
    margin=dict(l=20, r=20, t=20, b=20),
)

fig.show()

In [444]:
tsne_pd[tsne_pd['word'] == 'chemotherapy']

Unnamed: 0,0,1,word
248,53.942917,5.915286,chemotherapy
582,-60.452583,2.385796,chemotherapy
1314,44.227264,-5.434875,chemotherapy
1343,-53.886646,7.494791,chemotherapy
1681,-35.942814,-2.043328,chemotherapy
4342,20.892435,12.998721,chemotherapy
4778,57.303295,9.528056,chemotherapy


In [433]:
X_2D[::2]

array([[-37.89786 ,  31.823969],
       [-36.154022,  34.93345 ],
       [-33.97284 ,  33.56395 ],
       ...,
       [-12.334059,  64.78005 ],
       [-13.875397,  65.39892 ],
       [-14.301754,  65.52406 ]], dtype=float32)