In [1]:
import os
import mindspore as ms
import mindspore.context as context
#transforms.c_transforms用于通用型数据增强，vision.c_transforms用于图像类数据增强
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
#nn模块用于定义网络，model模块用于编译模型，callback模块用于设定监督指标
from mindspore import nn
from mindspore.train import Model
from mindspore.train.callback import LossMonitor
#设定运行模式为图模式，由于我是在本地Windows主机上进行的实验，所以在此处修改了device_target为CPU，如果你实验一安装的GPU版本就改成GPU，安装的Ascend版本就改成Ascend，否则一定会报错！
context.set_context(mode=context.GRAPH_MODE, device_target='CPU') 

In [2]:
#根据数据集存储地址，生成数据集
def create_dataset(data_dir, training=True, batch_size=32, resize=(32, 32),
                   rescale=1/(255*0.3081), shift=-0.1307/0.3081, buffer_size=64):
    #生成训练集和测试集的路径
    data_train = os.path.join(data_dir, 'train') # train set
    data_test = os.path.join(data_dir, 'test') # test set
    #利用MnistDataset方法读取mnist数据集，如果training是True则读取训练集
    ds = ms.dataset.MnistDataset(data_train if training else data_test)
    #map方法是非常有效的方法，可以整体对数据集进行处理，resize改变数据形状，rescale进行归一化，HWC2CHW改变图像通道
    ds = ds.map(input_columns=["image"], operations=[CV.Resize(resize), CV.Rescale(rescale, shift), CV.HWC2CHW()])
    #利用map方法改变数据集标签的数据类型
    ds = ds.map(input_columns=["label"], operations=C.TypeCast(ms.int32))
    # shuffle是打乱操作，同时设定了batchsize的大小，并将最后不足一个batch的数据抛弃
    ds = ds.shuffle(buffer_size=buffer_size).batch(batch_size, drop_remainder=True)

    return ds 

In [3]:
#定义模型结构，MindSpore中的模型时通过construct定义模型结构，在__init__中初始化各层的对象
class BasicBlock(nn.Cell):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, pad_mode='pad', padding=1, has_bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, pad_mode='pad', padding=1, has_bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.SequentialCell()
        if stride != 1 or in_channels != self.expansion * out_channels:
            self.shortcut = nn.SequentialCell([
                nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size=1, stride=stride, has_bias=False),
                nn.BatchNorm2d(self.expansion * out_channels)
            ])

    def construct(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out

# 定义ResNet模型

class ResNet(nn.Cell):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, pad_mode='pad', padding=1, has_bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avg_pool = nn.AvgPool2d(kernel_size=4)
        self.fc = nn.Dense(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion
        return nn.SequentialCell(layers)

    def construct(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avg_pool(out)
        out = out.view(out.shape[0], -1)
        out = self.fc(out)
        return out

def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])

In [4]:
# 构建训练、验证函数进行模型训练和验证，提供数据路径，设定学习率，epoch数量
def train(data_dir, lr=0.01, momentum=0.9, num_epochs=3):
    #调用函数，读取训练集
    ds_train = create_dataset(data_dir)
    #调用函数，读取验证集
    ds_eval = create_dataset(data_dir, training=False)
    #构建网络
    net = ResNet18()
    #设定loss函数
    loss = nn.loss.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
    #设定优化器
    opt = nn.Momentum(net.trainable_params(), lr, momentum)
    #设定损失监控
    loss_cb = LossMonitor(50)
    #编译形成模型
    model = Model(net, loss, opt, metrics={'acc', 'loss'})
    # 训练网络，dataset_sink_mode为on_device模式
    model.train(num_epochs, ds_train, callbacks=[loss_cb], dataset_sink_mode=False)
    #用验证机评估网络表现
    metrics = model.eval(ds_eval, dataset_sink_mode=False)
    #输出相关指标
    print('Metrics:', metrics) 

In [5]:
#main函数负责调用之前定义的函数，完成整个训练验证过程
if __name__ == "__main__":
    #argsparse是python的命令行解析的标准模块，可以通过命令行传入参数
    import argparse
    parser = argparse.ArgumentParser()
    #设定训练数据路径
    parser.add_argument('--data_url', required=False, default='./MNIST_Data/', help='Location of data.')
    parser.add_argument('--train_url', required=False, default=None, help='Location of training outputs.')
    args, unknown = parser.parse_known_args()
    #判断路径是否为obs路径，如果是，从obs路径下载数据
    if args.data_url.startswith('s3'):
        import moxing

        # WAY1: copy dataset from your own OBS bucket to container/cache.
        # moxing.file.copy_parallel(src_url=args.data_url, dst_url='MNIST/')

        # WAY2: copy dataset from other's OBS bucket, which has been set public read or public read&write.
        moxing.file.copy_parallel(src_url="s3://share-course/dataset/MNIST/", dst_url='MNIST/')

        data_path = 'MNIST/'
    else:
        data_path = os.path.abspath(args.data_url)
    #调用train函数，训练并验证模型
    train(data_path)
    #请注意，换成ResNet后，这个代码如果用CPU跑，会非常慢！实测i7-12700H满载运行跑三轮花了将近一小时，如用个人电脑运行代码，请注意散热！不要着急，慢慢跑！



epoch: 1 step: 50, loss is 0.33057069778442383
epoch: 1 step: 100, loss is 0.20067404210567474
epoch: 1 step: 150, loss is 0.015457981266081333
epoch: 1 step: 200, loss is 0.08370426297187805
epoch: 1 step: 250, loss is 0.23574164509773254
epoch: 1 step: 300, loss is 0.021979905664920807
epoch: 1 step: 350, loss is 0.09551975131034851
epoch: 1 step: 400, loss is 0.14652471244335175
epoch: 1 step: 450, loss is 0.29470163583755493
epoch: 1 step: 500, loss is 0.008271989412605762
epoch: 1 step: 550, loss is 0.05912448465824127
epoch: 1 step: 600, loss is 0.0041542332619428635
epoch: 1 step: 650, loss is 0.005579465068876743
epoch: 1 step: 700, loss is 0.5467463135719299
epoch: 1 step: 750, loss is 0.027821220457553864
epoch: 1 step: 800, loss is 0.014661386609077454
epoch: 1 step: 850, loss is 0.017835469916462898
epoch: 1 step: 900, loss is 0.03618411347270012
epoch: 1 step: 950, loss is 0.03444613888859749
epoch: 1 step: 1000, loss is 0.021323563531041145
epoch: 1 step: 1050, loss is 0.