In [None]:
# %load_ext autoreload
# %autoreload 2

In [None]:
from torch.optim import AdamW
from torch.utils.data import DataLoader

from transformers import AutoTokenizer
from datasets import load_dataset

from tklearn.metrics import Accuracy
from tklearn.nn import Trainer, Evaluator
from tklearn.nn.callbacks import ProgbarLogger, EarlyStopping
from tklearn.nn.prototypes import PrototypeForSequenceClassification, PrototypeCallback

In [None]:
MODEL_NAME_OR_PATH = "google-bert/bert-base-uncased"
DATASET = "yelp_review_full"

In [None]:
dataset = load_dataset(DATASET)

dataset["train"][100]

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH)

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

In [None]:
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")

In [None]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

In [None]:
model = PrototypeForSequenceClassification.from_pretrained(MODEL_NAME_OR_PATH)

model.to("mps")

train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=16)

In [None]:
valid_dataloader = DataLoader(small_eval_dataset, batch_size=32)

evaluator = Evaluator(model, valid_dataloader, callbacks=[PrototypeCallback(train_dataloader), ProgbarLogger()], metrics={"acuracy": Accuracy()}, prefix="valid_")

In [None]:
optimizer = AdamW(model.parameters(), lr=1e-5)

evaluator = Evaluator(model, valid_dataloader, callbacks=[PrototypeCallback(train_dataloader), ProgbarLogger()], metrics={"acuracy": Accuracy()}, prefix="valid_")

trainer = Trainer(model, train_dataloader, optimizer=optimizer, callbacks=[ProgbarLogger(), EarlyStopping(patience=0)], evaluator=evaluator, epochs=10)

In [None]:
trainer.train()