In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [7]:
from k12libs.utils.nb_easy import k12ai_start_html, W3URL
from urllib.parse import urlencode
import os

In [34]:
code='''
from k12ai import EasyaiClassifier, EasyaiTrainer, EasyaiDataset
import torch
from torch import nn

class CustomClassifier(EasyaiClassifier):
    
    def configure_optimizer(self, model):
        return self.adam(model.parameters(), base_lr=0.001)

    def configure_scheduler(self, optimizer):
        return self.period_step(optimizer, step_size=30, gamma=0.1)
    
    ## Train
    def train_dataloader(self):
        return self.get_dataloader(
            phase='train', # [M]
            data_augment=[
                self.random_resized_crop(size=(128, 128)),
                self.random_brightness(factor=0.3),
                self.random_rotation(degrees=30)
            ], # [O]
            random_order=False, # [O]
            input_size=128,
            normalize=True,
            batch_size=32,
            drop_last=False,
            shuffle=False)

    def training_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self(x) # (32, 10)
        loss = self.cross_entropy(y_hat, y, reduction='mean')
        with torch.no_grad():
            accuracy = (torch.argmax(y_hat, axis=1) == y).float().mean()
        return {'loss': loss, 'progress_bar': {'acc': accuracy}}

    # def training_epoch_end(self, outputs):
    #     avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
    #     avg_acc = torch.stack([x['acc'] for x in outputs]).mean()
    #     return {'progress_bar': {'train_loss': avg_loss, 'train_acc': avg_acc}}
        
    ## Valid
    def val_dataloader(self):
        return self.get_dataloader('val',
                input_size=128,
                batch_size=32,
                drop_last=False,
                shuffle=False)

    def validation_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self(x) # (32, 10)
        loss = self.cross_entropy(y_hat, y, reduction='mean')
        accuracy = (torch.argmax(y_hat, axis=1) == y).float().mean()
        return {'loss': loss, 'acc': accuracy}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['acc'] for x in outputs]).mean()
        return {'progress_bar': {'val_loss': avg_loss, 'val_acc': avg_acc}}
        
    ## Test
    def test_dataloader(self):
        return self.get_dataloader('test',
                input_size=128,
                batch_size=32,
                drop_last=False,
                shuffle=False)

    def test_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self(x) # (32, 10)
        loss = self.cross_entropy(y_hat, y, reduction='mean')
        accuracy = (torch.argmax(y_hat, axis=1) == y).float().mean()
        return {'acc': accuracy}

    def test_epoch_end(self, outputs):
        avg_acc = torch.stack([x['acc'] for x in outputs]).mean()
        return {'test_acc': avg_acc}

    
trainer = EasyaiTrainer(
    max_epochs=1, # 训练过程遍历完整数据集的总次数(epoch)
    resume=False, # True: 模型继续上次训练(模型必须没有改变)
    log_rate=2,   # 日志打印的频率, 单位是迭代次数(iteration step) 
    model_summary='full',
    model_ckpt={'monitor': 'val_loss', 'period': 2, 'mode': 'min'},
    early_stop={'monitor': 'val_acc', 'patience': 3, 'mode': 'max'}
)

model = CustomClassifier()

# 训练
trainer.fit(model)

# 评估
trainer.test(model)

'''

params = {'default': code}
k12ai_start_html(f'{W3URL}/codemirror.html?{urlencode(params)}', width='100%', height='1400px')