## Averaging last 4 attention distributions using a GRU 

We average together the last 4 attention distributions of the input. Our classification model is an LSTM that reads in the attention distribution for each word sequentially. In general the max sequence lenght is 80 which means that our model will read in 80 data points which are each 80 dimensional. 


#### Notes
* One remaining question is how can we experiment with different number of attention heads. In general the extract attention scores function seems to have some bugs that need ironing out - such as only being able to pass in a batch size of 1 into the extraction schema.

* We also would like to eventually use a transformer architecture on top of the attention distributions

In [1]:
import sys; sys.path.append("../../../../..")
import torch 
from src.experiment import AttentionExperiment, ClassificationExperiment
from src.dataset import ExperimentDataset
from src.params import Params

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
params = Params.read_params("experiment_params.json")

In [4]:
# Loading in the dataset that we are using in this experiments 
# typically this dataset is the small set of ground-truth labels
dataset = ExperimentDataset.init_dataset(params)

02/19/2020 20:00:06 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at ./cache/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
386it [00:00, 4913.93it/s]


In [5]:
dataset

Length: 324 Keys: dict_keys(['pre_ids', 'masks', 'pre_lens', 'post_in_ids', 'post_out_ids', 'pre_tok_label_ids', 'post_tok_label_ids', 'rel_ids', 'pos_ids', 'categories', 'index', 'bias_label'])

In [6]:
attention_dataloader = dataset.return_dataloader() 

Attention Experiment: 
* Is a class that wraps useful methods to extract attention distributions from a given BERT-based model 
* In the config file the user needs to specify a .ckpt file for a trained BERT-based model from which 
     we want to extract attention scores
* The user needs to instantiate the attention experiment with a function that tells the model how to run 
 inference on the given model 

In [7]:
attention_experiment = AttentionExperiment.initialize_attention_experiment(params.intermediary_task, verbose=True)

02/19/2020 20:00:07 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at ./cache/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084


The len of our vocabulary is 30523
Cuda is set to true


02/19/2020 20:00:07 - INFO - pytorch_pretrained_bert.modeling -   loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at ./cache/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba
02/19/2020 20:00:07 - INFO - pytorch_pretrained_bert.modeling -   extracting archive file ./cache/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /tmp/tmp0e0eyydy
02/19/2020 20:00:11 - INFO - pytorch_pretrained_bert.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

02/19/2020 20:00:18 - INFO - pytor

Succesfully loaded in attention experiment!


In [8]:
dataset

Length: 324 Keys: dict_keys(['pre_ids', 'masks', 'pre_lens', 'post_in_ids', 'post_out_ids', 'pre_tok_label_ids', 'post_tok_label_ids', 'rel_ids', 'pos_ids', 'categories', 'index', 'bias_label'])

extract_attention_scores() works out of the box because the attention experiment has the config file saved, and knows what BERT model to use/load in, which layers to extract the attention scores from, and what the inference function is that should be used on this particular BERT model.

Attention_scores is then a list of dictionaries. The keys in this dictionary are the specific layers of a BERT model and the values are the corresponding attention distributions extracted from that particular layer.

In [9]:
attention_scores = attention_experiment.extract_attention_scores(attention_dataloader)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




In [10]:
from src.utils.attention_utils import avg_attention_dist

In [11]:
avg_attention = avg_attention_dist(attention_scores)

In [12]:
avg_attention[0].shape

torch.Size([1, 1, 80, 80])

In [13]:
stacked_avg_attention = torch.stack(avg_attention).squeeze()
# squeezes from [324, 1, 1, 80, 80] --> [324, 80, 80]

In [14]:
dataset.add_data(stacked_avg_attention, "attention_dist")

In [15]:
dataset.shuffle_data()

In [16]:
dataset

Length: 324 Keys: dict_keys(['pre_ids', 'masks', 'pre_lens', 'post_in_ids', 'post_out_ids', 'pre_tok_label_ids', 'post_tok_label_ids', 'rel_ids', 'pos_ids', 'categories', 'index', 'bias_label', 'attention_dist'])

