In [None]:
!pip install shap
!pip install transformers

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
import shap
import torch
import pandas as pd
import scipy as sp
import numpy as np
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import DataLoader, Dataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 2
batch_size = 32
hidden_dim = 256
max_length = 256

In [None]:
class MedicalTCDataset(Dataset):
    def __init__(self, data):
        self.data = data['medical_abstract']
        self.labels = data['condition_label']
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        return self.data[index], self.labels[index]

In [None]:
train = pd.read_csv("/content/gdrive/MyDrive/XAI/data/medical_tc_train_cleaned.csv")
test =  pd.read_csv("/content/gdrive/MyDrive/XAI/data/medical_tc_test_cleaned.csv")
labels = pd.read_csv("/content/gdrive/MyDrive/XAI/data/medical_tc_labels.csv")

In [None]:
train_dataset = MedicalTCDataset(train)
test_dataset = MedicalTCDataset(test)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)
for data, labels in train_dataloader:
    break

In [None]:
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-base-cased-v1.1")
model = AutoModelForSequenceClassification.from_pretrained("dmis-lab/biobert-base-cased-v1.1")
model.classifier = nn.Sequential(
    nn.Linear(model.config.hidden_size, hidden_dim),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(hidden_dim, num_classes)
)
model.load_state_dict(torch.load("/content/gdrive/MyDrive/XAI/models/biobert_fine_tuned_epoch_20.pt")['model_state_dict'])
model.to(device)
print()

## Cardio Data

In [None]:
cardio_data = ['giant left atriuma case report seventysevenyear old woman with mitral stenosis presented with cardiomegaly evident her chest roentgenogram cardiac enlargement due giant left atrium that distorted cardiac structures echocardiogram and firstpass nuclear angiogram able delineate huge left atrium',
               'management brucella endocarditis with aortic root abscess three cases brucella endocarditis with aortic root abscess reported two patients successfully managed by combination medical therapy and surgery third patient died suddenly 36 hours after admission hospital',
               'rib compression coronary arteries this report describes finding coronary artery narrowing caused by compression by overlying rib two patients with cardiomegaly there probably no clinical significance this finding primary differential diagnostic entity myocardial bridging',
               'dynamic cardiomyoplasty chronic chagas heart disease clinicopathological data we report 44yearold man with chronic chagasic cardiomyopathy who underwent latissimus dorsi dynamic cardiomyoplasty and died 4 months later clinicopathological findings discussed and literature reviewed',
               'autonomic dysfunction and guillainbarre syndrome use esmolol its management 17yearold girl with guillainbarre syndrome and autonomic dysfunction treated successfully with esmolol esmolol may be appropriate drug rapid assessment and control tachyarrhythmias critically ill patients']

labels = [1, 1, 1, 1, 1]


In [None]:
def shapCalculations(x):
    tv = torch.tensor([tokenizer.encode(v, padding='max_length', max_length=128, truncation=True) for v in x]).cuda()
    attention_mask = (tv!=0).type(torch.int64).cuda()
    outputs = model(tv,attention_mask=attention_mask)[0].detach().cpu().numpy()
    scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
    val = sp.special.logit(scores[:,1])
    return val

explainer = shap.Explainer(shapCalculations, tokenizer)
shap_values = explainer({'label': labels, 'text': cardio_data}, fixed_context=1)

In [None]:
shap.plots.waterfall(shap_values[3])

In [None]:
shap.plots.text(shap_values[0:5])

In [None]:
shap.plots.bar(shap_values.abs.max(0))

### Non cardio data

In [None]:
non_cardio_data = [
    'meningitis due protozoa and helminths this article reviews microbiology pathogenesis epidemiology clinical manifestations diagnostic tests and recent advances therapy protozoan and helminthic infections central nervous system with more emphasis given protozoan than helminthic infections',
    'indomethacin responsive hypercalcaemia associated with renal sarcoma infant presented with nonmetastatic renal spindle cell sarcoma and hypercalcaemia which resolved after treatment with indomethacin there vivo and vitro evidence that hypercalcaemia mediated by circulatory prostaglandins',
    'ten cases transitional cell carcinoma bladder causing ureteric obstruction review carried out 10 patients with superficial transitional cell carcinoma bladder ta lesions that causing ureteric obstruction evidence upper tract obstruction did not necessarily indicate deep invasion',
    'preventing colorectal cancer knowledgeable patients should not die colorectal cancer increasing intake dietary fiber decreasing fat consumption and increasing use modern technology detect adenomatous polyps and early cancer can greatly decrease mortality associated with colorectal cancer',
    'unusual complication ingested foreign body migration foreign body from mouth and throat subcutaneous tissue neck very rare we present case migrating foreign body piece straw from floor mouth neck our knowledge this second case reported english literature'
    ]

labels = [0,0,0,0,0]


In [None]:
def shapCalculations(x):
    tv = torch.tensor([tokenizer.encode(v, padding='max_length', max_length=128, truncation=True) for v in x]).cuda()
    attention_mask = (tv!=0).type(torch.int64).cuda()
    outputs = model(tv,attention_mask=attention_mask)[0].detach().cpu().numpy()
    scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
    val = sp.special.logit(scores[:,1])
    return val

explainer = shap.Explainer(shapCalculations, tokenizer)
shap_values = explainer({'label': labels, 'text': non_cardio_data}, fixed_context=1)

In [None]:
shap.plots.waterfall(shap_values[3])

In [None]:
shap.plots.text(shap_values[0:5])

In [None]:
shap.plots.bar(shap_values.abs.max(0))