## Installs

In [7]:
#!pip install datasets

In [8]:
#!pip install 'transformers[torch]' -U

## Imports

In [2]:
from transformers import AutoTokenizer, BertForSequenceClassification
from transformers import BertModel
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from sklearn.metrics import classification_report

In [3]:
from torch import nn
import torch

In [4]:
import pandas as pd

In [5]:
import numpy as np

In [6]:
from transformers import default_data_collator
from torch.utils.data import DataLoader

2024-05-16 13:00:59.653796: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Import Model

In [9]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

In [103]:
import sys
sys.path.append('model/code-bert/')
from temporal_relation_classification import TemporalRelationClassification
from temporal_relation_classification_config import TemporalRelationClassificationConfig

In [104]:
model_path = "saved_model/bert-base-uncased-saved-model"
model = TemporalRelationClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

You are using a model of type roberta to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.


KeyboardInterrupt: 

In [75]:
model.resize_token_embeddings(len(tokenizer))

Embedding(30526, 1024)

## Initilise Model

In [105]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [106]:
model.eval()

TemporalRelationClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(50269, 1024)
      (position_embeddings): Embedding(514, 1024)
      (token_type_embeddings): Embedding(1, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-23): 24 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), eps=1e-05, 

## Import Instances

In [16]:
original = pd.read_csv('data/annotated/partitions.csv')
counterfactual = pd.read_csv('data/annotated/counterfactuals.csv')

In [18]:
counterfactual

Unnamed: 0,id,context,eventA,eventB,label_temp,label,counterfactual,new_label,new_temp
0,WSJ_20130322_159,Israeli Prime Minister Benjamin Netanyahu apol...,seemed,yield,AFTER,1,Israeli Prime Minister Benjamin Netanyahu apol...,0,BEFORE
1,WSJ_20130322_159,Israeli Prime Minister Benjamin Netanyahu apol...,wrapped,confirmed,BEFORE,0,Israeli Prime Minister Benjamin Netanyahu apol...,1,AFTER
2,WSJ_20130322_159,Israeli Prime Minister Benjamin Netanyahu apol...,told,spoken,AFTER,1,Israeli Prime Minister Benjamin Netanyahu apol...,0,BEFORE
3,WSJ_20130322_159,Israeli Prime Minister Benjamin Netanyahu apol...,expresses,spoken,AFTER,1,Israeli Prime Minister Benjamin Netanyahu apol...,0,BEFORE
4,WSJ_20130322_159,Israeli Prime Minister Benjamin Netanyahu apol...,spoken,discussed,BEFORE,0,Israeli Prime Minister Benjamin Netanyahu apol...,1,AFTER
...,...,...,...,...,...,...,...,...,...
199,CNN_20130322_248,The FAA on Friday announced it will close 149 ...,sparing,begin,BEFORE,0,The FAA on Friday announced it will close 149 ...,1,AFTER
200,AP_20130322,"The flu season is winding down, and it has kil...",dropping,appear,BEFORE,0,"The flu season is winding down, and it has kil...",1,AFTER
201,nyt_20130322_strange_computer,"Our digital age is all about bits, those preci...",expect,done,VAGUE,3,"Our digital age is all about bits, those preci...",0,BEFORE
202,nyt_20130321_cyprus,"A Cyprus exit from the euro union, if it comes...",struggling,mount,VAGUE,3,"A Cyprus exit from the euro union, if it comes...",0,BEFORE


In [19]:
counterfactual = counterfactual[['id', 'counterfactual', 'new_temp', 'new_label']]

In [20]:
counterfactual = counterfactual.rename(columns={'n': 'id', 'counterfactual': 'context', 'new_temp':'label_temp', 'new_label':'label'})

In [21]:
label_mapping = {
    'BEFORE': 0,
    'AFTER': 1,
    'EQUAL': 2,
    'VAGUE': 3
}

original['label'] = original['label_temp'].map(label_mapping)
counterfactual['label'] = counterfactual['label_temp'].map(label_mapping)

In [22]:
from datasets import Dataset, DatasetDict
dataset = DatasetDict({
    "original": Dataset.from_pandas(original),
    "counterfactual":  Dataset.from_pandas(counterfactual)
})

## Process Instances

In [None]:
dataset

In [112]:
import re
def annotate_text(row, column):
    context = row['context']
    eventA = re.escape(row['eventA'])
    eventB = re.escape(row['eventB'])

    # Make sure the longer event is replaced first if they overlap
    if len(eventA) > len(eventB):
        context = re.sub(eventA, f"[a1]{row['eventA']}[/a1]", context)
        context = re.sub(eventB, f"[a2]{row['eventB']}[/a2]", context)
    else:
        context = re.sub(eventB, f"[a2]{row['eventB']}[/a2]", context)
        context = re.sub(eventA, f"[a1]{row['eventB']}[/a1]", context)

    return context

