In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np
import torch.nn as nn

import sys
sys.path.append('../')


import pyroml as p
from pyroml.template.iris import IrisNet, IrisDataset, load_dataset

In [3]:
SEED = 42
p.seed_everything(SEED)

In [4]:
ds = load_dataset()
ds = ds.shuffle(seed=SEED)
tr_ds, ev_ds, te_ds = np.split(ds, [int(0.6 * len(ds)), int(0.7 * len(ds))])

tr_ds = IrisDataset(tr_ds)
ev_ds = IrisDataset(ev_ds)
te_ds = IrisDataset(te_ds)

In [5]:
class ScheduledIrisNet(IrisNet):
    def configure_optimizers(self, loop: "p.Loop"):
        tr = self.trainer
        self.optimizer = torch.optim.AdamW(self.parameters(), lr=tr.lr)
        self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
            self.optimizer,
            max_lr=tr.lr,
            total_steps=loop.total_steps,
            steps_per_epoch=loop.steps_per_epochs,
            epochs=tr.max_epochs,
            anneal_strategy="cos",
            cycle_momentum=False,
            div_factor=1e2,
            final_div_factor=0.05,
        )

    def forward(self, x):
        return super().forward(x)

In [6]:
model = ScheduledIrisNet()
model

ScheduledIrisNet(
  (module): Sequential(
    (0): Linear(in_features=4, out_features=16, bias=True)
    (1): ReLU()
    (2): Linear(in_features=16, out_features=16, bias=True)
    (3): ReLU()
    (4): Linear(in_features=16, out_features=3, bias=True)
    (5): Softmax(dim=1)
  )
)

In [7]:
import logging
trainer = p.Trainer(
    compile=True,
    loss=nn.CrossEntropyLoss(),
    max_epochs=12,
    batch_size=16,
    lr=0.005,
    evaluate=True,
    evaluate_every=12,
    wandb=False,
    dtype=torch.bfloat16,
    log_level=logging.INFO,
)

In [9]:
tr_tracker = trainer.fit(model, tr_ds, ev_ds)
tr_tracker.records

Unnamed: 0,stage,epoch,step,acc,pre,rec,loss,epoch_acc,epoch_pre,epoch_rec,epoch_loss
0,validation,0,0,0.0,0.0,0.0,1.101419,,,,
1,validation,0,1,,,,,0.0,0.0,0.0,1.101419
2,train,0,0,0.0,0.0,0.0,1.084991,,,,
3,train,0,1,0.0,0.0,0.0,1.109656,,,,
4,train,0,2,0.0,0.0,0.0,1.109476,,,,
...,...,...,...,...,...,...,...,...,...,...,...
91,train,11,68,0.0,0.0,0.0,0.815292,,,,
92,train,11,69,0.0,0.0,0.0,0.792867,,,,
93,train,11,70,0.0,0.0,0.0,0.820310,,,,
94,train,11,71,0.0,0.0,0.0,0.872430,,,,


In [12]:
tr_tracker.plot()

NotImplementedError: This method is not implemented yet

In [14]:
te_metrics = trainer.test(model, te_ds)
te_metrics.records

Unnamed: 0,stage,epoch,step,acc,pre,rec,loss,epoch_acc,epoch_pre,epoch_rec,epoch_loss
0,test,0,0,0.0,0.0,0.0,0.811332,,,,
1,test,0,1,0.0,0.0,0.0,0.790323,,,,
2,test,0,2,0.0,0.0,0.0,0.792543,,,,
3,test,0,3,,,,,0.0,0.0,0.0,0.798066