### This is where the classification experiment starts

We create a classification experiment that contains useful methods for classifying bias based on the attention distributions. 

In [17]:
classification_experiment = ClassificationExperiment.init_cls_experiment(params.final_task)

In [18]:
from src.utils.classification_utils import run_bootstrapping

In [19]:
stats = run_bootstrapping(classification_experiment, dataset, params.final_task, num_bootstrap_iters=3, input_key="attention_dist", label_key="bias_label")

HBox(children=(IntProgress(value=0, description='Cross Validation Iteration', max=3, style=ProgressStyle(descr…

HBox(children=(IntProgress(value=0, description='epochs', max=150, style=ProgressStyle(description_width='init…

Step: 0 ; Loss 0.6971462965011597 
Step: 3 ; Loss 0.692383885383606 
Step: 6 ; Loss 0.6503917574882507 
Step: 0 ; Loss 0.6304763555526733 
Step: 3 ; Loss 0.7080031633377075 
Step: 6 ; Loss 0.6238500475883484 
Step: 0 ; Loss 0.6031494736671448 
Step: 3 ; Loss 0.7197659611701965 
Step: 6 ; Loss 0.6223848462104797 
Step: 0 ; Loss 0.603582501411438 
Step: 3 ; Loss 0.7072267532348633 
Step: 6 ; Loss 0.6254379153251648 
Step: 0 ; Loss 0.6075143814086914 
Step: 3 ; Loss 0.6986983418464661 
Step: 6 ; Loss 0.6240379810333252 
Step: 0 ; Loss 0.6037139892578125 
Step: 3 ; Loss 0.6921854615211487 
Step: 6 ; Loss 0.6181116104125977 
Step: 0 ; Loss 0.594451904296875 
Step: 3 ; Loss 0.6795854568481445 
Step: 6 ; Loss 0.6126923561096191 
Step: 0 ; Loss 0.5786746144294739 
Step: 3 ; Loss 0.6381656527519226 
Step: 6 ; Loss 0.5993362665176392 
Step: 0 ; Loss 0.5655835270881653 
Step: 3 ; Loss 0.5424614548683167 
Step: 6 ; Loss 0.5923318862915039 
Step: 0 ; Loss 0.5652759075164795 
Step: 3 ; Loss 0.472257

Step: 3 ; Loss 0.28453874588012695 
Step: 6 ; Loss 0.237224280834198 
Step: 0 ; Loss 0.4630010426044464 
Step: 3 ; Loss 0.2432796210050583 
Step: 6 ; Loss 0.4021208584308624 
Step: 0 ; Loss 0.43308526277542114 
Step: 3 ; Loss 0.12418773770332336 
Step: 6 ; Loss 0.27421438694000244 
Step: 0 ; Loss 0.24169492721557617 
Step: 3 ; Loss 0.17862758040428162 
Step: 6 ; Loss 0.18237653374671936 
Step: 0 ; Loss 0.3625430166721344 
Step: 3 ; Loss 0.13350233435630798 
Step: 6 ; Loss 0.2703002095222473 
Step: 0 ; Loss 0.44789543747901917 
Step: 3 ; Loss 0.38834428787231445 
Step: 6 ; Loss 0.4613816440105438 
Step: 0 ; Loss 0.18570470809936523 
Step: 3 ; Loss 0.19440051913261414 
Step: 6 ; Loss 0.12191767245531082 
Step: 0 ; Loss 0.20709273219108582 
Step: 3 ; Loss 0.17909613251686096 
Step: 6 ; Loss 0.11623779684305191 
Step: 0 ; Loss 0.13306193053722382 
Step: 3 ; Loss 0.10298753529787064 
Step: 6 ; Loss 0.09625814855098724 
Step: 0 ; Loss 0.09912914037704468 
Step: 3 ; Loss 0.07964907586574554 


HBox(children=(IntProgress(value=0, description='epochs', max=150, style=ProgressStyle(description_width='init…

Step: 0 ; Loss 0.6203060746192932 
Step: 3 ; Loss 0.6005366444587708 
Step: 6 ; Loss 0.5153704881668091 
All labels are of the same type – skipping AUC calculation
Step: 0 ; Loss 0.45306241512298584 
Step: 3 ; Loss 0.4700659513473511 
Step: 6 ; Loss 0.4514164924621582 
All labels are of the same type – skipping AUC calculation
Step: 0 ; Loss 0.37552985548973083 
Step: 3 ; Loss 0.40404027700424194 
Step: 6 ; Loss 0.4061206579208374 
All labels are of the same type – skipping AUC calculation
Step: 0 ; Loss 0.4437530040740967 
Step: 3 ; Loss 0.41677039861679077 
Step: 6 ; Loss 0.4714903235435486 
All labels are of the same type – skipping AUC calculation
Step: 0 ; Loss 0.32585734128952026 
Step: 3 ; Loss 0.45106247067451477 
Step: 6 ; Loss 0.521379292011261 
All labels are of the same type – skipping AUC calculation
Step: 0 ; Loss 0.3385026752948761 
Step: 3 ; Loss 0.42848315834999084 
Step: 6 ; Loss 0.446835994720459 
All labels are of the same type – skipping AUC calculation
Step: 0 ; L

Step: 3 ; Loss 0.01815582625567913 
Step: 6 ; Loss 0.009822087362408638 
All labels are of the same type – skipping AUC calculation
Step: 0 ; Loss 0.01623791828751564 
Step: 3 ; Loss 0.016757501289248466 
Step: 6 ; Loss 0.009022911079227924 
All labels are of the same type – skipping AUC calculation
Step: 0 ; Loss 0.014559631235897541 
Step: 3 ; Loss 0.01793523132801056 
Step: 6 ; Loss 0.009444532915949821 
All labels are of the same type – skipping AUC calculation
Step: 0 ; Loss 0.016218220815062523 
Step: 3 ; Loss 0.016502689570188522 
Step: 6 ; Loss 0.008997628465294838 
All labels are of the same type – skipping AUC calculation
Step: 0 ; Loss 0.01312455628067255 
Step: 3 ; Loss 0.022040264680981636 
Step: 6 ; Loss 0.007807966321706772 
All labels are of the same type – skipping AUC calculation
Step: 0 ; Loss 0.014352790080010891 
Step: 3 ; Loss 0.0138898566365242 
Step: 6 ; Loss 0.007782271597534418 
All labels are of the same type – skipping AUC calculation
Step: 0 ; Loss 0.011601

Step: 6 ; Loss 0.5880367755889893 
All labels are of the same type – skipping AUC calculation
Step: 0 ; Loss 0.5401203632354736 
Step: 3 ; Loss 0.6106472015380859 
Step: 6 ; Loss 0.5804495215415955 
All labels are of the same type – skipping AUC calculation
Step: 0 ; Loss 0.5280779004096985 
Step: 3 ; Loss 0.6108323931694031 
Step: 6 ; Loss 0.5821245908737183 
All labels are of the same type – skipping AUC calculation
Step: 0 ; Loss 0.5236149430274963 
Step: 3 ; Loss 0.6262916326522827 
Step: 6 ; Loss 0.5667567253112793 
All labels are of the same type – skipping AUC calculation
Step: 0 ; Loss 0.522339403629303 
Step: 3 ; Loss 0.6176205277442932 
Step: 6 ; Loss 0.5599054098129272 
All labels are of the same type – skipping AUC calculation
Step: 0 ; Loss 0.5014303922653198 
Step: 3 ; Loss 0.6319559216499329 
Step: 6 ; Loss 0.5594376921653748 
All labels are of the same type – skipping AUC calculation
Step: 0 ; Loss 0.5016419887542725 
Step: 3 ; Loss 0.5942882895469666 
Step: 6 ; Loss 0.

Step: 6 ; Loss 0.33660799264907837 
All labels are of the same type – skipping AUC calculation
Step: 0 ; Loss 0.36221766471862793 
Step: 3 ; Loss 0.307378888130188 
Step: 6 ; Loss 0.3631622791290283 
All labels are of the same type – skipping AUC calculation


HBox(children=(IntProgress(value=0, description='epochs', max=150, style=ProgressStyle(description_width='init…

Step: 0 ; Loss 0.7427797317504883 
Step: 3 ; Loss 0.7244586944580078 
Step: 6 ; Loss 0.6962272524833679 
Step: 0 ; Loss 0.6670082807540894 
Step: 3 ; Loss 0.6188298463821411 
Step: 6 ; Loss 0.6358893513679504 
Step: 0 ; Loss 0.6057889461517334 
Step: 3 ; Loss 0.5435404777526855 
Step: 6 ; Loss 0.5908763408660889 
Step: 0 ; Loss 0.5493637323379517 
Step: 3 ; Loss 0.5105730295181274 
Step: 6 ; Loss 0.5732740759849548 
Step: 0 ; Loss 0.5334495902061462 
Step: 3 ; Loss 0.49505436420440674 
Step: 6 ; Loss 0.5762150287628174 
Step: 0 ; Loss 0.5289062261581421 
Step: 3 ; Loss 0.44685128331184387 
Step: 6 ; Loss 0.585297703742981 
Step: 0 ; Loss 0.5207058787345886 
Step: 3 ; Loss 0.44476521015167236 
Step: 6 ; Loss 0.5501129627227783 
Step: 0 ; Loss 0.5050820112228394 
Step: 3 ; Loss 0.4345044791698456 
Step: 6 ; Loss 0.5665379166603088 
Step: 0 ; Loss 0.5130108594894409 
Step: 3 ; Loss 0.4277929365634918 
Step: 6 ; Loss 0.5424851775169373 
Step: 0 ; Loss 0.5033265948295593 
Step: 3 ; Loss 0.4

Step: 3 ; Loss 0.16159795224666595 
Step: 6 ; Loss 0.2945966422557831 
Step: 0 ; Loss 0.2049943208694458 
Step: 3 ; Loss 0.1478855311870575 
Step: 6 ; Loss 0.42051148414611816 
Step: 0 ; Loss 0.3299188017845154 
Step: 3 ; Loss 0.1381010264158249 
Step: 6 ; Loss 0.28436699509620667 
Step: 0 ; Loss 0.24290552735328674 
Step: 3 ; Loss 0.13993990421295166 
Step: 6 ; Loss 0.290988564491272 
Step: 0 ; Loss 0.23732252418994904 
Step: 3 ; Loss 0.12863464653491974 
Step: 6 ; Loss 0.2624308466911316 
Step: 0 ; Loss 0.2251303344964981 
Step: 3 ; Loss 0.13168422877788544 
Step: 6 ; Loss 0.21910101175308228 
Step: 0 ; Loss 0.23520506918430328 
Step: 3 ; Loss 0.12152355164289474 
Step: 6 ; Loss 0.26561689376831055 
Step: 0 ; Loss 0.22166453301906586 
Step: 3 ; Loss 0.12412118166685104 
Step: 6 ; Loss 0.25902894139289856 
Step: 0 ; Loss 0.23977796733379364 
Step: 3 ; Loss 0.13097158074378967 
Step: 6 ; Loss 0.2600433826446533 
Step: 0 ; Loss 0.20946532487869263 
Step: 3 ; Loss 0.11169901490211487 
St

In [20]:
stats

{'auc': [(0.6421503574746055, 0.8592986796739482), 0.7802021833439663],
 'accuracy': [(0.6508184523809524, 0.805920493197279), 0.7503543083900226]}