In [53]:
from transformers import Trainer
evaluator = Trainer(
            model=model,
            eval_dataset=tokenized_datasets["original_with_key"],
            tokenizer=tokenizer,
            data_collator=data_collator,
            compute_metrics=compute_metrics,
        )

In [54]:
eval_mode = True
print('Evaluate:')
evaluator.evaluate(tokenized_datasets['original_with_key'])

Evaluate:


{'eval_runtime': 68.8437,
 'eval_samples_per_second': 2.048,
 'eval_steps_per_second': 0.261}

In [113]:
original['annotated_context'] = original.apply(lambda row: annotate_text(row, 'context'), axis=1)
counterfactual['annotated_context'] = counterfactual.apply(lambda row: annotate_text(row, 'context'), axis=1)

In [114]:
from datasets import Dataset, DatasetDict
dataset = DatasetDict({
    "original": Dataset.from_pandas(original),
    "counterfactual":  Dataset.from_pandas(counterfactual)
})

In [115]:
from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
tokenizer_class = str(type(tokenizer)).strip("><'").split('.')[-1]

In [121]:
def preprocess_function(examples):
  max_length_value = 508
  return tokenizer(examples["annotated_context"], truncation=True, max_length=max_length_value)

tokenized_datasets = dataset.map(preprocess_function, batched=True)

Map:   0%|          | 0/204 [00:00<?, ? examples/s]

Map:   0%|          | 0/203 [00:00<?, ? examples/s]

In [122]:
tokenized_datasets

DatasetDict({
    original: Dataset({
        features: ['id', 'context', 'eventA', 'eventB', 'label_temp', 'label', 'pos_partition', 'neg_partition', 'annotated_context', 'input_ids', 'attention_mask'],
        num_rows: 204
    })
    counterfactual: Dataset({
        features: ['id', 'context', 'eventA', 'eventB', 'label_temp', 'label', 'pos_partition', 'neg_partition', 'annotated_context', 'input_ids', 'attention_mask'],
        num_rows: 203
    })
})

In [123]:
from datasets import DatasetDict

required_ids = {
50265,
50266,
50267,
50268}

def contains_all_required_ids(example):
    return required_ids.issubset(set(example['input_ids']))

filtered_original_dataset = tokenized_datasets['original'].filter(contains_all_required_ids)
tokenized_datasets['original'] = filtered_original_dataset
print("Number of instances after filtering: ", len(filtered_original_dataset))


Filter:   0%|          | 0/204 [00:00<?, ? examples/s]

Number of instances after filtering:  196


In [124]:
from datasets import DatasetDict

required_ids = {
50265,
50266,
50267,
50268}

def contains_all_required_ids(example):
    return required_ids.issubset(set(example['input_ids']))

filtered_counterfactual_dataset = tokenized_datasets['counterfactual'].filter(contains_all_required_ids)
tokenized_datasets['counterfactual'] = filtered_counterfactual_dataset
print("Number of instances after filtering: ", len(filtered_counterfactual_dataset))

Filter:   0%|          | 0/203 [00:00<?, ? examples/s]

Number of instances after filtering:  196


In [125]:
counterfactuals = filtered_counterfactual_dataset.to_pandas()
original = filtered_original_dataset.to_pandas()

## Evaluation Function

In [126]:
eval_mode = False


def compute_metrics(eval_preds):
    predictions, labels = eval_preds
    predictions = np.argmax(predictions, axis=1)

    if eval_mode:
        report = classification_report(y_true=labels, y_pred=predictions,
                                       target_names=['BEFORE', 'AFTER', 'EQUAL', 'VAGUE'])
        for i in range(labels.shape[0]):
            if labels[i] == 3 and predictions[i] != 3:
                labels[i] = predictions[i]
        report_no_vague = classification_report(y_true=labels, y_pred=predictions,
                                                target_names=['BEFORE', 'AFTER', 'EQUAL', 'VAGUE'])


        print(report)
        print(report_no_vague)

    results = \
        classification_report(y_true=labels, y_pred=predictions, target_names=['BEFORE', 'AFTER', 'EQUAL','VAGUE'],
                              output_dict=True)
    final_results = results['weighted avg']
    final_results.pop('support')
    final_results['BEFORE-f1'] = results['BEFORE']['f1-score']
    final_results['AFTER-f1'] = results['AFTER']['f1-score']
    final_results['EQUAL-f1'] = results['EQUAL']['f1-score']
    final_results['VAGUE-f1'] = results['VAGUE']['f1-score']
    return final_results

## Evaluating Original Datasplit

In [127]:
from transformers import Trainer
evaluator = Trainer(
            model=model,
            eval_dataset=tokenized_datasets["original"],
            tokenizer=tokenizer,
            data_collator=data_collator,
            compute_metrics=compute_metrics,
        )

In [128]:
eval_mode = True
print('Evaluate:')
evaluator.evaluate(tokenized_datasets['original'])

Evaluate:


              precision    recall  f1-score   support

      BEFORE       0.77      0.90      0.83        83
       AFTER       0.92      0.43      0.59        79
       EQUAL       0.25      0.15      0.19        26
       VAGUE       0.07      0.38      0.11         8

    accuracy                           0.59       196
   macro avg       0.50      0.47      0.43       196
