In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import sys

sys.path.append("../")


import pyroml as p
from pyroml.template.iris import IrisModel, IrisDataset

In [None]:
SEED = 42
p.seed_everything(SEED)

In [None]:
ds = IrisDataset()
tr_ds, ev_ds, te_ds = torch.utils.data.random_split(
    ds, [int(0.5 * len(ds)), int(0.2 * len(ds)), int(0.3 * len(ds))]
)
len(tr_ds), len(ev_ds), len(te_ds)

In [None]:
from pyroml.loop import Loop


class ScheduledIrisNet(IrisModel):
    def configure_optimizers(self, loop: "Loop"):
        tr = self.trainer
        self.optimizer = torch.optim.AdamW(self.parameters(), lr=tr.lr)
        self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
            self.optimizer,
            max_lr=tr.lr,
            total_steps=loop.total_steps,
            steps_per_epoch=loop.steps_per_epochs,
            anneal_strategy="cos",
            cycle_momentum=False,
            div_factor=1e2,
            final_div_factor=0.05,
        )

In [None]:
model = ScheduledIrisNet()
model

In [None]:
import logging

from pyroml.callbacks.progress import TQDMProgress

trainer = p.Trainer(
    device="cpu",
    compile=True,
    max_epochs=12,
    batch_size=4,
    lr=0.005,
    evaluate_on="epoch",
    evaluate_every=1,
    wandb=False,
    dtype=torch.bfloat16,
    log_level=logging.NOTSET,
    num_workers=0,
    callbacks=[TQDMProgress()],
)

In [None]:
tr_tracker = trainer.fit(model=model, tr_dataset=tr_ds, ev_dataset=ev_ds)
tr_tracker.records

In [None]:
r = tr_tracker.records
r[(r["stage"] == p.Stage.TRAIN.value) & (r["step"] != -1)][["step", "loss"]].plot(
    x="step"
)

In [None]:
tr_tracker.plot(stage=p.Stage.VAL, epoch=True, kind="bar")

In [None]:
te_metrics = trainer.evaluate(model, te_ds)
te_metrics.plot(epoch=True, kind="bar")