# One-vs-Rest Classifier

This notebook implements an one-vs-rest classifier that fine-tunes several BERT models to tell if a sentence contains problematic metaphors.

<div hidden>
TODO: add extend data3/data.json with better data in the same format that actually makes sense.
</div>

## Imports and Setup

In [1]:
%pip install transformers -Uqq
%pip install sklearn -Uqq
%pip install datasets -Uqq
%pip install torch -Uqq
%pip install numpy -Uqq
%pip install evaluate -Uqq

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
import evaluate
import numpy as np
import torch
from datasets import Dataset, load_dataset
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    EvalPrediction,
    Trainer,
    TrainingArguments,
)
import os

In [3]:
os.environ['WANDB_DISABLED']='true'

## Loading Dataset

In [4]:
dataset = load_dataset("json", data_files="data/data.json", field="data")
dataset

Found cached dataset json (/root/.cache/huggingface/datasets/json/default-e7751a31f3f17887/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)


  0%|          | 0/1 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'agency', 'humanComparison', 'hyperbole', 'historyComparison', 'unjustClaims', 'deepSounding', 'sceptics', 'deEmphasize', 'performanceNumber', 'inscrutable'],
        num_rows: 329
    })
})

In [5]:
dataset["train"][0:3]

{'text': ['A new vision of artificial intelligence for the people',
  'The gig workers fighting back against the algorithms',
  'How the AI industry profits from catastrophe'],
 'agency': [False, True, False],
 'humanComparison': [False, True, False],
 'hyperbole': [False, True, True],
 'historyComparison': [False, False, False],
 'unjustClaims': [False, False, False],
 'deepSounding': [False, False, False],
 'sceptics': [False, False, False],
 'deEmphasize': [False, False, False],
 'performanceNumber': [False, False, False],
 'inscrutable': [False, False, False]}

In [6]:
labels = [label for label in dataset["train"].features.keys() if label not in ["text"]]

num_epochs = {
    "agency": 10,
    "humanComparison": 2,
    "hyperbole": 2,
    "historyComparison": 2,
    "unjustClaims": 5,
    "deepSounding": 2,
    "sceptics": 2,
    "deEmphasize": 7,
    "performanceNumber": 2,
    "inscrutable": 2,
}

labels

['agency',
 'humanComparison',
 'hyperbole',
 'historyComparison',
 'unjustClaims',
 'deepSounding',
 'sceptics',
 'deEmphasize',
 'performanceNumber',
 'inscrutable']

## Preprocess Data, Create Train/Test Split

In [7]:
processed_dataset = {}
for label in labels:
    projected_dataset = (
        dataset["train"]
        .map(remove_columns=[l for l in labels if l != label])
        .rename_column(label, "labels")
        .class_encode_column("labels")
    )
    processed_dataset[label] = projected_dataset.train_test_split(
        test_size=0.2, stratify_by_column="labels"
    )
    # print(f"{label}:\n\t{processed_dataset[label]['test'][0:3]}\n")

processed_dataset

Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-e7751a31f3f17887/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-604182da521c301e.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-e7751a31f3f17887/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-8d3dee6ea1033dcd.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-e7751a31f3f17887/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-839d38437908d34a.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-e7751a31f3f17887/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-fc7d1db672f398d9.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-e7751a31f3f17887/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/cache-36f8b4671a097526.arrow


