In [None]:
# torch and related packages
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data.dataset import Dataset
import torch.utils.data.dataloader as dataloader

# general packages
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import trange, tqdm

# reading lime_tabular
from lime import lime_tabular

In [None]:
# Set random seed for reproducibility.
np.random.seed(0)
torch.manual_seed(0) 

device ="cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
class DatasetWBC(Dataset):
    def __init__(self, dataset_root):
        self.wbcd_df = pd.read_csv(os.path.join(dataset_root,"breast-cancer.csv"))

        # remove the id column
        self.wbcd_df = self.wbcd_df.drop(["id","Unnamed: 32"], axis=1)

        # replace M with 1 and B with 0 for the diagnosis column
        diag_map = {'M':1, 'B':0}
        self.wbcd_df["diagnosis"] = self.wbcd_df["diagnosis"].map(diag_map)

        # Convert features and labels to numpy arrays.
        self.wbcd_labels = torch.LongTensor(self.wbcd_df["diagnosis"].to_numpy())
        self.wbcd_df = self.wbcd_df.drop(["diagnosis"], axis=1)
        
        self.feature_names = list(self.wbcd_df.columns)
        self.wbcd_data = torch.FloatTensor(self.wbcd_df.to_numpy())

        # normalizing features 
        self.wbcd_data -= self.wbcd_data.mean(0, keepdims=True)
        self.wbcd_data /= self.wbcd_data.std(0, keepdims=True)
        
        
    def __getitem__(self, index):
        return self.wbcd_data[index], self.wbcd_labels[index]

    def __len__(self):
        return self.wbcd_data.shape[0]

In [None]:
data_set_root = "../data"
dataset = DatasetWBC(data_set_root)

# Split training data into train and validation set with 90/10% training/validation split
validation_split = 0.8

n_train_examples = int(len(dataset)*validation_split)
n_valid_examples = len(dataset) - n_train_examples
train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [n_train_examples, n_valid_examples],
                                                       generator=torch.Generator().manual_seed(42))

In [None]:
# Using the Pytorch dataloader class and the Pytorch datasets we with create itterable dataloader objects
train_loader = dataloader.DataLoader(train_dataset, shuffle=True, batch_size=128) 
valid_loader = dataloader.DataLoader(valid_dataset, shuffle=False, batch_size=128)

In [None]:
# let's check the records
dataset.wbcd_df

In [None]:
class MLP(nn.Module):
    def __init__(self, input_size=30, hidden_size=32):
        super().__init__()
        self.fc1  = nn.Linear(input_size, hidden_size)
        self.fc2  = nn.Linear(hidden_size, hidden_size)
        self.fc3  = nn.Linear(hidden_size, 2)
        self.elu = nn.ELU()

    def forward(self, x):
        x = self.elu(self.fc1(x))
        x = self.elu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
num_epochs = 50
net = MLP().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

In [None]:
train_loss_log = []
train_acc_log = []
valid_acc_log = []

In [None]:
for epoch in trange(num_epochs, leave=False):   
    net.train()
    train_acc = 0
    for (data, labels) in tqdm(train_loader, leave=False):
        data = data.to(device)
        labels = labels.to(device)

        output = net(data)
        loss = criterion(output, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_acc += (output.argmax(1) == labels).sum().item()
        
    train_acc_log.append(train_acc/len(train_dataset))
    
    valid_acc = 0
    net.eval()
    with torch.no_grad():
        for (data, labels) in tqdm(valid_loader, leave=False):
            data = data.to(device)
            labels = labels.to(device)

            output = net(data)
            valid_acc += (output.argmax(1) == labels).sum().item()
        valid_acc_log.append(valid_acc/len(valid_dataset))

In [None]:
_ = plt.plot(train_acc_log)
_ = plt.plot(valid_acc_log)
_ = plt.legend(["train", "validation"])

In [None]:
def mlp_predict(inp_array_numpy):
    net.eval()
    inp_tensor = torch.FloatTensor(inp_array_numpy).to(device)    
    logits = net(inp_tensor)
    probs = F.softmax(logits, dim=1).cpu().detach().numpy()
    return probs

In [None]:
wbcd_class_names = ["benign", "malignant"]

train_data = train_dataset.dataset.wbcd_data.numpy()
train_labels = train_dataset.dataset.wbcd_labels.numpy()

feature_names = train_dataset.dataset.feature_names
explainer = lime_tabular.LimeTabularExplainer(train_data, mode="classification",
                                              class_names=wbcd_class_names,
                                              feature_names=feature_names,
                                             )

In [None]:
idx = 0
inp_explainer = np.expand_dims(train_data[idx], axis=0)
explanation = explainer.explain_instance(train_data[idx], mlp_predict,
                                         num_features=len(feature_names))

print("Prediction : ", wbcd_class_names[np.argmax(mlp_predict(inp_explainer))])
print("Actual :     ", wbcd_class_names[train_labels[idx]])

explanation.show_in_notebook()