weighted avg       0.73      0.59      0.62       196

              precision    recall  f1-score   support

      BEFORE       0.81      0.91      0.85        87
       AFTER       0.92      0.43      0.59        79
       EQUAL       0.31      0.19      0.23        27
       VAGUE       0.07      1.00      0.12         3

    accuracy                           0.62       196
   macro avg       0.53      0.63      0.45       196
weighted avg       0.77      0.62      0.65       196



{'eval_loss': 4.744256019592285,
 'eval_precision': 0.7722691206929545,
 'eval_recall': 0.6173469387755102,
 'eval_f1-score': 0.6493220270292387,
 'eval_BEFORE-f1': 0.854054054054054,
 'eval_AFTER-f1': 0.5862068965517242,
 'eval_EQUAL-f1': 0.2325581395348837,
 'eval_VAGUE-f1': 0.125,
 'eval_runtime': 285.3923,
 'eval_samples_per_second': 0.687,
 'eval_steps_per_second': 0.088}

## Evaluating Counterfactual Datasplit

In [129]:
from transformers import Trainer
evaluator = Trainer(
            model=model,
            eval_dataset=tokenized_datasets["counterfactual"],
            tokenizer=tokenizer,
            data_collator=data_collator,
            compute_metrics=compute_metrics,
        )

In [130]:
eval_mode = True
print('Evaluate:')
evaluator.evaluate(tokenized_datasets['counterfactual'])

Evaluate:


              precision    recall  f1-score   support

      BEFORE       0.57      0.64      0.60        83
       AFTER       0.67      0.40      0.50        80
       EQUAL       0.14      0.09      0.11        33
       VAGUE       0.00      0.00      0.00         0

    accuracy                           0.45       196
   macro avg       0.34      0.28      0.30       196
weighted avg       0.54      0.45      0.48       196

              precision    recall  f1-score   support

      BEFORE       0.57      0.64      0.60        83
       AFTER       0.67      0.40      0.50        80
       EQUAL       0.14      0.09      0.11        33
       VAGUE       0.00      0.00      0.00         0

    accuracy                           0.45       196
   macro avg       0.34      0.28      0.30       196
weighted avg       0.54      0.45      0.48       196



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'eval_loss': 7.280460357666016,
 'eval_precision': 0.5364000438885231,
 'eval_recall': 0.4489795918367347,
 'eval_f1-score': 0.47749304267161413,
 'eval_BEFORE-f1': 0.6022727272727274,
 'eval_AFTER-f1': 0.5,
 'eval_EQUAL-f1': 0.10909090909090909,
 'eval_VAGUE-f1': 0.0,
 'eval_runtime': 348.6763,
 'eval_samples_per_second': 0.562,
 'eval_steps_per_second': 0.072}

## Predictions for Original Datasplit

In [None]:
def evaluate(model, dataloader):
    model.eval()
    device = next(model.parameters()).device
    all_predictions, all_labels = [], []

    for batch in dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)

        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1).cpu().numpy()
        labels = batch['labels'].cpu().numpy()

        all_predictions.extend(predictions)
        all_labels.extend(labels)

    return np.array(all_predictions), np.array(all_labels)

In [None]:
from transformers import default_data_collator
counterfactual_data_collator = default_data_collator
counterfactual_dataloader = DataLoader(tokenized_datasets['counterfactual'], collate_fn=counterfactual_data_collator)

counterfactual_predictions, counterfactuals_labels = evaluate(model, counterfactual_dataloader)

counterfactuals_df = pd.DataFrame({
    'prediction': counterfactual_predictions,
    'true_label': counterfactuals_labels
})


In [None]:
original_data_collator = default_data_collator
original_dataloader = DataLoader(tokenized_datasets['original'], collate_fn=original_data_collator)

original_predictions, original_labels = evaluate(model, original_dataloader)

original_df = pd.DataFrame({
    'prediction': original_predictions,
    'true_label': original_labels
})

In [None]:
len(original_predictions)

196

## Saving Predictions

In [None]:
original_df.to_csv('/content/drive/My Drive/XAI/BERT-BASE/original_results/predictions-OG-bert-base.csv', index=False)
counterfactuals_df.to_csv('/content/drive/My Drive/XAI/BERT-BASE/counterfactual_results/predictions-CF-bert-base.csv', index=False)

In [None]:
df['counterfactual_predictions'] = counterfactual_predictions
df['original_predictions'] = original_predictions
df.rename(columns={'new_temp': 'label_temp_counterfactuals'}, inplace=True)
df.rename(columns={'new_label': 'label_counterfactuals'}, inplace=True)

In [None]:
df=df[['id', 'eventA', 'eventB',
       'context', 'label_temp', 'label', 'original_predictions',
       'counterfactual', 'label_temp_counterfactuals', 'label_counterfactuals', 'counterfactual_predictions']]