In [4]:
import os
import time
import logging
import numpy as np
from mindspore import  Tensor
from mindspore.nn import SoftmaxCrossEntropyWithLogits
from mindspore.train.serialization import export
from lenet import LeNet5
import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.transforms as C
import mindspore.dataset.vision as CV
from mindspore.nn.metrics import Accuracy
from mindspore import nn
from mindspore.train import Model
from mindspore.train.callback import Callback

#数据集创建
def create_dataset(dataset_path,batch_size=256,resize=(32, 32), rescale=1 / 255, shift=-0.5, buffer_size=1000):
    dataset = ds.MnistDataset(dataset_path)
    # 改变形状
    resize_op = CV.Resize(resize)
    # 归一化和标准化操作
    rescale_op = CV.Rescale(rescale, shift)
    # 变换格式
    hwc2chw_op = CV.HWC2CHW()

    # 利用map操作对原数据集进行调整
    dataset = dataset.map(input_columns="image", operations=[resize_op, rescale_op, hwc2chw_op])
    dataset = dataset.map(input_columns="label", operations=C.TypeCast(ms.int32))
    # 数据集打乱
    dataset = dataset.shuffle(buffer_size=buffer_size)
    # 设定batch_size，并丢弃不够一个batch的数据
    dataset = dataset.batch(batch_size, drop_remainder=True)

    return dataset


#超参数字典
class Config(dict):
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__

# 设置日志格式
def set_logging(config):
    if not os.path.exists(config.log_path):
        os.makedirs(config.log_path)
    filename = os.path.join(config.log_path, config.log_file)
    logging.basicConfig(
        level=logging.INFO,
        filename=filename,
        filemode='w',
        format='[%(asctime)s %(levelname)-8s] %(message)s',
        datefmt='%Y%m%d %H:%M:%S'
    )

#自定义回调函数，保存训练数据
class StepLossAccInfo(Callback):
    def __init__(self, model,epoch):
        self.model = model
        self.epoch = epoch

    def on_train_step_end(self, run_context):
        cb_params = run_context.original_args()
        cur_epoch = cb_params.cur_epoch_num
        cur_step = (cur_epoch - 1) * 1875 + cb_params.cur_step_num
        if cur_step%train_data.get_dataset_size()==0:
            self.train_loss = (str(cb_params.net_outputs))
            logging.info('======================================================================================')
            logging.info('===============================Train - Epoch :{} / Loss:{} =============================='.format(self.epoch + 1, self.train_loss))
            logging.info('======================================================================================')



if __name__ == '__main__':
    #超参数字典
    config = Config({
        #数据集名称及路径
        "dataset_name": "MNIST",
        "train_path": "D:/data/MNIST/train",
        "test_path": "D:/data/MNIST/test",

        #批处理大小
        "batch_size": 256,
        #学习率
        "learning_rate": 2e-3,
        #周期数
        "n_epoch": 5,
        #测试频率
        "test_freq": 1, #1即为1周期进行一次测试
        #日志路径和模型保存路径
        "model_name": "LeNet-5",
        "onnx_path": "./output/models/",
        "log_path": "./output/log/",
    })

    #模型保存路径创建
    config.onnx_path = os.path.join(config.onnx_path, config.model_name)
    if not os.path.exists(config.onnx_path):
        os.makedirs(config.onnx_path)
    # 日志文件
    config.log_file = config.model_name + ".log"
    set_logging(config)
    logging.info(config)
    #数据集
    train_data = create_dataset(config.train_path,batch_size=config.batch_size)
    test_data = create_dataset(config.test_path,batch_size=config.batch_size)
    #网络模型
    net = LeNet5()
    #交叉熵损失函数
    criterion = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
    # 优化器
    optimizer = nn.Adam(params=net.trainable_params(), learning_rate=config.learning_rate)
    #模型和优化器，损失函数链接在一起
    model = Model(net, loss_fn=criterion, optimizer=optimizer, metrics={'accuracy': Accuracy()})



    for epoch in range(config.n_epoch):
        #训练阶段
        start_time = time.time()
        logging.info('\n')
        logging.info('Running training epoch {}'.format(epoch + 1))
        #开始训练
        model.train(1, train_data, callbacks=[StepLossAccInfo(model,epoch)], dataset_sink_mode=False)

        #根据周期数和测试频率进行测试阶段
        if (epoch + 1) % config.test_freq == 0:
            logging.info('Starting test...')
            logging.info('Running testing in epoch {}'.format(epoch + 1))
            #开始测试
            acc = model.eval(test_data)

            logging.info('======================================================================================')
            logging.info('===========================Test - Epoch :{} / Accuracy: {}============================'.format(epoch+1,acc['accuracy']))
            logging.info('======================================================================================')

            logging.info('Test done...')
            #保存模型
            model_save_path = config.onnx_path + '/lenet-5-epoch' + str(epoch + 1) + '.onnx'
            logging.info('Saving weights and model of epoch{}, path:{}'.format(epoch + 1, model_save_path))
            export(net, Tensor(np.zeros((1, 1, 32, 32)).astype(np.float32)),file_name=model_save_path, file_format='ONNX')
        logging.info('Epoch {} done. Time: {:.2}s'.format(epoch + 1, (time.time() - start_time)))