In [None]:
import numpy as np
import pandas as pd

### Choose dataset

In [None]:
from ipywidgets import Dropdown, Checkbox, Label, Box
from IPython.display import display

In [None]:
d_ = Dropdown(description='Dataset',
              options=['hatespeech', 'sst', 'snli'],
              value='sst')
p_ = Checkbox(description='Calculate performance',
              value=False)

display(Box([Label('Options:'), d_, p_]))

In [None]:
import prepare_experiment

from config import DEFAULT_HPARAMS, SEED, TRAINED_MODELS

from dataset import Hatespeech, SST, SNLI
from explanation_methods import SEDC, PWWSAntonym, eBERT, TextFooler
from models import train_test, HatespeechWhitebox, SSTWhitebox, SNLIWhitebox, InfersentModel, BERT

In [None]:
DATASET = d_.value
CALCULATE_PERFORMANCE = p_.value

In [None]:
datasets = {'hatespeech': (Hatespeech, HatespeechWhitebox),
            'sst': (SST, SSTWhitebox),
            'snli': (SNLI, SNLIWhitebox)}

### Dataset descriptives

In [None]:
dataset, whitebox = datasets[str(DATASET).lower()]
dataset = dataset()
dataset.describe()

### Train models for dataset

In [None]:
trained_models = []

models = [whitebox, InfersentModel, BERT]
for model in models:
    m_name = str(model).split(".")[-1].replace("'>", "")
    print(f'> Model "{m_name}" on dataset "{dataset}"')

    if model is not None:
        model_name, test_score, trained_model = train_test(model, dataset, calculate_performance=CALCULATE_PERFORMANCE)
        new = [model_name, dataset, test_score, trained_model]
        if test_score is not None:
            print(new)
        trained_models.append(new)
    print('')

In [None]:
if CALCULATE_PERFORMANCE:
    pd.DataFrame(trained_models, columns=['predictive_model', 'dataset', 'performance', 'model']) \
      .to_csv(f'results/performance_{str(DATASET).lower()}.csv', index=None)

### Apply each explanation method per model
Do this with 5-fold cross-validation.

In [None]:
explanation_methods = [SEDC(), PWWSAntonym(), eBERT(batch_size=32), TextFooler()]

for target_seed in range(5):
    for _, dataset, _, predict_fn in trained_models:
        results = []

        print(f'> Model "{predict_fn}" on dataset "{dataset}" (target seed={target_seed})')
        np.random.seed(SEED)
        d = dataset.get(part='test')
        try:
            X = d['X']
        except KeyError:
            X = d[['X_premise', 'X_hypothesis']]
        y_true = d['y']
        y_target = dataset.target(part='test', seed=target_seed)
        
        for explanation_method in explanation_methods:
            if hasattr(explanation_method, 'seed'):
                explanation_method.seed = target_seed
            print(f'|--> {explanation_method}')
            explanation_method.target_size = dataset.target_size
            
            if explanation_method.provide_true_labels:
                res = explanation_method(X, predict_fn, y_target, y_true, return_y=True) 
            else:
                res = explanation_method(X, predict_fn, y_target, return_y=True)
            p, counterfactuals, y_cf = res
            p['model'] = str(predict_fn).lower()
            p['dataset'] = str(dataset).lower()
            p['explanation_method'] = str(explanation_method).lower().split('(')[0]
            p['seed'] = explanation_method.seed
            p['similarity_std'] = np.std(p['X_sim'])
            p['semantic_std'] = np.std(p['X_sem'])
            p['target_seed'] = target_seed
            p['counterfactuals'] = counterfactuals
            p['y_target'] = y_target
            p['y_cf'] = y_cf
            results.append(p)
        results = pd.DataFrame(results)[['model', 'dataset', 'explanation_method', 'seed', 'target_seed',
                                         'similarity', 'similarity_std', 'X_sim', 'semantic', 'semantic_std',
                                         'X_sem', 'performance_measure', 'fidelity', 'training_time',
                                         'inference_time', 'counterfactuals', 'y_target', 'y_cf']]
        results.to_json(f'results/counterfactuals_{str(DATASET).lower()}_{str(predict_fn).lower()}_seed-{target_seed}+textfooler.json')
        print('')

    print(f'\n... Finished seed {target_seed}!\n\n')