In [None]:
import torch

import pyroml as p
from pyroml.loop import Loop
from pyroml.callbacks.progress import TQDMProgress
from pyroml.template.iris import IrisModel, IrisDataset

In [None]:
p.seed_everything(42)

In [None]:
ds = IrisDataset()
tr_ds, ev_ds, te_ds = torch.utils.data.random_split(
    ds, [int(0.5 * len(ds)), int(0.2 * len(ds)), int(0.3 * len(ds))]
)
len(tr_ds), len(ev_ds), len(te_ds)

In [None]:
class ScheduledIrisNet(IrisModel):
    def configure_optimizers(self, loop: "Loop"):
        tr = self.trainer
        self.optimizer = torch.optim.AdamW(self.parameters(), lr=tr.lr)
        self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
            self.optimizer,
            max_lr=tr.lr,
            total_steps=loop.total_steps,
            steps_per_epoch=loop.steps_per_epoch,
            anneal_strategy="cos",
            cycle_momentum=False,
            div_factor=1e2,
            final_div_factor=0.05,
        )

In [None]:
model = ScheduledIrisNet()
model

In [None]:
trainer = p.Trainer(
    lr=0.005,
    compile=True,
    max_epochs=12,
    batch_size=4,
    evaluate_on="epoch",
    evaluate_every=1,
    device="cpu",
    dtype=torch.bfloat16,
    callbacks=[TQDMProgress()],
)

In [None]:
tr_tracker = trainer.fit(model=model, tr_dataset=tr_ds, ev_dataset=ev_ds)

In [None]:
tr_tracker.plot(epoch=True)

In [None]:
te_tracker = trainer.evaluate(model, te_ds)
te_tracker.plot(epoch=True, kind="bar")