{'agency': DatasetDict({
     train: Dataset({
         features: ['text', 'labels'],
         num_rows: 263
     })
     test: Dataset({
         features: ['text', 'labels'],
         num_rows: 66
     })
 }),
 'humanComparison': DatasetDict({
     train: Dataset({
         features: ['text', 'labels'],
         num_rows: 263
     })
     test: Dataset({
         features: ['text', 'labels'],
         num_rows: 66
     })
 }),
 'hyperbole': DatasetDict({
     train: Dataset({
         features: ['text', 'labels'],
         num_rows: 263
     })
     test: Dataset({
         features: ['text', 'labels'],
         num_rows: 66
     })
 }),
 'historyComparison': DatasetDict({
     train: Dataset({
         features: ['text', 'labels'],
         num_rows: 263
     })
     test: Dataset({
         features: ['text', 'labels'],
         num_rows: 66
     })
 }),
 'unjustClaims': DatasetDict({
     train: Dataset({
         features: ['text', 'labels'],
         num_rows: 263
     })
     t

In [8]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")


def preprocess_data(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

In [9]:
tokenized_dataset = {
    k: ds.map(
        preprocess_data,
        remove_columns="text",
        batched=True,
    )
    for k, ds in processed_dataset.items()
}

Map:   0%|          | 0/263 [00:00<?, ? examples/s]

Map:   0%|          | 0/66 [00:00<?, ? examples/s]

Map:   0%|          | 0/263 [00:00<?, ? examples/s]

Map:   0%|          | 0/66 [00:00<?, ? examples/s]

Map:   0%|          | 0/263 [00:00<?, ? examples/s]

Map:   0%|          | 0/66 [00:00<?, ? examples/s]

Map:   0%|          | 0/263 [00:00<?, ? examples/s]

Map:   0%|          | 0/66 [00:00<?, ? examples/s]

Map:   0%|          | 0/263 [00:00<?, ? examples/s]

Map:   0%|          | 0/66 [00:00<?, ? examples/s]

Map:   0%|          | 0/263 [00:00<?, ? examples/s]

Map:   0%|          | 0/66 [00:00<?, ? examples/s]

Map:   0%|          | 0/263 [00:00<?, ? examples/s]

Map:   0%|          | 0/66 [00:00<?, ? examples/s]

Map:   0%|          | 0/263 [00:00<?, ? examples/s]

Map:   0%|          | 0/66 [00:00<?, ? examples/s]

Map:   0%|          | 0/263 [00:00<?, ? examples/s]

Map:   0%|          | 0/66 [00:00<?, ? examples/s]

Map:   0%|          | 0/263 [00:00<?, ? examples/s]

Map:   0%|          | 0/66 [00:00<?, ? examples/s]

### Verify dataset

In [10]:
example = tokenized_dataset["agency"]["train"][0]
print(example.keys())

dict_keys(['labels', 'input_ids', 'token_type_ids', 'attention_mask'])


In [11]:
tokenizer.decode(example["input_ids"])

'[CLS] The White House just unveiled a new AI Bill of Rights [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PA

In [12]:
example["labels"]

0

## Load Pre-Trained Model

In [13]:
# use_fast uses fast tokenizers backed by rust. Remove it if it causes errors
# model = AutoModelForSequenceClassification.from_pretrained(
    # "bert-base-cased",
    # num_labels=2,
# )

### Verify data-model interaction

In [14]:
# forward pass
# outputs = model(
# input_ids=tokenized_dataset[labels[0]]["train"]["input_ids"][0],
# labels=tokenized_dataset[labels[0]]["train"][0]["labels"],
# )
# outputs

## Define Metrics

In [15]:
metrics = {
    "accuracy": evaluate.load("accuracy"),
    "presicion": evaluate.load("precision"),
    "recall": evaluate.load("recall"),
    "f1": evaluate.load("f1"),
}

In [16]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return {
        name: metric.compute(predictions=predictions, references=labels)
        for name, metric in metrics.items()
    }

In [17]:
class HighPrecisionTrainer(Trainer):
    """A trainer class, which computes loss based on a weighted MSE, where the error for the positive labels is
    weighted more than the error for the negative labels, leading to a higher precision
    """

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        label_mask = torch.FloatTensor(1, 2).cuda().zero_()
        label_mask[0, labels[0]] = 1
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")

        # compute custom loss (suppose one has 3 labels with different weights)
        loss = torch.sum(torch.tensor([[1, 300]]).cuda() * ((logits - label_mask) ** 2))
        return (loss, outputs) if return_outputs else loss

## Train the Model

In [18]:
batch_size = 1  # TODO: increase if we have more data
# metric_name = "f1"

In [19]:
for label in ['agency']:  # labels:
    print(f"training model for {label}")

    model = AutoModelForSequenceClassification.from_pretrained(
        "bert-base-cased",
        num_labels=2,
    )

    training_args = TrainingArguments(
        f"aihype_{label}-vs-rest",
        evaluation_strategy="epoch",
        save_strategy="epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=num_epochs[label],
        # weight_decay=0.05,
        report_to="all",
        # load_best_model_at_end=True,
        # metric_for_best_model=metric_name,
        # push_to_hub=True,  # TODO: enable once model seems good
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset[label]["train"],
        eval_dataset=tokenized_dataset[label]["test"],
        compute_metrics=compute_metrics,
    )

    trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Presicion,Recall,F1
1,No log,1.223269,{'accuracy': 0.7878787878787878},{'precision': 0.0},{'recall': 0.0},{'f1': 0.0}
2,1.037700,0.875177,{'accuracy': 0.7878787878787878},{'precision': 0.0},{'recall': 0.0},{'f1': 0.0}
3,1.037700,0.978453,{'accuracy': 0.7878787878787878},{'precision': 0.5},{'recall': 0.21428571428571427},{'f1': 0.3}
4,0.491000,1.233475,{'accuracy': 0.803030303030303},{'precision': 0.5454545454545454},{'recall': 0.42857142857142855},{'f1': 0.4799999999999999}
5,0.491000,1.684321,{'accuracy': 0.7727272727272727},{'precision': 0.42857142857142855},{'recall': 0.21428571428571427},{'f1': 0.2857142857142857}
6,0.026100,1.690034,{'accuracy': 0.7727272727272727},{'precision': 0.42857142857142855},{'recall': 0.21428571428571427},{'f1': 0.2857142857142857}
7,0.026100,1.668422,{'accuracy': 0.803030303030303},{'precision': 0.5555555555555556},{'recall': 0.35714285714285715},{'f1': 0.43478260869565216}
8,0.000000,1.684913,{'accuracy': 0.803030303030303},{'precision': 0.5555555555555556},{'recall': 0.35714285714285715},{'f1': 0.43478260869565216}
9,0.000000,1.691925,{'accuracy': 0.8333333333333334},{'precision': 0.6363636363636364},{'recall': 0.5},{'f1': 0.56}
10,0.000000,1.698746,{'accuracy': 0.8333333333333334},{'precision': 0.6363636363636364},{'recall': 0.5},{'f1': 0.56}


## Upload the Model

In [20]:
# agency-vs-rest/checkpoint-263: 0.75 precision, 0.85 recall
#

In [21]:
# trainer.push_to_hub()