In [21]:
import os
import random
import paddle
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import gzip
import json
import warnings 
warnings.filterwarnings('ignore')

import paddle.nn as nn
from paddle.nn import Conv2D, MaxPool2D, Linear
import paddle.nn.functional as F

In [22]:
class MnistDataset(paddle.io.Dataset):
    def __init__(self, mode):
        data = json.load(gzip.open('mnist.json.gz'))
        train_set, val_set, eval_set = data
        
        self.IMG_ROWS = 28
        self.IMG_COLS = 28

        if mode=='train':
            imgs, labels = train_set[0], train_set[1]
        elif mode=='valid':
            imgs, labels = val_set[0], val_set[1]
        elif mode=='eval':
            imgs, labels = eval_set[0], eval_set[1]
        else:
            raise Exception("mode can only be one of ['train', 'valid', 'eval']")
        
        # 校验数据
        imgs_length = len(imgs)
        assert len(imgs) == len(labels), \
            "length of train_imgs({}) should be the same as train_labels({})".format(len(imgs), len(labels))
        
        self.imgs = imgs
        self.labels = labels

    def __getitem__(self, idx):
        # MLP
        # img = np.array(self.imgs[idx]).astype('float32')
        # label = np.array(self.labels[idx]).astype('int64')
        # CNN
        img = np.reshape(self.imgs[idx], [1, self.IMG_ROWS, self.IMG_COLS]).astype('float32')
        label = np.reshape(self.labels[idx], [1]).astype('int64')
        return img, label

    def __len__(self):
        return len(self.imgs)

In [23]:
# 定义模型结构
class MNIST(paddle.nn.Layer):
     def __init__(self):
         super(MNIST, self).__init__()
         nn.initializer.set_global_initializer(nn.initializer.Uniform(), nn.initializer.Constant())
         self.conv1 = Conv2D(in_channels=1, out_channels=20, kernel_size=5, stride=1, padding=2)
         self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)
         self.conv2 = Conv2D(in_channels=20, out_channels=20, kernel_size=5, stride=1, padding=2)
         self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)
         self.fc = Linear(in_features=980, out_features=10)
         
     def forward(self, inputs):
         x = self.conv1(inputs)
         x = F.relu(x)
         x = self.max_pool1(x)
         x = self.conv2(x)
         x = F.relu(x)
         x = self.max_pool2(x)
         x = paddle.reshape(x, [x.shape[0], 980])
         x = self.fc(x)
        #  x = F.softmax(x)
         return x     

## 定义训练Trainer， 
包含训练过程和模型保存

In [24]:


class Trainer(object):
    def __init__(self,model_path , model , optimizer):
        self.model_path = model_path 
        self.model = model
        self.optimizer = optimizer 

    def save(self): 
        paddle.save(self.model.state_dict(), self.model_path)
    
    def train_step(self,data):
        images , labels = data
        predicts = self.model.forward(images)
        
        loss = F.cross_entropy(predicts , label=labels)
        avg_loss = paddle.mean(loss)
        avg_loss.backward()
        self.optimizer.step()
        self.optimizer.clear_grad()
        return avg_loss
    
    def train_epoch(self,datasets ,epoch):
        self.model.train()
        for batch_id , data in enumerate(datasets()):
            loss = self.train_step(data)
            if batch_id % 500 == 0:
                print("epoch_id: {}, batch_id: {}, loss is: {}".format(epoch, batch_id, loss.numpy()))
    
    def train(self, train_datasets, start_epoch, end_epoch, save_path):
        if not os.path.exists(save_path):
            os.mkdir(save_path)
        
        # 保存每一轮epoch参数    
        for i in range(start_epoch, end_epoch):
            self.train_epoch(train_datasets, i)
            # 保存 优化器参数
            paddle.save(opt.state_dict(), './{}/mnist_epoch{}'.format(save_path,i)+'.pdopt')
            # 保存 模型参数
            paddle.save(model.state_dict(), './{}/mnist_epoch{}'.format(save_path,i)+'.pdparams')
        self.save() 
        
        
    

## 模型加载及恢复训练

###  保存每个epoch的参数

附：学习率多项式衰减api参数

![20221128134549](https://cdn.jsdelivr.net/gh/xihuishawpy/PicBad@main/blogs/pictures/20221128134549.png)

In [25]:
paddle.seed(2022) 

epochs = 3 
BATCH_SIZE = 32
model_path = 'mnist.pdparams'

# 准备训练数据
train_dataset = MnistDataset(mode='train')
train_loader = paddle.io.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0) 

# 学习率衰减步长(所有epoch的总batch数据)
total_steps = (int(5000//BATCH_SIZE)+1) * epochs

model = MNIST()
lr = paddle.optimizer.lr.PolynomialDecay(learning_rate=0.01, decay_steps= total_steps ,end_lr= 0.001)
opt = paddle.optimizer.Momentum(learning_rate=lr , parameters=model.parameters())


trainer = Trainer(
    model_path,
    model = model,
    optimizer = opt
    
)


trainer.train(train_datasets = train_loader , start_epoch = 0, end_epoch = epochs , save_path = 'checkpoint')

epoch_id: 0, batch_id: 0, loss is: [57.481045]
epoch_id: 0, batch_id: 500, loss is: [2.3059423]
epoch_id: 0, batch_id: 1000, loss is: [2.3140857]
epoch_id: 0, batch_id: 1500, loss is: [2.2985983]
epoch_id: 1, batch_id: 0, loss is: [2.2895033]
epoch_id: 1, batch_id: 500, loss is: [2.3060064]
epoch_id: 1, batch_id: 1000, loss is: [2.3140857]
epoch_id: 1, batch_id: 1500, loss is: [2.2985983]
epoch_id: 2, batch_id: 0, loss is: [2.2895033]
epoch_id: 2, batch_id: 500, loss is: [2.3060062]
epoch_id: 2, batch_id: 1000, loss is: [2.3140857]
epoch_id: 2, batch_id: 1500, loss is: [2.2985983]


### 恢复训练

用上述保存的参数（载入第1轮参数），从第2轮epoch开始恢复训练，观察恢复训练和完整训练的loss变化是否差距不大

In [26]:
paddle.seed(2022)

epochs = 3
BATCH_SIZE = 32
model_path = './mnist_retrain.pdparams'

# 载入数据
train_dataset = MnistDataset(mode='train')
train_loader = paddle.io.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4) 

model = MNIST()
total_steps = (int(50000//BATCH_SIZE) + 1) * epochs
lr = paddle.optimizer.lr.PolynomialDecay(learning_rate=0.01, decay_steps=total_steps, end_lr=0.001)
opt = paddle.optimizer.Momentum(learning_rate=lr, parameters=model.parameters())

# 载入参数（模型参数 + lr衰减参数）
params_dict = paddle.load('./checkpoint/mnist_epoch0.pdparams')
opt_dict = paddle.load('./checkpoint/mnist_epoch0.pdopt')

model.set_state_dict(params_dict)
opt.set_state_dict(opt_dict)

# 训练
trainer = Trainer(
    model_path=model_path,
    model=model,
    optimizer=opt
)

# 从第2个epoch开始训练
trainer.train(train_datasets=train_loader,start_epoch = 1, end_epoch = epochs, save_path='checkpoint_con')

AssertionError: Optimizer set error, conv2d_14.w_0_velocity_0 should in state dict