In [None]:
# %load_ext autoreload
# %autoreload 2

In [None]:
from torch.optim import AdamW
from torch.utils.data import DataLoader

from transformers import AutoTokenizer
from datasets import load_dataset

from tklearn.metrics import Accuracy, ArrayAccumulator
from tklearn.nn import Trainer, Evaluator
from tklearn.nn.optim import BERTAdamW
from tklearn.nn.callbacks import ProgbarLogger, EarlyStopping
from tklearn.nn.transformers import TransformerForSequenceClassification

In [None]:
MODEL_NAME_OR_PATH = "google-bert/bert-base-uncased"
DATASET = "yelp_review_full"

In [None]:
dataset = load_dataset(DATASET)

dataset["train"][100]

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH)


def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

In [None]:
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")

In [None]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

In [None]:
train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=16)
valid_dataloader = DataLoader(small_eval_dataset, batch_size=32)

In [None]:
NUM_EPOCHS = 20

model = TransformerForSequenceClassification.from_pretrained(
    MODEL_NAME_OR_PATH, num_labels=5
)

model.to("mps")

# model = model.compile()

optimizer = BERTAdamW(
    model.parameters(), lr=2e-6, warmup=0.1, t_total=len(train_dataloader) * NUM_EPOCHS
)

In [None]:
# BREAK

In [None]:
evaluator = Evaluator(
    model,
    valid_dataloader,
    callbacks=[ProgbarLogger()],
    metrics={"acuracy": Accuracy()},
    prefix="valid_",
)

trainer = Trainer(
    model,
    train_dataloader,
    optimizer=optimizer,
    callbacks=[ProgbarLogger(), EarlyStopping(patience=5)],
    evaluator=evaluator,
    epochs=NUM_EPOCHS,
)

In [None]:
trainer.train()

In [None]:
grad_dims = []
for pn, pp in model.named_parameters():
    if pn.startswith("classification."):
        continue
    grad_dims.append((pp.data.shape, pp.data.numel()))

In [None]:
import copy

old_model = copy.deepcopy(model)

In [None]:
old_grads = grad_dims

In [None]:
model.set_num_labels(10)

In [None]:
grad_dims = []
for pn, pp in model.named_parameters():
    ppold = old_model.get_parameter(pn)
    if ppold.shape != pp.shape:
        ppview = pp.data.view(-1)[: ppold.shape[0]]
        grad_dims.append((ppview.shape, pp.data.numel()))
        continue
    grad_dims.append((pp.data.view(-1).shape, pp.data.numel()))

In [None]:
# grad_dims

In [None]:
model.history.to_pandas()[["loss", "valid_loss"]].plot.line()

In [None]:
# e = ValueError("hahah!")

# f"message: {e!s}"