In [1]:
import torch
from transformers import BertForSequenceClassification, BertTokenizer, Trainer, TrainingArguments
from datasets import Dataset, load_metric
import pandas as pd
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
import numpy as np



In [2]:
model_name = 'mhr2004/BERT_BOOLQ_Circa_YN'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name)

  return self.fget.__get__(instance, owner)()


In [3]:
# Load the dataset
dataset = pd.read_parquet('Data/Circa_train.parquet')

In [4]:
# Filter and map the labels
label_map = {
    0: 0,
    1: 1,
    2: 2,
    3: 3
}
filtered_dataset = dataset[dataset['goldstandard2'].isin([0, 1, 2, 3])]
#filtered_dataset['goldstandard2'] = filtered_dataset['goldstandard2'].map(label_map)



In [5]:
# Split the data
train_data, dev_data = train_test_split(filtered_dataset, test_size=0.4, random_state=42)
dev_data, test_data = train_test_split(dev_data, test_size=0.5, random_state=42)

# Convert pandas DataFrame to Hugging Face Dataset
train_dataset = Dataset.from_pandas(train_data)
dev_dataset = Dataset.from_pandas(dev_data)
test_dataset = Dataset.from_pandas(test_data)

In [6]:
# Preprocess function
def preprocess(examples):
    tokenized_inputs = tokenizer(examples['question-X'], examples['answer-Y'], truncation=True, padding='max_length', max_length=128)
    tokenized_inputs['labels'] = examples['goldstandard2']
    return tokenized_inputs

# Apply preprocessing
encoded_train_dataset = train_dataset.map(preprocess, batched=True)
encoded_dev_dataset = dev_dataset.map(preprocess, batched=True)
encoded_test_dataset = test_dataset.map(preprocess, batched=True)

# Load metrics
accuracy_metric = load_metric("accuracy")

Map:   0%|          | 0/19795 [00:00<?, ? examples/s]

Map:   0%|          | 0/6599 [00:00<?, ? examples/s]

Map:   0%|          | 0/6599 [00:00<?, ? examples/s]

  accuracy_metric = load_metric("accuracy")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [7]:
# Compute metrics function
def compute_metrics(p):
    predictions, labels = p
    preds = np.argmax(predictions, axis=1)
    acc = accuracy_metric.compute(predictions=preds, references=labels)
    f1_weighted = f1_score(labels, preds, average='weighted')
    f1_per_class = f1_score(labels, preds, average=None)
    return {
        'accuracy': acc['accuracy'],
        'f1_weighted': f1_weighted,
        'f1_per_class': f1_per_class
    }

# Training arguments
training_args = TrainingArguments(
    output_dir='./results',
    per_device_eval_batch_size=16,
    do_train=False,
    do_eval=True,
    logging_dir='./logs'
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    eval_dataset=encoded_test_dataset,
    compute_metrics=compute_metrics
)

In [8]:
# Evaluate the model
eval_result = trainer.evaluate()

  0%|          | 0/413 [00:00<?, ?it/s]

Trainer is attempting to log a value of "[0.95728834 0.9520718  0.58252427 0.94128611]" of type <class 'numpy.ndarray'> for key "eval/f1_per_class" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.


In [9]:
# Print results
acc = eval_result.get('eval_accuracy', None) * 100 if 'eval_accuracy' in eval_result else None
f1_weighted = eval_result.get('eval_f1_weighted', None) * 100 if 'eval_f1_weighted' in eval_result else None
f1_per_class = eval_result.get('eval_f1_per_class', None) * 100 if 'eval_f1_per_class' in eval_result else None

print(f"Accuracy: {acc:.1f}%" if acc is not None else "Accuracy not found")
print(f"Weighted F1 Score: {f1_weighted:.1f}%" if f1_weighted is not None else "Weighted F1 Score not found")
print(f"F1 Score for each class: {f1_per_class}" if f1_per_class is not None else "F1 Score for each class not found")

Accuracy: 94.5%
Weighted F1 Score: 94.3%
F1 Score for each class: [95.72883417 95.20717968 58.25242718 94.12861137]
