## Transformers for Counterfactual Recognition

#### This version has class weighted cross-entropy loss 

In [7]:
import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [2]:
from datasets import load_dataset
dataset = load_dataset('csv', data_files='../input/counterfactualrecognition/subtask1_train_bert.csv')

In [3]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=100)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

In [4]:
train_test_dataset = tokenized_datasets["train"].train_test_split(test_size=0.15)

In [5]:
train_test_dataset

In [8]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
model.to(device)

In [9]:
import numpy as np
from datasets import load_metric

acc_metric = load_metric("accuracy")
f1_metric = load_metric("f1")
p_metric = load_metric("precision")
r_metric = load_metric("recall")

In [10]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    acc = acc_metric.compute(predictions=predictions, references=labels)
    f1 = f1_metric.compute(predictions=predictions, references=labels)
    prec = p_metric.compute(predictions=predictions, references=labels)
    rec = r_metric.compute(predictions=predictions, references=labels)
    return {'accuracy':acc, 'f1':f1, 'precision':prec, 'recall':rec}

In [11]:
from torch import nn
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(output_dir="cr_bert_trainer", evaluation_strategy="epoch", learning_rate=2e-5, num_train_epochs=8, save_steps=1500, report_to=None)

In [12]:
class CRTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        # compute custom loss (suppose one has 3 labels with different weights)
        loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 1.0]).to(device))
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

trainer = CRTrainer(
    model=model,
    args=training_args,
    train_dataset=train_test_dataset['train'],
    eval_dataset=train_test_dataset['test'],
    compute_metrics=compute_metrics,
)

In [13]:
trainer.train()

In [14]:
import pandas as pd

test_df = pd.read_csv('../input/counterfactualrecognition/subtask1_test_bert.csv')

In [15]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained('/kaggle/working/cr_bert_trainer/checkpoint-6000', num_labels=2)
model.to(device)

In [16]:
inputs = tokenizer(list(test_df['text']), padding="max_length", truncation=True, max_length=100, return_tensors="pt")

In [17]:
test_size = len(test_df)
batch_size = 8
pred_labels = torch.zeros(test_size)

for i in range(0, test_size, batch_size):
    curr_inputs = {k:v[i:min(test_size, i+batch_size)].to(device)  for k, v in inputs.items()}
    outputs = model(**curr_inputs)
    pred_labels[i:min(test_size, i+batch_size)] = outputs.get('logits').argmax(1)

In [18]:
pred_labels = np.array(pred_labels.cpu().numpy())
gold_labels = np.array(test_df['label'])

In [19]:
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score

print('F1: ', f1_score(gold_labels, pred_labels))
print('Accuracy: ', accuracy_score(gold_labels, pred_labels))
print('Precision: ', precision_score(gold_labels, pred_labels))
print('Recall: ', recall_score(gold_labels, pred_labels))

## False positives and negatives

In [None]:
preds = pred_labels * 10 + gold_labels
tps = preds == 11
fps = preds == 10
fns = preds == 1
tns = preds == 0

In [None]:
list(np.array(test_df['text'])[fps])