<a href="https://colab.research.google.com/github/sangamsky/-WCCI-26-DL-NL--07-Attention-Faithfulness-in-Transformer-Based-Text-Classifiers-on-Real-Datasets/blob/main/code/notebooks/baseline_replication.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:


import numpy as np
import torch

from datasets import load_dataset
from transformers import (
    BertTokenizer,
    BertForSequenceClassification,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding
)

import evaluate




dataset = load_dataset("glue", "sst2")

print(dataset)




tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")



def tokenize_function(example):
    return tokenizer(
        example["sentence"],
        truncation=True,
        padding="max_length",
        max_length=64
    )


tokenized_dataset = dataset.map(tokenize_function, batched=True)


tokenized_dataset = tokenized_dataset.remove_columns(["sentence", "idx"])
tokenized_dataset = tokenized_dataset.rename_column("label", "labels")
tokenized_dataset.set_format("torch")



data_collator = DataCollatorWithPadding(tokenizer=tokenizer)


model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=2
)




accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    if isinstance(eval_pred, tuple):
        logits, labels = eval_pred
    else:
        logits = eval_pred.predictions
        labels = eval_pred.label_ids


    if isinstance(logits, tuple):
        logits = logits[0]

    predictions = np.argmax(logits, axis=-1)
    return accuracy.compute(predictions=predictions, references=labels)




training_args = TrainingArguments(
    output_dir="./sst2_results",
    eval_strategy="epoch",
    save_strategy="no",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    weight_decay=0.01,
    logging_steps=100,
    report_to="none"
)



trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)



trainer.train()


val_results = trainer.evaluate()
print("SST-2 Validation Results:", val_results)

