In [None]:
!pip install transformers
!pip install transformers-interpret

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
import torch
import pandas as pd
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers_interpret import SequenceClassificationExplainer
from transformers import AutoTokenizer, AutoModelForSequenceClassification

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 2
num_epochs = 101
batch_size = 32
hidden_dim = 256
max_length = 256

In [None]:
class MedicalTCDataset(Dataset):
    def __init__(self, data):
        self.data = data['medical_abstract']
        self.labels = data['condition_label']
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        return self.data[index], self.labels[index]

In [None]:
train = pd.read_csv("/content/gdrive/MyDrive/XAI/data/medical_tc_train_cleaned.csv")
test =  pd.read_csv("/content/gdrive/MyDrive/XAI/data/medical_tc_test_cleaned.csv")
labels = pd.read_csv("/content/gdrive/MyDrive/XAI/data/medical_tc_labels.csv")

In [None]:
train_dataset = MedicalTCDataset(train)
test_dataset = MedicalTCDataset(test)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)
for data, labels in train_dataloader:
    break

In [None]:
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-base-cased-v1.1")
model = AutoModelForSequenceClassification.from_pretrained("dmis-lab/biobert-base-cased-v1.1")
model.classifier = nn.Sequential(
    nn.Linear(model.config.hidden_size, hidden_dim),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(hidden_dim, num_classes)
)
model.load_state_dict(torch.load("/content/gdrive/MyDrive/XAI/models/biobert_fine_tuned_epoch_20.pt")['model_state_dict'])
model.to(device)
print()

In [None]:
model.config.id2label = {0: 'nocardio', 1: 'cardio'}
model.config.label2id = {'nocardio': 0, 'cardio': 1}

In [None]:
cls_explainer = SequenceClassificationExplainer(model,tokenizer)

## Cardio Data

In [None]:
cardio_data = 'giant left atriuma case report seventysevenyear old woman with mitral stenosis presented with cardiomegaly evident her chest roentgenogram cardiac enlargement due giant left atrium that distorted cardiac structures echocardiogram and firstpass nuclear angiogram able delineate huge left atrium'
word_attributions = cls_explainer(cardio_data)
cls_explainer.visualize("bioBert.html")

# Non-cardio data

In [None]:
non_cardio_data = 'indomethacin responsive hypercalcaemia associated with renal sarcoma infant presented with nonmetastatic renal spindle cell sarcoma and hypercalcaemia which resolved after treatment with indomethacin there vivo and vitro evidence that hypercalcaemia mediated by circulatory prostaglandins'
word_attributions = cls_explainer(non_cardio_data)
cls_explainer.visualize("bioBert.html")