# Set-up<br>
This set-up assumes that the working directory (`os.curdir`) is where the notebook is.

In [1]:
import os
import sys
this_notebook_dir = os.curdir
project_root_dir = os.path.relpath(os.path.join('..', '..'), this_notebook_dir)
if project_root_dir not in sys.path:
    sys.path += [project_root_dir]
from pprint import pprint

# Loading data and models<br>
We will now load both, SST and AG news datasets:

In [2]:
from src.data.dataload import *
sst, agnews = load_sst(), load_agnews()
print(f'loaded datasets {DatasetSST.NAME} and {DatasetAGNews.NAME}')

loaded datasets sst and agnews


Creating bcn model for each dataset:

In [3]:
from src.models.bcnmodel import *
bcn_sst, bcn_ag = BCNModel(), BCNModel()
print(f'expecting location for the model file at '
      f'"{bcn_sst._get_model_filepath_for_dataset(sst)}"')
bcn_sst.load_model(sst)
print(f'expecting location for the model file at '
      f'"{bcn_ag._get_model_filepath_for_dataset(agnews)}"')
bcn_ag.load_model(agnews)
print(f'loaded BCN models for {sst.NAME}, {agnews.NAME}')

expecting location for the model file at "../../models/bcn-sst_output/model.tar.gz"
expecting location for the model file at "../../models/bcn-agnews_output/model.tar.gz"
loaded BCN models for sst, agnews


Loading bert model for each dataset:

In [4]:
from src.models.bertmodel import *
bert_sst, bert_ag = BertModel(), BertModel()
print(f'expecting location for the model file at '
      f'"{bert_sst._get_model_filepath_for_dataset(sst)}"')
bert_sst.load_model(sst)
print(f'expecting location for the model file at '
      f'"{bert_ag._get_model_filepath_for_dataset(agnews)}"')
bert_ag.load_model(agnews)
print(f'loaded BERT models for {sst.NAME}, {agnews.NAME}')

expecting location for the model file at "../../models/fine-tuned-bert-base-sst"
expecting location for the model file at "../../models/fine-tuned-bert-base-agnews"
loaded BERT models for sst, agnews


# Explainers

In [5]:
from src.explainers.explainers import *

#### BCN + SST

In [6]:
lime_bcn_sst = LimeExplainer(bcn_sst)
anlp_bcn_sst = AllenNLPExplainer(bcn_sst)

#### BCN + AG News

In [7]:
lime_bcn_ag = LimeExplainer(bcn_ag)
anlp_bcn_ag = AllenNLPExplainer(bcn_ag)

#### BERT + SST

In [8]:
lime_bert_sst = LimeExplainer(bert_sst)
shap_bert_sst = SHAPExplainer(bert_sst)

#### BERT + AG News

In [9]:
lime_bert_ag = LimeExplainer(bert_ag)
shap_bert_ag = SHAPExplainer(bert_ag)

Some explainer

In [10]:
import random
dataset = agnews
explainer = random.choice([lime_bcn_ag, anlp_bcn_ag, lime_bert_ag, shap_bert_ag])
print(f'using explainer {type(explainer)} with model {explainer.model} and dataset {explainer.model.dataset_finetune.NAME}')
train_ag, val_ag, test_ag = agnews.train_val_test
inds = np.random.choice(test_ag.index, 5, replace=False)
indices_ag, preds_ag = explainer.explain_instances(test_ag.sentence[inds])
print(type(indices_ag), type(preds_ag))
indices_ag, preds_ag

Using custom data configuration default
Reusing dataset ag_news (/Users/Admin/.cache/huggingface/datasets/ag_news/default/0.0.0/17ec33e23df9e89565131f989e0fdf78b0cc4672337b582da83fc3c9f79fe34d)


using explainer <class 'src.explainers.explainers.SHAPExplainer'> with model <src.models.bertmodel.BertModel object at 0x14d3d2410> and dataset agnews


100%|██████████| 1/1 [00:01<00:00,  1.33s/it]
100%|██████████| 1/1 [00:01<00:00,  1.20s/it]
100%|██████████| 1/1 [00:02<00:00,  2.48s/it]
100%|██████████| 1/1 [00:05<00:00,  5.16s/it]


HBox(children=(FloatProgress(value=0.0, max=48.0), HTML(value='')))


  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:07<00:00,  7.52s/it]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:07<00:00,  7.97s/it]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:07<00:00,  7.79s/it]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:12<00:00, 12.54s/it]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:10<00:00, 10.66s/it]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:08<00:00,  8.09s/it]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:07<00:00,  7.23s/it]
