## BERT Fresh for Bias Detection + FC Classifier 

We load in a BERT model that has not been pretrained on any sort of bias detection, and try to train the model directly on the task of bias classification. 

In [1]:
import sys; sys.path.append("../../../../..")
import torch 
from src.experiment import ClassificationExperiment
from src.dataset import ExperimentDataset
from src.params import Params

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
params = Params.read_params("experiment_params.json")

In [4]:
# Loading in the dataset that we are using in this experiments 
# typically this dataset is the small set of ground-truth labels
dev_dataset = ExperimentDataset.init_dataset(params.dataset)

04/01/2020 19:05:31 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at ./cache/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
100it [00:00, 3955.62it/s]


In [5]:
import pickle 
train_dataset = pickle.load(open(params.dataset["weakly_labeled_data"], "rb"))

In [6]:
train_dataset.add_data(train_dataset.get_val('weak_bias_label'), 'bias_label')

### Start of Classification

In [8]:
params = Params.read_params("experiment_params.json")

In [9]:
train_dataset.shuffle_data()
dev_dataset.shuffle_data()
train_dataloader = train_dataset.return_dataloader(batch_size=params.final_task['training_params']['batch_size'])
dev_dataloader = dev_dataset.return_dataloader(batch_size=32)

In [10]:
classification_experiment = ClassificationExperiment.init_cls_experiment(params.final_task)

04/01/2020 19:05:32 - INFO - pytorch_pretrained_bert.modeling -   loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at /sailhome/rdm/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba
04/01/2020 19:05:32 - INFO - pytorch_pretrained_bert.modeling -   extracting archive file /sailhome/rdm/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /tmp/tmpdy18ll4s
04/01/2020 19:05:35 - INFO - pytorch_pretrained_bert.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 

In [11]:
losses, evals = classification_experiment.train_model(train_dataloader, dev_dataloader, input_key="pre_ids", label_key="bias_label", disable_tqdm=False, model_dtype=torch.long)

HBox(children=(FloatProgress(value=0.0, description='epochs', max=3.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))





In [12]:
evals

[[{'num_examples': 32, 'accuracy': 0.65625, 'auc': 0.5},
  {'num_examples': 32, 'accuracy': 0.625, 'auc': 0.5},
  {'num_examples': 17, 'accuracy': 0.4117647058823529, 'auc': 0.5}],
 [{'num_examples': 32, 'accuracy': 0.65625, 'auc': 0.5},
  {'num_examples': 32, 'accuracy': 0.625, 'auc': 0.5},
  {'num_examples': 17, 'accuracy': 0.4117647058823529, 'auc': 0.5}],
 [{'num_examples': 32, 'accuracy': 0.65625, 'auc': 0.5},
  {'num_examples': 32, 'accuracy': 0.625, 'auc': 0.5},
  {'num_examples': 17, 'accuracy': 0.4117647058823529, 'auc': 0.5}]]

In [13]:
from src.utils.classification_utils import average_data

In [14]:
avg_evaluations = [average_data(epoch_evaluations) for epoch_evaluations in evals]

In [15]:
avg_evaluations

[{'num_examples': 81, 'accuracy': 0.5925925925925926, 'auc': 0.5},
 {'num_examples': 81, 'accuracy': 0.5925925925925926, 'auc': 0.5},
 {'num_examples': 81, 'accuracy': 0.5925925925925926, 'auc': 0.5}]