In [None]:
import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer
)
import evaluate

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = load_dataset("ag_news")

Generating train split: 100%|██████████| 120000/120000 [00:00<00:00, 2046318.05 examples/s]
Generating test split: 100%|██████████| 7600/7600 [00:00<00:00, 2286708.06 examples/s]


In [7]:
dataset["train"][:2]

{'text': ["Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.",
  'Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.'],
 'label': [2, 2]}

In [4]:
dataset["train"].column_names

['text', 'label']

In [11]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [12]:
def tokenize_fn(batch):
    return tokenizer(
        batch["text"],
        truncation=True,
        padding="max_length",
        max_length=128
    )

In [13]:
tokenized_ds = dataset.map(tokenize_fn, batched=True)
tokenized_ds = tokenized_ds.remove_columns(["text"])
tokenized_ds = tokenized_ds.rename_column("label", "labels")
tokenized_ds.set_format("torch")


Map: 100%|██████████| 120000/120000 [00:03<00:00, 31451.35 examples/s]
Map: 100%|██████████| 7600/7600 [00:00<00:00, 31711.68 examples/s]


In [14]:
model = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=4
)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = torch.argmax(torch.tensor(logits), dim=1)
    return accuracy.compute(predictions=preds, references=labels)


In [16]:
training_args = TrainingArguments(
    output_dir="./agnews_bert",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    weight_decay=0.01,
    logging_steps=200,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy"
)


In [17]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_ds["train"],
    eval_dataset=tokenized_ds["test"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)


  trainer = Trainer(


In [18]:
trainer.train()

2025/12/31 14:32:41 INFO mlflow.store.db.utils: Creating initial MLflow database tables...
2025/12/31 14:32:41 INFO mlflow.store.db.utils: Updating database tables
2025/12/31 14:32:41 INFO alembic.runtime.migration: Context impl SQLiteImpl.
2025/12/31 14:32:41 INFO alembic.runtime.migration: Will assume non-transactional DDL.
2025/12/31 14:32:41 INFO alembic.runtime.migration: Context impl SQLiteImpl.
2025/12/31 14:32:41 INFO alembic.runtime.migration: Will assume non-transactional DDL.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.2114,0.181238,0.943553
2,0.1209,0.184403,0.9475


TrainOutput(global_step=15000, training_loss=0.17707151120503745, metrics={'train_runtime': 540.7641, 'train_samples_per_second': 443.816, 'train_steps_per_second': 27.739, 'total_flos': 1.578694680576e+16, 'train_loss': 0.17707151120503745, 'epoch': 2.0})

In [19]:
trainer.evaluate()

{'eval_loss': 0.18440325558185577,
 'eval_accuracy': 0.9475,
 'eval_runtime': 5.6412,
 'eval_samples_per_second': 1347.224,
 'eval_steps_per_second': 84.201,
 'epoch': 2.0}