In [1]:
import sys; sys.path.append("../../../../..")

import torch
import numpy as np

from src.experiment import AttentionExperiment, ClassificationExperiment
from src.dataset import ExperimentDataset
from src.params import Params
from src.utils.attention_utils import reduce_attention_dist
from src.utils.classification_utils import run_bootstrapping

In [104]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [231]:
params = Params.read_params("gru-params.json")
print("layers = {}".format(params.intermediary_task["attention"]["layers"]))
print("reducer = {}".format(params.intermediary_task["attention"]["reducer"]))
print("n_components = {}".format(params.intermediary_task["attention"]["n_components"]))
print("dropout = {}".format(params.final_task["dropout"]))
print("hidden_dim = {}".format(params.final_task["hidden_dim"]))
print("attention_units = {}".format(params.final_task["attention_units"]))

layers = [5, 6, 7]
reducer = avg
n_components = 40
dropout = 0.3
hidden_dim = 200
attention_units = 512


In [5]:
dataset = ExperimentDataset.init_dataset(params.dataset)
dataset

04/02/2020 15:55:56 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at ./cache/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
386it [00:00, 3736.50it/s]


Length: 324 Keys: dict_keys(['pre_ids', 'masks', 'pre_lens', 'post_in_ids', 'post_out_ids', 'pre_tok_label_ids', 'post_tok_label_ids', 'rel_ids', 'pos_ids', 'categories', 'index', 'bias_label'])

In [6]:
attention_dataloader = dataset.return_dataloader() 
attention_experiment = AttentionExperiment.initialize_attention_experiment(params.intermediary_task, params.dataset, verbose=True)

04/02/2020 15:55:58 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at ./cache/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
04/02/2020 15:55:59 - INFO - pytorch_pretrained_bert.modeling -   loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at ./cache/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba
04/02/2020 15:55:59 - INFO - pytorch_pretrained_bert.modeling -   extracting archive file ./cache/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /tmp/tmprkakx8s5
04/02/2020 15:56:07 - INFO - pytorch_pretrained_bert.modeling -   Model config {
  "attention_probs_d

Instantiated joint model with pretrained weights.
Succesfully loaded in attention experiment!


In [7]:
attention_scores = attention_experiment.extract_attention_scores(attention_dataloader)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [232]:
lengths = [int(d["pre_lens"].numpy()) for d in dataset]
reduced_attention = reduce_attention_dist(attention_scores, params.intermediary_task["attention"], lengths)
dataset.add_data(reduced_attention, "attention_dist")
dataset.shuffle_data()
print("reduced_attention.shape = {}".format(reduced_attention.shape))

reduced_attention.shape = torch.Size([324, 80, 40])


### This is where the classification experiment starts

We create a classification experiment that contains useful methods for classifying bias based on the attention distributions. 

In [233]:
classification_experiment = ClassificationExperiment.init_cls_experiment(params.final_task, params.intermediary_task["attention"])

In [234]:
stats = run_bootstrapping(classification_experiment, dataset, params.final_task, num_bootstrap_iters=5, input_key="attention_dist", label_key="bias_label")

HBox(children=(FloatProgress(value=0.0, description='Cross Validation Iteration', max=5.0, style=ProgressStyle…

HBox(children=(FloatProgress(value=0.0, description='epochs', max=150.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='epochs', max=150.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='epochs', max=150.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='epochs', max=150.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='epochs', max=150.0, style=ProgressStyle(description_width…




In [230]:
stats

{'auc': [(0.520740922164123, 0.5861565543896371), 0.5539473356336944],
 'accuracy': [(0.6326530612244897, 0.6428571428571429), 0.636734693877551]}

In [11]:
classification_experiment.save_model_weights("gru-attention.weights")