In [1]:
%reload_ext autoreload
%autoreload 2

! [ -L /datasets ] && rm -f /datasets
! ln -s /data/datasets/cv /datasets

from k12libs.utils.nb_easy import k12ai_set_notebook

k12ai_set_notebook(cellw=95)

## 需掌握知识点

1. 学习率
2. 批量数(Batch Size)
3. 优化器

In [2]:
%reset -f

from pyr.app.k12ai import EasyaiClassifier, EasyaiTrainer
import torch
from torch import nn
from sklearn.metrics import confusion_matrix

class CustomClassifier(EasyaiClassifier):
    
    def __init__(self):
        super().__init__()
        self.input_size = 100 # 输入模型的图片大小
        self.batch_size = 32  # 输入模型的图片数量
        
        # 调试: 日志输出模型shape
        self.example_input_array = torch.zeros(
            self.batch_size, 3, self.input_size, self.input_size) 

    def prepare_dataset(self):
        """
        水果数据集:
            0: 'Banana'  香蕉
            1: 'Pear'    梨子
            2: 'Lemon'   柠檬
            3: 'Apple'   苹果
            4: 'Orange'  橘子
        """
        return self.load_fruits()
    
    def build_model(self):
        class LeNet(nn.Module):
            def __init__(self, num_classes):
                super().__init__()
                self.features = nn.Sequential(
                    nn.Conv2d(3, 6, kernel_size=5),
                    nn.Tanh(),
                    nn.AvgPool2d(2, stride=2),
                    nn.Conv2d(6, 16, kernel_size=5),
                    nn.Tanh(),
                    nn.AvgPool2d(2, stride=2)
                )
                self.classifier = nn.Sequential(
                    nn.Flatten(start_dim=1, end_dim=-1),
                    nn.Linear(7744, 120),
                    nn.Tanh(),
                    nn.Linear(120, 84),
                    nn.Tanh(),
                    nn.Linear(84, num_classes)
                )
                
            def forward(self, x):     
                x = self.features(x)
                return self.classifier(x)
            
        return LeNet(num_classes=5) # input shape: (32, 3, 100, 100)
    
    def train_dataloader(self):
        return self.get_dataloader('train', self.batch_size, self.input_size)
     
    def training_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self(x)
        loss = self.cross_entropy(y_hat, y, reduction='mean') # 损失方法
        with torch.no_grad():
            accuracy = (torch.argmax(y_hat, axis=1) == y).float().mean()
        return {'loss': loss, 'progress_bar': {'acc': accuracy}}

    def val_dataloader(self):
        return self.get_dataloader('val', self.batch_size, self.input_size)
    
    def validation_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self(x)
        loss = self.cross_entropy(y_hat, y, reduction='mean')
        accuracy = (torch.argmax(y_hat, axis=1) == y).float().mean()
        return {'loss': loss, 'acc': accuracy}
        
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['acc'] for x in outputs]).mean()
        return {'progress_bar': {'val_loss': avg_loss, 'val_acc': avg_acc}}
    
    def test_dataloader(self):
        return self.get_dataloader('test', self.batch_size, self.input_size)
    
    def test_step(self, batch, batch_idx):
        x, y, p = batch
        y_hat = self(x)
        accuracy = (torch.argmax(y_hat, axis=1) == y).float().mean()
        return {'acc': accuracy}
        
    def test_epoch_end(self, outputs):
        avg_acc = torch.stack([x['acc'] for x in outputs]).mean()
        return {'test_acc': avg_acc}
    
    def configure_optimizer(self, model):
        """
        优化器:
            adam:
            sgd:
        """
        # 调整学习率
        return self.adam(model.parameters(), base_lr=0.1)
        # return self.sgd(model.parameters(), base_lr=0.1)
    
    def configure_scheduler(self, optimizer):
        """
        学习率修改策略:
            step_lr:
            multistep_lr:
            reduce_lr:
        """
        # return self.step_lr(optimizer, step_size=2, gamma=0.1)
        return self.reduce_lr(optimizer, factor=0.1, patience=2)
    
    
trainer = EasyaiTrainer(max_epochs=10, model_summary='full', ckpt_path='6-1')

model = CustomClassifier()

# 训练
trainer.fit(model)

# 评估
trainer.test()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


--------------------------------------------------------------------------------
{'label_names': ['Banana', 'Pear', 'Lemon', 'Apple', 'Orange'],
 'mean': [0.7507, 0.6427, 0.5061],
 'num_classes': 5,
 'num_records': 3263,
 'std': [0.2251, 0.2804, 0.3714]}

--------------------------------------------------------------------------------

   | Name               | Type       | Params | In sizes          | Out sizes       
------------------------------------------------------------------------------------------
0  | model              | LeNet      | 942 K  | [32, 3, 100, 100] | [32, 5]         
1  | model.features     | Sequential | 2 K    | [32, 3, 100, 100] | [32, 16, 22, 22]
2  | model.features.0   | Conv2d     | 456    | [32, 3, 100, 100] | [32, 6, 96, 96] 
3  | model.features.1   | Tanh       | 0      | [32, 6, 96, 96]   | [32, 6, 96, 96] 
4  | model.features.2   | AvgPool2d  | 0      | [32, 6, 96, 96]   | [32, 6, 48, 48] 
5  | model.features.3   | Conv2d     | 2 K    | [32, 6, 48, 4

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

Saving latest checkpoint..



(fit)	GPU-0 memory allocated: 7.2 MB	 max memory allocated: 38.67 MB
--------------------------------------------------------------------------------
{'label_names': ['Banana', 'Pear', 'Lemon', 'Apple', 'Orange'],
 'mean': [0.7507, 0.6427, 0.5061],
 'num_classes': 5,
 'num_records': 3263,
 'std': [0.2251, 0.2804, 0.3714]}


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


(fit)	GPU-0 memory allocated: 10.8 MB	 max memory allocated: 38.67 MB
(test)	GPU-0 memory allocated: 10.8 MB	 max memory allocated: 38.67 MB
--------------------------------------------------------------------------------
{'test_acc': 0.20365919172763824}
--------------------------------------------------------------------------------
