## Averaging attention distributions from the last four layers then classifying via a Linear Network

We extract the attention distribution directly from a model that has been pretrained on bias detection!

We then determine the predicted biased word, and use only the attention distribution associated with that extracted word for further classification. 

We average together the last four attention distributions of the input. The classification model is a simple super shallow neural network.

A random baseline performs at accuracy 0.59 

In [1]:
import sys; sys.path.append("../../../../..")
import torch 
from src.experiment import AttentionExperiment, ClassificationExperiment
from src.dataset import ExperimentDataset
from src.params import Params
from src.utils.attention_utils import reduce_attention_dist, return_idx_attention_dist
from src.utils.classification_utils import run_bootstrapping
from src.utils.shared_utils import get_bias_predictions

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
params = Params.read_params("linear-params.json")
print("model = {}".format(params.final_task['model']))
print("layers = {}".format(params.intermediary_task["attention"]["layers"]))
print("reducer = {}".format(params.intermediary_task["attention"]["reducer"]))

model = shallow_nn
layers = [8, 9, 10, 11]
reducer = avg


In [4]:
# Loading in the dataset that we are using in this experiments 
# typically this dataset is the small set of ground-truth labels
dev_dataset = ExperimentDataset.init_dataset(params.dataset)

03/22/2020 16:36:14 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at ./cache/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
100it [00:00, 3883.40it/s]


In [5]:
import pickle 
train_dataset = pickle.load(open(params.dataset["weakly_labeled_data"], "rb"))

### Attention Experiment: 
* Is a class that wraps useful methods to extract attention distributions from a given BERT-based model 
* The user has to provide in two config files: One to specify parameters for how the attention scores should be extracted and combined, and other to specify the intermediary model from which the attention scores should be extracted from
* The user needs to instantiate the attention experiment with a function that tells the model how to run 
 inference on the given model. The function header is specified below: 
 
 ``` def initialize_attention_experiment(cls, intermediary_task_params, dataset_params, verbose=False) ```
 


In [6]:
attention_experiment = AttentionExperiment.initialize_attention_experiment(params.intermediary_task, params.dataset, verbose=True)

03/22/2020 16:36:18 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at ./cache/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
03/22/2020 16:36:19 - INFO - pytorch_pretrained_bert.modeling -   loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at ./cache/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba
03/22/2020 16:36:19 - INFO - pytorch_pretrained_bert.modeling -   extracting archive file ./cache/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /tmp/tmp42smtxf9
03/22/2020 16:36:22 - INFO - pytorch_pretrained_bert.modeling -   Model config {
  "attention_probs_d

Instantiated joint model with pretrained weights.
Succesfully loaded in attention experiment!


In [7]:
attention_dataloader_dev = dev_dataset.return_dataloader() 
attention_dataloader_train = train_dataset.return_dataloader()

```extract_attention_scores()``` works out of the box because the attention experiment has the config file saved, and knows what BERT model to use/load in, which layers to extract the attention scores from, and what the inference function is that should be used on this particular BERT model.

Attention_scores is then a list of dictionaries. The keys in this dictionary are the specific layers of a BERT model and the values are the corresponding attention distributions extracted from that particular layer.

In [8]:
attention_scores_dev = attention_experiment.extract_attention_scores(attention_dataloader_dev)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [9]:
# Saving out attention weights from the train dataset to facilitate future runs
import os 
attention_weights_file = "train_attention.weights"
if os.path.exists(attention_weights_file):
    attention_scores_train = pickle.load(open("train_attention.weights", "rb"))
else: 
    attention_scores_train = attention_experiment.extract_attention_scores(attention_dataloader_train)
    pickle.dump(attention_scores_train, open( "train_attention.weights", "wb+"))

Returning the predictions of the biased word from the BERT detection module 

In [10]:
bias_predictions_dev = get_bias_predictions(dev_dataset, params.intermediary_task, params.dataset, batch_size=8)
bias_predictions_train = get_bias_predictions(train_dataset, params.intermediary_task, params.dataset, batch_size=16)

03/22/2020 16:37:19 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at ./cache/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
03/22/2020 16:37:20 - INFO - pytorch_pretrained_bert.modeling -   loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at ./cache/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba
03/22/2020 16:37:20 - INFO - pytorch_pretrained_bert.modeling -   extracting archive file ./cache/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /tmp/tmpnqz2pyid
03/22/2020 16:37:23 - INFO - pytorch_pretrained_bert.modeling -   Model config {
  "attention_probs_d

HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))




