In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [49]:
from k12libs.utils.nb_easy import k12ai_start_html, W3URL
from urllib.parse import urlencode

In [52]:
code='''
from pyr.app.k12ai import EasyaiClassifier, EasyaiTrainer, EasyaiDataset
from collections import OrderedDict
import torch

class CustomClassifier(EasyaiClassifier):
    ## Load
    def prepare_dataset(self):
        return self.load_presetting_dataset_('rmnist', '/data/datasets/cv')
    
    def build_model(self):
        return self.load_pretrained_model_('resnet18', num_classes=10)
    
    def configure_optimizer(self, model):
        return self.adam(model.parameters(),
            base_lr=0.001)

    def configure_scheduler(self, optimizer):
        return self.period_step(optimizer, step_size=30, gamma=0.1)
    
    ## Train
    def train_dataloader(self):
        return self.get_dataloader(
            phase='train',
            data_augment=[
                self.random_resized_crop(size=(128, 128)),
                self.random_brightness(factor=0.3),
                self.random_rotation(degrees=30)
            ],
            random_order=False,
            input_size=128,
            normalize=True,
            batch_size=32,
            drop_last=False,
            shuffle=False)

    def training_step(self, batch, batch_idx):
        # REQUIRED
        x, y, _ = batch
        y_hat = self.forward(x) # (32, 10)
        loss = self.cross_entropy(y_hat, y, reduction='mean')
        with torch.no_grad():
            accuracy = (torch.argmax(y_hat, axis=1) == y).float().mean()
        log = {'train_loss': loss, 'train_acc': accuracy}
        output = OrderedDict({
            'loss': loss,        # M
            'acc': accuracy,     # O
            'progress_bar': log, # O
        })
        return output

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['acc'] for x in outputs]).mean()
        log = {'train_loss': avg_loss, 'train_acc': avg_acc}
        output = OrderedDict({
            'progress_loss': log,
        })
        return output
        
    ## Valid
    def val_dataloader(self):
        return self.get_dataloader('val',
                input_size=128,
                batch_size=32,
                drop_last=False,
                shuffle=False)

    def validation_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self.forward(x) # (32, 10)
        loss = self.cross_entropy(y_hat, y, reduction='mean')
        accuracy = (torch.argmax(y_hat, axis=1) == y).float().mean()
        output = OrderedDict({
            'val_loss': loss,
            'val_acc': accuracy,
        })
        return output

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['val_acc'] for x in outputs]).mean()
        log = {'val_loss': avg_loss, 'val_acc': avg_acc}
        output = OrderedDict({
            **log,
            'progress_loss': log,
        })
        return output
        
trainer = EasyaiTrainer(max_epochs=2, model_summary=None, early_stop={'monitor': 'val_acc', 'patience': 1, 'mode': 'max'})
trainer.fit(CustomClassifier())
print(trainer.test())
'''

In [46]:
exec(code)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
TEST RESULTS
{'test_acc': tensor(0.4023, device='cuda:0'),
 'test_loss': tensor(1.5445, device='cuda:0')}
--------------------------------------------------------------------------------



In [53]:
params = {'default': code.replace('pyr.app.', '').replace('/data/datasets/cv', '/datasets')}
k12ai_start_html(f'{W3URL}/codemirror.html?{urlencode(params)}', width='100%', height='900px')