In [3]:
%reload_ext autoreload
%autoreload 2

! [ -L /datasets ] && rm -f /datasets
! ln -s /data/datasets/cv /datasets

from k12libs.utils.nb_easy import k12ai_set_notebook

k12ai_set_notebook(cellw=95)

## 需掌握知识点

1. 卷积层基本功能
2. 池化层基本功能

In [6]:
from pyr.app.k12ai import EasyaiClassifier, EasyaiTrainer
import torch
from torch import nn

class CustomClassifier(EasyaiClassifier):
    
    def __init__(self):
        super().__init__()
        self.input_size = 224
        self.batch_size = 32
        
        # 输出模型shape
        self.example_input_array = torch.zeros(
            self.batch_size, 3, self.input_size, self.input_size) 

    def prepare_dataset(self):
        """
        flowers: 22个分类
            0: 'salviasplendens' 鼠尾草
            1: 'daffodil' 水仙花
            2: 'snowdrop' 雪花莲
            3: 'lilyvalley' 铃兰花
            4: 'bluebell' 野风信子
            5: 'crocus' 番红花
            6: 'iris' 鸢尾花
            7: 'tigerlily' 卷丹
            8: 'tulip'  郁金香
            9: 'fritillary' 豹纹蝶
           10: 'sunflower' 向日葵
           11: 'daisy' 雏菊
           12: 'coltsfoot' 款冬
           13: 'dandelion' 蒲公英
           14: 'cowslip' 黄花九轮草
           15: 'buttercup' 毛茛
           16: 'windflower' 白头翁
           17: 'pansy' 蝴蝶花
           18: 'coxcomb'  鸡冠花
           19: 'flamingo' 红鹤
           20: 'lily'  百合花
           21: 'lotus' 荷花
        """
        return self.load_flowers()
    
    def build_model(self):
        class ConvNetwork(nn.Module):
            def __init__(self, num_classes):
                super().__init__()
                self.conv_net = nn.Sequential()
                self.conv_net.add_module('batchnorm', nn.BatchNorm2d(num_features=3, momentum=0.1))
                self.conv_net.add_module('conv2d', nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=1, bias=True))
                self.conv_net.add_module('maxpool2d', nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=False))
                self.conv_net.add_module('flatten', nn.Flatten(start_dim=1, end_dim=-1))
                self.conv_net.add_module('linear1', nn.Linear(in_features=774400, out_features=100, bias=True))
                self.conv_net.add_module('linear2', nn.Linear(in_features=100, out_features=num_classes, bias=True))

            def forward(self, x):
                return self.conv_net(x)
        return ConvNetwork(num_classes=22) # input: (32, 3, 224, 224)
    
    def train_dataloader(self):
        batch_size = 32
        input_size = 224
        # 
        return self.get_dataloader('train', input_size, batch_size)
     
    def training_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self(x)
        loss = self.cross_entropy(y_hat, y, reduction='mean') # 损失方法
        with torch.no_grad():
            accuracy = (torch.argmax(y_hat, axis=1) == y).float().mean()
        return {'loss': loss, 'progress_bar': {'acc': accuracy}}

    def val_dataloader(self):
        return self.get_dataloader(phase='val', input_size=224, batch_size=32)
    
    def validation_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self(x)
        loss = self.cross_entropy(y_hat, y, reduction='mean')
        accuracy = (torch.argmax(y_hat, axis=1) == y).float().mean()
        return {'loss': loss, 'acc': accuracy}
        
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['acc'] for x in outputs]).mean()
        return {'progress_bar': {'val_loss': avg_loss, 'val_acc': avg_acc}}
    
    def test_dataloader(self):
        return self.get_dataloader('test', input_size=224, batch_size=32)
    
    
trainer = EasyaiTrainer(max_epochs=5, model_summary='full')

model = CustomClassifier()

# 训练
trainer.fit(model)

trainer.test()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


--------------------------------------------------------------------------------
{'classes': ['salviasplendens',
             'daffodil',
             'snowdrop',
             'lilyvalley',
             'bluebell',
             'crocus',
             'iris',
             'tigerlily',
             'tulip',
             'fritillary',
             'sunflower',
             'daisy',
             'coltsfoot',
             'dandelion',
             'cowslip',
             'buttercup',
             'windflower',
             'pansy',
             'coxcomb',
             'flamingo',
             'lily',
             'lotus'],
 'mean': [0.4623, 0.4305, 0.295],
 'records': 1760,
 'std': [0.252, 0.2242, 0.2091]}
--------------------------------------------------------------------------------

--------------------------------------------------------------------------------

  | Name                     | Type        | Params | In sizes           | Out sizes         
-------------------------------

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

Saving latest checkpoint..



--------------------------------------------------------------------------------
{'classes': ['salviasplendens',
             'daffodil',
             'snowdrop',
             'lilyvalley',
             'bluebell',
             'crocus',
             'iris',
             'tigerlily',
             'tulip',
             'fritillary',
             'sunflower',
             'daisy',
             'coltsfoot',
             'dandelion',
             'cowslip',
             'buttercup',
             'windflower',
             'pansy',
             'coxcomb',
             'flamingo',
             'lily',
             'lotus'],
 'mean': [0.4623, 0.4305, 0.295],
 'records': 1760,
 'std': [0.252, 0.2242, 0.2091]}
--------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
{'test_acc': 0.2955729365348816}
--------------------------------------------------------------------------------
