In [7]:
import torch
import scipy.sparse.linalg
from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import train_test_split
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
)
import numpy as np
import evaluate

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

model_name = "bert-base-uncased"
max_length = 512

metric = evaluate.load("accuracy")

tokenizer = AutoTokenizer.from_pretrained(model_name)

def read_20newsgroups(test_size=0.2):
    dataset = fetch_20newsgroups(subset="all", shuffle=True, remove=("headers", "footers", "quotes"))
    documents = dataset.data
    labels = dataset.target

    return train_test_split(documents, labels, test_size=test_size), dataset.target_names

(train_texts, valid_texts, train_labels, valid_labels), target_names = read_20newsgroups()
train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=max_length)
valid_encodings = tokenizer(valid_texts, truncation=True, padding=True, max_length=max_length)

class NewsGroupsDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
        item["labels"] = torch.tensor([self.labels[idx]])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = NewsGroupsDataset(train_encodings, train_labels)
eval_dataset = NewsGroupsDataset(valid_encodings, valid_labels)

model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(target_names))
training_args = TrainingArguments(output_dir="test_trainer")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

training_args = TrainingArguments(output_dir="test_trainer", eval_strategy="epoch")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
)

trainer.train()



Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 