03/22/2020 16:37:36 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at ./cache/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
03/22/2020 16:37:36 - INFO - pytorch_pretrained_bert.modeling -   loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at ./cache/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba
03/22/2020 16:37:36 - INFO - pytorch_pretrained_bert.modeling -   extracting archive file ./cache/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /tmp/tmpdclusmy8
03/22/2020 16:37:40 - INFO - pytorch_pretrained_bert.modeling -   Model config {
  "attention_probs_d

HBox(children=(FloatProgress(value=0.0, max=6535.0), HTML(value='')))




In [11]:
bias_indices_dev = torch.argmax(bias_predictions_dev == 1, dim=1).tolist()
bias_indices_train = torch.argmax(bias_predictions_train == 1, dim=1).tolist()

In [13]:
attention_scores_indexed_train = return_idx_attention_dist(attention_scores_train, bias_indices_train)
attention_scores_indexed_dev = return_idx_attention_dist(attention_scores_dev, bias_indices_dev)

reduced_attention_train = reduce_attention_dist(attention_scores_indexed_train, params.intermediary_task["attention"]["reducer"])
reduced_attention_dev = reduce_attention_dist(attention_scores_indexed_dev, params.intermediary_task["attention"]["reducer"])

stacked_reduced_attention_train = torch.stack(reduced_attention_train).squeeze()
stacked_reduced_attention_dev = torch.stack(reduced_attention_dev).squeeze()

train_dataset.add_data(stacked_reduced_attention_train, "attention_dist")
train_dataset.shuffle_data()

dev_dataset.add_data(stacked_reduced_attention_dev, "attention_dist")

In [18]:
train_dataset

Length: 52275 Keys: dict_keys(['pre_ids', 'masks', 'pre_lens', 'post_in_ids', 'post_out_ids', 'pre_tok_label_ids', 'post_tok_label_ids', 'rel_ids', 'pos_ids', 'categories', 'index', 'pos_features', 'marta_features', 'bert_embeddings', 'weak_bias_label', 'attention_dist'])

In [22]:
dev_dataset.add_data(dev_dataset.get_val('bias_label'),'weak_bias_label')

In [73]:
dev_dataset.shuffle_data()
train_dataset.shuffle_data()

### This is where the classification experiment starts

We create a classification experiment that contains useful methods for classifying bias based on the attention distributions. 

In [82]:
params = Params.read_params("linear-params.json")

In [83]:
train_dataloader = train_dataset.return_dataloader(batch_size=params.final_task['training_params']['batch_size'])
dev_dataloader = dev_dataset.return_dataloader(batch_size=81)

In [84]:
classification_experiment = ClassificationExperiment.init_cls_experiment(params.final_task)

In [85]:
losses, evals = classification_experiment.train_model(train_dataloader, dev_dataloader, input_key="attention_dist", label_key="weak_bias_label")

HBox(children=(FloatProgress(value=0.0, description='epochs', max=5.0, style=ProgressStyle(description_width='…

In [86]:
from src.utils.classification_utils import average_data

In [87]:
evals

[[{'num_examples': 81,
   'accuracy': 0.5925925925925926,
   'auc': 0.48547979797979796}],
 [{'num_examples': 81,
   'accuracy': 0.5925925925925926,
   'auc': 0.48232323232323226}],
 [{'num_examples': 81,
   'accuracy': 0.5925925925925926,
   'auc': 0.481060606060606}],
 [{'num_examples': 81,
   'accuracy': 0.5925925925925926,
   'auc': 0.481060606060606}],
 [{'num_examples': 81,
   'accuracy': 0.5925925925925926,
   'auc': 0.47285353535353536}]]

Changing Epoch Count 

In [72]:
classification_experiment.save_model_weights("linear-attention.weights")