In [None]:
from transformers import TrainingArguments, Trainer, DataCollatorWithPadding, AutoTokenizer
from qbert.modeling_qbert import QBertForSequenceClassification, QBertConfig
import numpy as np
from sklearn.metrics import accuracy_score

def compute_metrics(eval_pred):
    print("compute_metrics called!")  # Debugging print
    print("Eval pred:", eval_pred)

    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return {"accuracy": accuracy_score(labels, predictions)}

class DummyDataset:
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Dummy Data
train_data = [{"input_ids": [1, 2, 3], "labels": 0, "token_type_ids": [0, 0, 0], "attention_mask": [1, 1, 1]}, {"input_ids": [4, 5, 6], "labels": 1, "token_type_ids": [0, 0, 0], "attention_mask": [1, 1, 1]}]
eval_data = [{"input_ids": [7, 8, 9], "labels": 1, "token_type_ids": [0, 0, 0], "attention_mask": [1, 1, 1]}, {"input_ids": [10, 11, 12], "labels": 0, "token_type_ids": [0, 0, 0], "attention_mask": [1, 1, 1]}]

train_dataset = DummyDataset(train_data)
eval_dataset = DummyDataset(eval_data)

training_args = TrainingArguments(
    output_dir="./test_trainer",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    metric_for_best_model="accuracy",
    remove_unused_columns=True,
    num_train_epochs=5,
)

tokenizer = AutoTokenizer.from_pretrained("google/bert_uncased_L-8_H-512_A-8")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

trainer = Trainer(
    model=QBertForSequenceClassification(config=QBertConfig(num_labels=3)),  # Replace with a dummy model or your actual model
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.train()
trainer.evaluate()