In [None]:
# load dataset
from p_efficientnet_b7_1000 import define
from torchvision.datasets import ImageFolder
from torch.utils.data.dataloader import DataLoader

transforms = define.transform()

train_data = ImageFolder(f"{define.get_datafolder()}/train", transform=transforms)
test_data = ImageFolder(f"{define.get_datafolder()}/test", transform=transforms)

train_loader = DataLoader(train_data, batch_size=20, shuffle=True)
test_loader = DataLoader(test_data, batch_size=20, shuffle=True)

In [None]:
# define model
model = define.create_model()
model_desc = define.model_desc()
model_directory_name = define.__file__.replace("\\", "/").split("/")[-2]

In [None]:
# set parameters
from torch import nn
import torch


load_model_path = None # None: `model_directory_name`에서 마지막으로 저장된 .pth 로드
save_model_path = None # None: `model_directory_name`에 자동생성된 이름으로 저장
memo = ""

In [None]:
# create runner
import util

runner = util.Runner(model, default_dir=model_directory_name)
print(f"runner use {runner.device}")
runner.load(load_model_path)

In [None]:
# start train & test
class CustomSchedule(util.Schedule):
    def __init__(self) -> None:
        self.learning_rate = 0.001
        self.min_loss = 99999999

    def on_event_start(self, runner: util.Runner):
        if runner.loaded_lr != 0:
            self.learning_rate = runner.loaded_lr
        runner.criterion = nn.CrossEntropyLoss()
        runner.optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate)
        # self.scheduler = torch.optim.lr_scheduler.StepLR(runner.optimizer, step_size=20, gamma=0.1)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            runner.optimizer,
            mode='min',
            factor=0.5,
            patience=5)
        
    def is_end(self, runner: util.Runner) -> bool:
        if runner.counter == 100:
            return True
        return self.scheduler.get_last_lr()[-1] < 0.00000001

    def on_event_end_epoch(self, runner: util.Runner):
        val_loss = runner.learn[-1]["val_loss"]
        if val_loss < self.min_loss and runner.counter > 10:
            self.min_loss = val_loss
            runner.save()
        self.scheduler.step(val_loss)
        print(f"lr: {self.scheduler.get_last_lr()[-1]}")
    
    def get_record(self, runner) -> dict:
        return { "lr": self.scheduler.get_last_lr()[-1] }
        

schedule = CustomSchedule()
runner.run(train_loader, test_loader, schedule, record={
    "model": model_desc,
    "memo": memo,
})

In [None]:
# save model
runner.save()