Partition explainer:  20%|██        | 1/5 [00:00<?, ?it/s]
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:01<00:00,  1.23s/it]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:01<00:00,  1.22s/it]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:02<00:00,  2.58s/it]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1

HBox(children=(FloatProgress(value=0.0, max=48.0), HTML(value='')))



  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:07<00:00,  7.44s/it]


  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:06<00:00,  6.14s/it]


  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:05<00:00,  5.82s/it]


  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:07<00:00,  7.61s/it]


  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:08<00:00,  8.16s/it]


  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:07<00:00,  7.11s/it]


  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:06<00:00,  6.99s/it]
Partition explainer:  60%|██████    | 3/5 [02:13<00:59, 29.57s/it]
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:01<00:00,  1.26s/it]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:01<00:00,  1.31s/it]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:02<00:00,  2.54s/it]

  0%|          | 0

HBox(children=(FloatProgress(value=0.0, max=48.0), HTML(value='')))



  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:08<00:00,  8.23s/it]


  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:07<00:00,  7.38s/it]


  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:07<00:00,  7.96s/it]


  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:08<00:00,  8.10s/it]


  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:07<00:00,  7.83s/it]


  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:07<00:00,  7.49s/it]


  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:07<00:00,  7.30s/it]
Partition explainer:  80%|████████  | 4/5 [03:18<00:40, 40.33s/it]
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:01<00:00,  1.35s/it]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:01<00:00,  1.15s/it]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:02<00:00,  2.31s/it]

  0%|          | 0

HBox(children=(FloatProgress(value=0.0, max=48.0), HTML(value='')))



  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:07<00:00,  7.49s/it]


  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:07<00:00,  7.28s/it]


  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:07<00:00,  7.37s/it]


  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:07<00:00,  7.25s/it]


  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:09<00:00,  9.52s/it]


  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:07<00:00,  7.95s/it]


  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:06<00:00,  6.55s/it]
Partition explainer: 100%|██████████| 5/5 [04:22<00:00, 47.35s/it]
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:01<00:00,  1.01s/it]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:01<00:00,  1.06s/it]

  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:02<00:00,  2.29s/it]

  0%|          | 0

HBox(children=(FloatProgress(value=0.0, max=48.0), HTML(value='')))



  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:07<00:00,  7.31s/it]


  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:06<00:00,  6.55s/it]


  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:07<00:00,  7.69s/it]


  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:07<00:00,  7.81s/it]


  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:07<00:00,  7.95s/it]


  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:07<00:00,  7.09s/it]


  0%|          | 0/1 [00:00<?, ?it/s][A[A

100%|██████████| 1/1 [00:06<00:00,  6.49s/it]
Partition explainer: 6it [05:22, 64.48s/it]

<class 'numpy.ndarray'> <class 'numpy.ndarray'>





(array([list([' ', 'anaheim ', 'scored ', 'three ', 'runs ', 'in ', 'the ', 'eighth ', 'inning ', 'off ', 'oakland ', 'relieve ', '##rs ', 'to ', 'rally ', 'for ', 'a ', 'victory ', 'and ', 'clinch ', 'the ', 'american ', 'league ', 'west ', 'title ', '. ', ' ']),
        list([' ', 'ci ', '##ng ', '##ular ', 'wireless ', ', ', 'the ', 'nation ', "' ", '; ', 's ', 'largest ', 'wireless ', 'carrier ', 'following ', 'the ', 'company ', "' ", '; ', 's ', 'merger ', 'with ', 'at ', 'amp ', '; ', 't ', 'wireless ', ', ', 'said ', 'wednesday ', 'that ', 'it ', 'has ', 'completed ', 'integration ', 'activities ', 'ahead ', 'of ', 'schedule ', 'and ', 'now ', 'expects ', 'merger ', '- ', 'related ', 'cost ', 'savings ', 'to ', 'exceed ', 'prior ', 'estimates ', ' ']),
        list([' ', 'a ', 'yemen ', '##i ', 'court ', 'jailed ', 'five ', 'al ', 'qaeda ', 'supporters ', 'for ', '10 ', 'years ', 'saturday ', 'for ', 'the ', 'bombing ', 'of ', 'the ', 'french ', 'super ', '##tan ', '##ker ', 'l