In [1]:
%reload_ext autoreload
%autoreload 2

! [ ! -L /datasets ] && ln -s /data/datasets/cv /datasets

from k12libs.utils.nb_easy import k12ai_set_notebook

k12ai_set_notebook(cellw=95)

## 需掌握知识点

1. 卷积层基本功能
2. 池化层基本功能

In [4]:
%reset -f

from pyr.app.k12ai import EasyaiClassifier, EasyaiTrainer
import torch
from torch import nn

class CustomClassifier(EasyaiClassifier):
    
    def __init__(self):
        super().__init__()
        self.input_size = 224 # 输入模型的图片大小
        self.batch_size = 64  # 输入模型的图片数量
        
        # 调试: 日志输出模型shape
        self.example_input_array = torch.zeros(
            self.batch_size, 3, self.input_size, self.input_size) 

    def prepare_dataset(self):
        """
        flowers: 22个分类
            0: 'salviasplendens' 鼠尾草
            1: 'daffodil' 水仙花
            2: 'snowdrop' 雪花莲
            3: 'lilyvalley' 铃兰花
            4: 'bluebell' 野风信子
            5: 'crocus' 番红花
            6: 'iris' 鸢尾花
            7: 'tigerlily' 卷丹
            8: 'tulip'  郁金香
            9: 'fritillary' 豹纹蝶
           10: 'sunflower' 向日葵
           11: 'daisy' 雏菊
           12: 'coltsfoot' 款冬
           13: 'dandelion' 蒲公英
           14: 'cowslip' 黄花九轮草
           15: 'buttercup' 毛茛
           16: 'windflower' 白头翁
           17: 'pansy' 蝴蝶花
           18: 'coxcomb'  鸡冠花
           19: 'flamingo' 红鹤
           20: 'lily'  百合花
           21: 'lotus' 荷花
        """
        return self.load_flowers()
    
    def build_model(self):
        class ConvNetwork(nn.Module):
            def __init__(self, num_classes):
                super().__init__()
                self.features1 = nn.Sequential(
                    nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=1, dilation=1, groups=1, bias=True),
                    nn.ReLU(inplace=True),
                    nn.AvgPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=False),
                    nn.BatchNorm2d(num_features=64, momentum=0.1)
                )

                self.features2 = nn.Sequential(
                    nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5, stride=2, padding=1, dilation=1, groups=1, bias=True),
                    nn.ReLU(inplace=True),
                    nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=False),
                    nn.Dropout2d(inplace=True, p=0.2)
                )

                self.features3 = nn.Sequential(
                    nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True),
                    nn.ReLU(inplace=True),
                    nn.AdaptiveAvgPool2d(output_size=28),
                    nn.BatchNorm2d(num_features=64, momentum=0.1, affine=True, track_running_stats=True)
                )

                self.classifier = nn.Sequential(
                    nn.Flatten(start_dim=1, end_dim=-1),
                    nn.Linear(in_features=50176, out_features=128, bias=True),
                    nn.Dropout(inplace=True, p=0.5),
                    nn.Linear(in_features=128, out_features=22, bias=True)
                )

            def forward(self, x):
                x = self.features1(x)
                x = self.features2(x)
                x = self.features3(x)
                return self.classifier(x)
            
        return ConvNetwork(num_classes=22) # input: (32, 3, 224, 224)
    
    def train_dataloader(self):
        return self.get_dataloader('train', self.batch_size, self.input_size, normalize=True)
     
    def training_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self(x)
        loss = self.cross_entropy(y_hat, y, reduction='mean') # 损失方法
        with torch.no_grad():
            accuracy = (torch.argmax(y_hat, axis=1) == y).float().mean()
        return {'loss': loss, 'progress_bar': {'acc': accuracy}}

    def val_dataloader(self):
        return self.get_dataloader('val', self.batch_size, self.input_size, normalize=True)
    
    def validation_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self(x)
        loss = self.cross_entropy(y_hat, y, reduction='mean')
        accuracy = (torch.argmax(y_hat, axis=1) == y).float().mean()
        return {'loss': loss, 'acc': accuracy}
        
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['acc'] for x in outputs]).mean()
        return {'progress_bar': {'val_loss': avg_loss, 'val_acc': avg_acc}}
    
    def test_dataloader(self):
        return self.get_dataloader('test', self.batch_size, self.input_size, normalize=True)
    
    def configure_optimizer(self, model):
        return self.adam(model.parameters(), base_lr=0.001)
    
    
trainer = EasyaiTrainer(max_epochs=50, resume=False, model_summary='full', ckpt_path='4-2')

model = CustomClassifier()

# 训练
trainer.fit(model)

# 评估
trainer.test()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


--------------------------------------------------------------------------------
{'label_names': ['salviasplendens',
                 'daffodil',
                 'snowdrop',
                 'lilyvalley',
                 'bluebell',
                 'crocus',
                 'iris',
                 'tigerlily',
                 'tulip',
                 'fritillary',
                 'sunflower',
                 'daisy',
                 'coltsfoot',
                 'dandelion',
                 'cowslip',
                 'buttercup',
                 'windflower',
                 'pansy',
                 'coxcomb',
                 'flamingo',
                 'lily',
                 'lotus'],
 'mean': [0.4623, 0.4305, 0.295],
 'num_classes': 22,
 'num_records': 1760,
 'std': [0.252, 0.2242, 0.2091]}

--------------------------------------------------------------------------------

   | Name               | Type              | Params | In sizes           | Out sizes         

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

Saving latest checkpoint..



(fit)	GPU-0 memory allocated: 50.17 MB	 max memory allocated: 1028.23 MB
--------------------------------------------------------------------------------
{'label_names': ['salviasplendens',
                 'daffodil',
                 'snowdrop',
                 'lilyvalley',
                 'bluebell',
                 'crocus',
                 'iris',
                 'tigerlily',
                 'tulip',
                 'fritillary',
                 'sunflower',
                 'daisy',
                 'coltsfoot',
                 'dandelion',
                 'cowslip',
                 'buttercup',
                 'windflower',
                 'pansy',
                 'coxcomb',
                 'flamingo',
                 'lily',
                 'lotus'],
 'mean': [0.4623, 0.4305, 0.295],
 'num_classes': 22,
 'num_records': 1760,
 'std': [0.252, 0.2242, 0.2091]}


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


(fit)	GPU-0 memory allocated: 75.25 MB	 max memory allocated: 1028.23 MB
(test)	GPU-0 memory allocated: 75.25 MB	 max memory allocated: 1028.23 MB
--------------------------------------------------------------------------------
{'test_acc': 0.3072916865348816}
--------------------------------------------------------------------------------
