安装mindspore、mindvision

https://www.mindspore.cn/install

https://mindspore.cn/vision/docs/zh-CN/r0.1/mindvision_install.html

In [18]:
import mindspore
mindspore.run_check()

MindSpore version:  1.8.1
The result of multiplication calculation is correct, MindSpore has been installed successfully!


In [19]:
import os
import os.path as osp
import time
import numpy as np
import re
import glob
from PIL import Image
import warnings #simplify


import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common import set_seed
from mindspore import Tensor, Model
from mindspore import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor,\
                                      TimeMonitor, SummaryCollector
import mindspore.dataset as ds
from mindspore.dataset.vision import Resize, Rescale, Normalize, HWC2CHW, RandomHorizontalFlip, RandomErasing
from mindspore.dataset.transforms import Compose
from mindspore.common.initializer import initializer, HeNormal
from collections import OrderedDict

# OSNet行人重识别

## 任务简介

![reid.png](./img/reid.png)

行人重识别是利用计算机视觉技术判断图像或者视频序列中是否存在特定行人的技术，通常被认为是一个图像检索的子问题。在监控视频中，由于相机分辨率和拍摄角度的缘故，通常无法得到高质量的人脸图片。当人脸识别失效的情况下，ReID就成为了人物身份识别的重要替代技术。
行人重识别的数据集通常是通过人工标注或者检测算法得到的行人图片。数据集分为训练集、验证集、Query、Gallery。在训练集上进行训练得到的模型对Query与Gallery中的图片分别提取特征并计算相似度。对于每个Query，在Gallery中会找出前N个与其相似的图片。训练、测试中人物身份不重复。

## OSNet简介

![schematic](./img/schematic.png)

首先需要理解omni-scale指什么？ReID的关键在于学习判别性特征，本文认为需要提取全方位特征具有判别性，文中解释为是多样化的同质和异质的分块的结合。

OSNet有多个分支，每个分支有不同感受野，捕获不同尺度的特征。通过Channel-wise Adaptive Aggregation融合特征，即一个特征聚合模块，可训练且权重与输入动态关联。

OSNet是一个轻量级的网络，它可带来以下好处：（1）轻量级网络具有更少的模型参数，不容易过拟合（2）在大型监视应用程序中（例如，使用数千个摄像头的城市范围的监视），ReID的唯一实用方法是在摄像头端执行特征提取。对于设备上的处理，小型ReID网络显然是首选。

## 模型解析

#### LightConv3x3

![](./img/LightConv.png)

（a）为标准的3\*3卷积，（b）为轻量化的。采用point-wise和depth-wise操作拆分传统的卷积减少参数量和运算量。

In [20]:
def kaiming_normal(shape, mode='fan_out', nonlinearity='relu'):
    '''initialize weight of conv2d layer.'''
    weight = initializer(HeNormal(mode=mode, nonlinearity=nonlinearity), shape=shape)
    return weight

def _conv2d(in_channels, out_channels, kernel_size, stride, pad_mode, padding, group=1, has_bias=False):
    '''return conv2d layer with initialized weight'''
    if in_channels % group == 0 and out_channels % group == 0:
        weight_shape = [out_channels, in_channels//group, kernel_size, kernel_size]
    else:
        raise ValueError("In_ channels:{} and out_channels:{} must be divisible by the number of groups:{}."
                         .format(in_channels, out_channels, group))
    weight = kaiming_normal(weight_shape)
    conv = nn.Conv2d(in_channels=in_channels,
                     out_channels=out_channels,
                     kernel_size=kernel_size,
                     stride=stride,
                     pad_mode=pad_mode,
                     padding=padding,
                     group=group,
                     has_bias=has_bias,
                     weight_init=weight
                     )
    return conv


class ConvLayer(nn.Cell):
    '''Convolution layer (conv + bn + relu).'''

    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size=7,
            stride=2,
            pad_mode='pad',
            padding=3,
            group=1,
            has_bias=False
    ):
        super(ConvLayer, self).__init__()
        self.conv = _conv2d(in_channels, out_channels, kernel_size,
                            stride, pad_mode, padding, group, has_bias)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def construct(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class Conv1x1(nn.Cell):
    """1x1 convolution + bn + relu."""

    def __init__(self, in_channels, out_channels, stride=1, group=1):
        super(Conv1x1, self).__init__()
        self.conv = _conv2d(in_channels, out_channels, 1, stride=stride, pad_mode='valid', padding=0, group=group)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def construct(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class Conv1x1Linear(nn.Cell):
    '''1x1 convolution + bn (w/o non-linearity).'''

    def __init__(self, in_channels, out_channels, stride=1):
        super(Conv1x1Linear, self).__init__()
        self.conv = _conv2d(in_channels, out_channels, 1, stride, pad_mode='valid', padding=0)
        self.bn = nn.BatchNorm2d(out_channels)

    def construct(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x


class Conv3x3(nn.Cell):
    """3x3 convolution + bn + relu."""
    def __init__(self, in_channels, out_channels, stride=1, group=1):
        super(Conv3x3, self).__init__()
        self.conv = _conv2d(in_channels, out_channels, 3, stride, pad_mode='pad', padding=1, group=group)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def construct(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class LightConv3x3(nn.Cell):
    """Lightweight 3x3 convolution.
    1x1 (linear) + dw 3x3 (nonlinear).
    """
    def __init__(self, in_channels, out_channels):
        super(LightConv3x3, self).__init__()
        self.conv1 = _conv2d(in_channels, out_channels, 1, stride=1, pad_mode='valid', padding=0)
        self.conv2 = _conv2d(out_channels, out_channels, 3, stride=1, pad_mode='pad', padding=1, group=out_channels)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def construct(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

#### ChannelGate

每个流都可以提供特定尺度的特征，即，它们是尺度均匀的。为了学习全尺度特征，以动态的方式组合不同流的输出，即，不同的权重根据输入图像分配到不同的尺度，而不是经过训练后固定。更具体地说，动态尺度融合是通过一种新的聚合门(AG)实现的，它是一种可学习的神经网络。

In [21]:
class ChannelGate(nn.Cell):
    """A mini-network that generates channel-wise gates conditioned on input tensor."""

    def __init__(
            self,
            in_channels,
            num_gates=None,
            return_gates=False,
            gate_activation='sigmoid',
            reduction=16,
    ):
        super(ChannelGate, self).__init__()
        if num_gates is None:
            num_gates = in_channels
        self.return_gates = return_gates
        self.global_avgpool = ops.ReduceMean(keep_dims=True)
        self.fc1 = _conv2d(in_channels, in_channels//reduction, kernel_size=1, stride=1,
                           pad_mode='valid', padding=0, has_bias=True)
        self.relu = nn.ReLU()
        self.fc2 = _conv2d(in_channels//reduction, num_gates, kernel_size=1, stride=1,
                           pad_mode='valid', padding=0, has_bias=True)
        if gate_activation == 'sigmoid':
            self.gate_activation = nn.Sigmoid()
        elif gate_activation == 'relu':
            self.gate_activation = nn.ReLU()
        elif gate_activation == 'linear':
            self.gate_activation = None
        else:
            raise RuntimeError(
                "Unknown gate activation: {}".format(gate_activation)
            )

    def construct(self, x):
        '''constuct function'''
        inputs = x
        x = self.global_avgpool(x, (2, 3))
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        if self.gate_activation is not None:
            x = self.gate_activation(x)
        if self.return_gates:
            return x
        return inputs * x


#### OSBlock（bottleneck）

![](./img/osblock.png)

使用全尺度残差块OSBlock（b）代替传统残差块（a）。

In [22]:
class OSBlock(nn.Cell):
    """Omni-scale feature learning block."""

    def __init__(
            self,
            in_channels,
            out_channels,
            bottleneck_reduction=4,
            **kwargs
    ):
        super(OSBlock, self).__init__()
        mid_channels = out_channels // bottleneck_reduction
        self.conv1 = Conv1x1(in_channels, mid_channels)
        self.conv2a = LightConv3x3(mid_channels, mid_channels)
        self.conv2b = nn.SequentialCell(
            LightConv3x3(mid_channels, mid_channels),
            LightConv3x3(mid_channels, mid_channels),
        )
        self.conv2c = nn.SequentialCell(
            LightConv3x3(mid_channels, mid_channels),
            LightConv3x3(mid_channels, mid_channels),
            LightConv3x3(mid_channels, mid_channels),
        )
        self.conv2d = nn.SequentialCell(
            LightConv3x3(mid_channels, mid_channels),
            LightConv3x3(mid_channels, mid_channels),
            LightConv3x3(mid_channels, mid_channels),
            LightConv3x3(mid_channels, mid_channels),
        )
        self.gate = ChannelGate(mid_channels)
        self.conv3 = Conv1x1Linear(mid_channels, out_channels)
        self.downsample = None
        self.relu = nn.ReLU()
        if in_channels != out_channels:
            self.downsample = Conv1x1Linear(in_channels, out_channels)

    def construct(self, x):
        '''construct layer'''
        identity = x
        x1 = self.conv1(x)
        x2a = self.conv2a(x1)
        x2b = self.conv2b(x1)
        x2c = self.conv2c(x1)
        x2d = self.conv2d(x1)
        x2 = self.gate(x2a) + self.gate(x2b) + self.gate(x2c) + self.gate(x2d)
        x3 = self.conv3(x2)
        if self.downsample is not None:
            identity = self.downsample(identity)
        add = x3 + identity
        out = self.relu(add)
        return out

#### OSNet

![](./img/Architecture.png)

OSNet是通过简单地逐层堆叠轻量级bottleneck构建，上图为网络架构。

In [23]:
class OSNet(nn.Cell):
    """Omni-Scale Network."""
    def __init__(
            self,
            num_classes,
            blocks,
            layers,
            channels,
            feature_dim=512,
            **kwargs
    ):
        super(OSNet, self).__init__()
        num_blocks = len(blocks)
        assert num_blocks == len(layers)
        assert num_blocks == len(channels) - 1
        self.feature_dim = feature_dim
        self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3)
        self.pad = nn.Pad(paddings=((0, 0), (0, 0), (1, 1), (1, 1)), mode="CONSTANT")
        self.maxpool = nn.MaxPool2d(3, stride=2)
        self.conv2 = self._make_layer(
            blocks[0],
            layers[0],
            channels[0],
            channels[1],
            reduce_spatial_size=True,
        )
        self.conv3 = self._make_layer(
            blocks[1],
            layers[1],
            channels[1],
            channels[2],
            reduce_spatial_size=True
        )
        self.conv4 = self._make_layer(
            blocks[2],
            layers[2],
            channels[2],
            channels[3],
            reduce_spatial_size=False
        )
        self.conv5 = Conv1x1(channels[3], channels[3])
        self.global_avgpool = ops.ReduceMean(keep_dims=True)
        self.fc = self._construct_fc_layer(
            self.feature_dim, channels[3], dropout_p=None
        )
        self.classifier = nn.Dense(self.feature_dim, num_classes)
        self.stop_layer = ops.Identity()

    def _make_layer(
            self,
            block,
            layer,
            in_channels,
            out_channels,
            reduce_spatial_size,
    ):
        '''make block layers.'''
        layers = []
        layers.append(block(in_channels, out_channels))
        for _ in range(1, layer):
            layers.append(block(out_channels, out_channels))
        if reduce_spatial_size:
            layers.append(
                nn.SequentialCell(
                    Conv1x1(out_channels, out_channels),
                    nn.AvgPool2d(2, stride=2)
                )
            )

        return nn.SequentialCell(*layers)

    def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
        '''constuct full-connection layer.'''
        if fc_dims is None or fc_dims < 0:
            self.feature_dim = input_dim
            return None
        if isinstance(fc_dims, int):
            fc_dims = [fc_dims]

        layers = []
        for dim in fc_dims:
            layers.append(nn.Dense(input_dim, dim))
            layers.append(nn.BatchNorm1d(dim))
            layers.append(nn.ReLU())
            if dropout_p is not None:
                layers.append(nn.Dropout(p=dropout_p))
            input_dim = dim

        self.feature_dim = fc_dims[-1]

        return nn.SequentialCell(*layers)

    def construct(self, x):
        '''construct'''
        x = self.conv1(x)
        x = self.pad(x)
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        v = self.global_avgpool(x, (2, 3))
        v = v.view(v.shape[0], -1)
        if self.fc is not None:
            v = self.fc(v)
        if not self.training:
            return v
        y = self.stop_layer(v)
        y = self.classifier(v)
        return y


## 数据集准备与加载

### 下载数据

https://www.kaggle.com/datasets/pengcw1/market-1501/download?datasetVersionNumber=1

### 数据集介绍

Market-1501数据集[1]在清华开放环境通过6个摄像头采集得到，一共包括1501个行人，其中训练集中751个行人12936张图片，测试集（gallery）中750个行人19732张图片，query集中3368张图片。

数据文件说明：
1.  “bounding_box_test”——用于测试集的 750 人，包含 19,732 张图像，前缀为 0000 表示在提取这 750 人的过程中DPM检测错的图（可能与query是同一个人），-1 表示检测出来其他人的图（不在这 750 人中）
2. “bounding_box_train”——用于训练集的 751 人，包含 12,936 张图像
3. “query”——为 750 人在每个摄像头中随机选择一张图像作为query，因此一个人的query最多有 6 个，共有 3,368 张图像
4. “gt_query”——matlab格式，用于判断一个query的哪些图片是好的匹配（同一个人不同摄像头的图像）和不好的匹配（同一个人同一个摄像头的图像或非同一个人的图像）
5. “gt_bbox”——手工标注的bounding box，用于判断DPM检测的bounding box是不是一个好的box

文件命名方式：
例如：0001_c1s1_001051_01.jpg

1. 0001 是行人 ID，Market 1501 有 1501 个行人，故行人 ID 范围为 0001-1501
2. c1 是摄像头编号(camera 4)，表明图片采集自第1个摄像头，一共有 6 个摄像头
3. s1 是视频的第一个片段(sequece1)，一个视频包含若干个片段
4. 001051 是视频的第 1051 帧图片，表明行人出现在该帧图片中
5. 01 代表第 826 帧图片上的第一个检测框，DPM 检测器可能在一帧图片上生成多个检测框，00为手工标注

### 加载数据

本案例实现选取Market-1501数据集中10个行人数据构建子数据集，在CPU上实现。

In [24]:
class Market10():
    _junk_pids = [0, -1]
    dataset_dir = 'Market-10'
    def __init__(self, root='',mode='train', verbose=True, **kwargs):
        self.root = osp.abspath(osp.expanduser(root))
        self.data_dir = osp.join(self.root, self.dataset_dir)
        self.train_dir = osp.join(self.data_dir, 'train')
        self.query_dir = osp.join(self.data_dir, 'query')
        self.gallery_dir = osp.join(self.data_dir, 'test')
        train = self.process_dir(self.train_dir, relabel=True)
        query = self.process_dir(self.query_dir, relabel=False)
        gallery = self.process_dir(self.gallery_dir, relabel=False)

        if len(train[0]) == 3:
            train = [(*items, 0) for items in train]
        if len(query[0]) == 3:
            query = [(*items, 0) for items in query]
        if len(gallery[0]) == 3:
            gallery = [(*items, 0) for items in gallery]

        self.train = train
        self.query = query
        self.gallery = gallery
        self.mode = mode
        self.verbose = verbose

        self.num_train_pids = self.get_num_pids(self.train)
        self.num_train_cams = self.get_num_cams(self.train)
        self.num_datasets = self.get_num_datasets(self.train)


        if self.mode == 'train':
            self.data = self.train
        elif self.mode == 'query':
            self.data = self.query
        elif self.mode == 'gallery':
            self.data = self.gallery

        if self.verbose:
            self.show_summary()

    def process_dir(self, dir_path, relabel=False):
        '''get images and labels from directory.'''
        img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
        pattern = re.compile(r'([-\d]+)_c(\d)')

        pid_container = set()
        for img_path in img_paths:
            pid, _ = map(int, pattern.search(img_path).groups())
            if pid == -1:
                continue # junk images are just ignored
            pid_container.add(pid)
        pid2label = {pid: label for label, pid in enumerate(pid_container)}

        data = []
        for img_path in img_paths:
            pid, camid = map(int, pattern.search(img_path).groups())
            if pid == -1:
                continue # junk images are just ignored
            assert 0 <= pid <= 1501 # pid == 0 means background
            assert 1 <= camid <= 6
            camid -= 1 # index starts from 0
            if relabel:
                pid = pid2label[pid]
            data.append((img_path, pid, camid))

        return data
    def __getitem__(self, index):
        img_path, pid, camid, _ = self.data[index]
        img = Image.open(img_path).convert('RGB')
        pid = np.array(pid).astype(np.int32)
        if self.mode == 'train':
            return img, pid

        return img, pid, camid


    def __len__(self):
        return len(self.data)

    def get_num_pids(self, data):
        pids = set()
        for items in data:
            pid = items[1]
            pids.add(pid)
        return len(pids)

    def get_num_cams(self, data):
        cams = set()
        for items in data:
            camid = items[2]
            cams.add(camid)
        return len(cams)

    def get_num_datasets(self, data):
        dsets = set()
        for items in data:
            dsetid = items[3]
            dsets.add(dsetid)
        return len(dsets)

    def show_summary(self):
        num_train_pids = self.get_num_pids(self.train)
        num_train_cams = self.get_num_cams(self.train)

        num_query_pids = self.get_num_pids(self.query)
        num_query_cams = self.get_num_cams(self.query)

        num_gallery_pids = self.get_num_pids(self.gallery)
        num_gallery_cams = self.get_num_cams(self.gallery)

        print('=> Loaded {}'.format(self.__class__.__name__))
        print('  ----------------------------------------')
        print('  subset   | # ids | # images | # cameras')
        print('  ----------------------------------------')
        print(
            '  train    | {:5d} | {:8d} | {:9d}'.format(
                num_train_pids, len(self.train), num_train_cams
            )
        )
        print(
            '  query    | {:5d} | {:8d} | {:9d}'.format(
                num_query_pids, len(self.query), num_query_cams
            )
        )
        print(
            '  gallery  | {:5d} | {:8d} | {:9d}'.format(
                num_gallery_pids, len(self.gallery), num_gallery_cams
            )
        )
        print('  ----------------------------------------')

    def __repr__(self):
        num_train_pids = self.get_num_pids(self.train)
        num_train_cams = self.get_num_cams(self.train)

        num_query_pids = self.get_num_pids(self.query)
        num_query_cams = self.get_num_cams(self.query)

        num_gallery_pids = self.get_num_pids(self.gallery)
        num_gallery_cams = self.get_num_cams(self.gallery)

        msg = '  ----------------------------------------\n' \
              '  subset   | # ids | # items | # cameras\n' \
              '  ----------------------------------------\n' \
              '  train    | {:5d} | {:7d} | {:9d}\n' \
              '  query    | {:5d} | {:7d} | {:9d}\n' \
              '  gallery  | {:5d} | {:7d} | {:9d}\n' \
              '  ----------------------------------------\n' \
              '  items: images/tracklets for image/video dataset\n'.format(
                  num_train_pids, len(self.train), num_train_cams,
                  num_query_pids, len(self.query), num_query_cams,
                  num_gallery_pids, len(self.gallery), num_gallery_cams
              )

        return msg


def dataset_creator(
        root='',
        height=256,
        width=128,
        norm_mean=[0.485, 0.456, 0.406],
        norm_std=[0.229, 0.224, 0.225],
        batch_size_train=32,
        batch_size_test=32,
        mode=None
):
    dataset_ = Market10(root=root, mode=mode)
    num_pids = dataset_.num_train_pids

    if mode == 'train':
        sampler = ds.RandomSampler()
        data_set = ds.GeneratorDataset(dataset_, ['img', 'pid'], sampler=sampler)
        transforms = Compose([
            Resize((height, width)),
            RandomHorizontalFlip(),
            Rescale(1.0 / 255.0, 0.0),
            Normalize(mean=norm_mean, std=norm_std),
            HWC2CHW(),
        ])
        data_set = data_set.map(operations=transforms, input_columns=['img'])
        data_set = data_set.batch(batch_size=batch_size_train, drop_remainder=True)
        return num_pids, data_set

    data_set = ds.GeneratorDataset(dataset_, ['img', 'pid', 'camid'])
    transforms = Compose([
        Resize((height, width)),
        Rescale(1.0/255.0, 0.0),
        Normalize(mean=norm_mean, std=norm_std),
        HWC2CHW(),
    ])
    data_set = data_set.map(operations=transforms, input_columns=['img'])
    data_set = data_set.batch(batch_size=batch_size_test, drop_remainder=False)
    return num_pids, data_set

## 模型训练与评估

In [25]:

def init_pretrained_weights(model, pretrained_param_dir):
    """
    Initializes model with pretrained weights.
    Layers that don't match with pretrained layers in name or size are kept unchanged.
    """

    filename = 'init_osnet.ckpt'
    file = os.path.join(pretrained_param_dir, filename)
    print(file)
    if not os.path.exists(file):
        raise ValueError(
            'The file:{} does not exist.'.format(file)
        )
    param_dict = load_checkpoint(file)
    model_dict = model.parameters_dict()
    new_state_dict = OrderedDict()
    matched_layers, discarded_layers = [], []
    for k, v in param_dict.items():
        if k in model_dict and model_dict[k].data.shape == v.shape:
            new_state_dict[k] = v
            matched_layers.append(k)
        else:
            discarded_layers.append(k)

    model_dict.update(new_state_dict)
    load_param_into_net(model, model_dict)

    if not matched_layers:
        warnings.warn(
            'The pretrained weights from "{}" cannot be loaded, '
            'please check the key names manually '
            '(** ignored and continue **)'.format(file)
        )
    else:
        print(
            'Successfully loaded imagenet pretrained weights from "{}"'.
            format(file)
        )
        if discarded_layers:
            print(
                '** The following layers are discarded '
                'due to unmatched keys or layer size: {}'.
                format(discarded_layers)
            )


def create_osnet(num_classes=1500, pretrained=False, pretrained_dir='', **kwargs):
    '''create osnet.'''
    model = OSNet(
        num_classes,
        blocks=[OSBlock, OSBlock, OSBlock],
        layers=[2, 2, 2],
        channels=[64, 256, 384, 512],
        **kwargs
    )
    if pretrained:
        init_pretrained_weights(model, pretrained_dir)
    return model

### 模型训练

模型训练时前10个epoch固定网络参数训练分类器，之后才开始训练网络参数。

In [26]:
from mindvision.engine.loss import CrossEntropySmooth
from mindvision.engine.callback import LossMonitor

set_seed(1)
max_epoch = 100
fixbase_epoch = 10
batch_size_train = 32
data_path = './datasets'
height = 256
width = 128

num_classes, dataset1 = dataset_creator(root=data_path, height=height, width=width, batch_size_train=batch_size_train, mode='train')
num_classes, dataset2 = dataset_creator(root=data_path, height=height, width=width, batch_size_train=batch_size_train, mode='train')
num_batches = dataset1.get_dataset_size()

net = create_osnet(num_classes=num_classes, pretrained=True, pretrained_dir='./pretrained_model')

crit = CrossEntropySmooth(sparse=True,
                          reduction="mean",
                          smooth_factor=0.1,
                          classes_num=num_classes)

lr = nn.cosine_decay_lr(0., 0.001, num_batches * max_epoch, num_batches,
                                         max_epoch)
time_cb = TimeMonitor(data_size=num_batches)

net.stop_layer = ops.stop_gradient
lr1 = lr[:fixbase_epoch * num_batches]
opt1 = nn.Adam(net.classifier.trainable_params(), learning_rate=lr1, beta1=0.9,
               beta2=0.99, weight_decay=0.0005)
model1 = Model(network=net, optimizer=opt1, loss_fn=crit)
loss_cb1 = LossMonitor(lr1)
cb1 = [time_cb, loss_cb1]
model1.train(fixbase_epoch, dataset1, cb1, dataset_sink_mode=True)

net.stop_layer = ops.Identity()
lr2 = lr[fixbase_epoch * num_batches:]
loss_cb2 = LossMonitor(lr2)
opt2 = nn.Adam(net.trainable_params(), learning_rate=lr2, beta1=0.9, beta2=0.99,
               weight_decay=0.0005)
model2 = Model(network=net, optimizer=opt2, loss_fn=crit)

cb2 = [time_cb, loss_cb2]

ckpt_append_info = [{"epoch_num": fixbase_epoch, "step_num": fixbase_epoch}]
config_ck = CheckpointConfig(save_checkpoint_steps=10 * num_batches,
                             keep_checkpoint_max=10, append_info=ckpt_append_info)
ckpt_save_dir = 'output/checkpoint/market10'
ckpt_cb = ModelCheckpoint(prefix="osnet", directory=ckpt_save_dir, config=config_ck)
cb2 += [ckpt_cb]

model2.train(max_epoch-fixbase_epoch, dataset2, cb2, dataset_sink_mode=True)
print("train success")



=> Loaded Market10
  ----------------------------------------
  subset   | # ids | # images | # cameras
  ----------------------------------------
  train    |     5 |       85 |         6
  query    |     5 |       24 |         6
  gallery  |     5 |       89 |         6
  ----------------------------------------
=> Loaded Market10
  ----------------------------------------
  subset   | # ids | # images | # cameras
  ----------------------------------------
  train    |     5 |       85 |         6
  query    |     5 |       24 |         6
  gallery  |     5 |       89 |         6
  ----------------------------------------
./pretrained_model\init_osnet.ckpt
Successfully loaded imagenet pretrained weights from "./pretrained_model\init_osnet.ckpt"
** The following layers are discarded due to unmatched keys or layer size: ['classifier.weight', 'classifier.bias']
Epoch:[  0/ 10], step:[    1/    2], loss:[1.656/1.656], time:5912.812 ms, lr:0.00100
Epoch:[  0/ 10], step:[    2/    2], loss

Train epoch time: 5723.422 ms, per step time: 2861.711 ms
Epoch time: 5724.423 ms, per step time: 2862.211 ms, avg loss: 0.481
Epoch:[ 15/240], step:[    1/    2], loss:[0.483/0.483], time:2664.919 ms, lr:0.00098
Epoch:[ 15/240], step:[    2/    2], loss:[0.474/0.479], time:2549.893 ms, lr:0.00098
Train epoch time: 5235.832 ms, per step time: 2617.916 ms
Epoch time: 5236.833 ms, per step time: 2618.417 ms, avg loss: 0.479
Epoch:[ 16/240], step:[    1/    2], loss:[0.477/0.477], time:2504.364 ms, lr:0.00097
Epoch:[ 16/240], step:[    2/    2], loss:[0.477/0.477], time:2661.165 ms, lr:0.00097
Train epoch time: 5187.548 ms, per step time: 2593.774 ms
Epoch time: 5188.549 ms, per step time: 2594.275 ms, avg loss: 0.477
Epoch:[ 17/240], step:[    1/    2], loss:[0.475/0.475], time:2718.243 ms, lr:0.00097
Epoch:[ 17/240], step:[    2/    2], loss:[0.475/0.475], time:2767.173 ms, lr:0.00097
Train epoch time: 5504.433 ms, per step time: 2752.216 ms
Epoch time: 5505.434 ms, per step time: 2752.

Epoch:[ 42/240], step:[    1/    2], loss:[0.474/0.474], time:3039.341 ms, lr:0.00090
Epoch:[ 42/240], step:[    2/    2], loss:[0.470/0.472], time:2992.615 ms, lr:0.00090
Train epoch time: 6051.972 ms, per step time: 3025.986 ms
Epoch time: 6052.973 ms, per step time: 3026.487 ms, avg loss: 0.472
Epoch:[ 43/240], step:[    1/    2], loss:[0.481/0.481], time:2619.278 ms, lr:0.00089
Epoch:[ 43/240], step:[    2/    2], loss:[0.470/0.475], time:2749.896 ms, lr:0.00089
Train epoch time: 5390.192 ms, per step time: 2695.096 ms
Epoch time: 5391.193 ms, per step time: 2695.596 ms, avg loss: 0.475
Epoch:[ 44/240], step:[    1/    2], loss:[0.468/0.468], time:2880.180 ms, lr:0.00089
Epoch:[ 44/240], step:[    2/    2], loss:[0.467/0.467], time:2792.932 ms, lr:0.00089
Train epoch time: 5694.185 ms, per step time: 2847.093 ms
Epoch time: 5694.185 ms, per step time: 2847.093 ms, avg loss: 0.467
Epoch:[ 45/240], step:[    1/    2], loss:[0.468/0.468], time:2791.085 ms, lr:0.00089
Epoch:[ 45/240], 

Train epoch time: 5614.611 ms, per step time: 2807.306 ms
Epoch time: 5615.612 ms, per step time: 2807.806 ms, avg loss: 0.471
Epoch:[ 70/240], step:[    1/    2], loss:[0.469/0.469], time:2656.933 ms, lr:0.00077
Epoch:[ 70/240], step:[    2/    2], loss:[0.472/0.470], time:3045.095 ms, lr:0.00077
Train epoch time: 5723.046 ms, per step time: 2861.523 ms
Epoch time: 5723.046 ms, per step time: 2861.523 ms, avg loss: 0.470
Epoch:[ 71/240], step:[    1/    2], loss:[0.470/0.470], time:2893.305 ms, lr:0.00076
Epoch:[ 71/240], step:[    2/    2], loss:[0.470/0.470], time:2887.393 ms, lr:0.00076
Train epoch time: 5802.765 ms, per step time: 2901.383 ms
Epoch time: 5803.767 ms, per step time: 2901.883 ms, avg loss: 0.470
Epoch:[ 72/240], step:[    1/    2], loss:[0.471/0.471], time:2921.974 ms, lr:0.00076
Epoch:[ 72/240], step:[    2/    2], loss:[0.472/0.471], time:2754.349 ms, lr:0.00076
Train epoch time: 5696.339 ms, per step time: 2848.169 ms
Epoch time: 5697.341 ms, per step time: 2848.

Epoch:[ 97/240], step:[    1/    2], loss:[0.476/0.476], time:2475.767 ms, lr:0.00061
Epoch:[ 97/240], step:[    2/    2], loss:[0.479/0.478], time:3330.045 ms, lr:0.00061
Train epoch time: 5825.815 ms, per step time: 2912.908 ms
Epoch time: 5826.816 ms, per step time: 2913.408 ms, avg loss: 0.478
Epoch:[ 98/240], step:[    1/    2], loss:[0.469/0.469], time:3167.845 ms, lr:0.00061
Epoch:[ 98/240], step:[    2/    2], loss:[0.468/0.469], time:2806.574 ms, lr:0.00061
Train epoch time: 5998.438 ms, per step time: 2999.219 ms
Epoch time: 5998.438 ms, per step time: 2999.219 ms, avg loss: 0.469
Epoch:[ 99/240], step:[    1/    2], loss:[0.470/0.470], time:2534.035 ms, lr:0.00060
Epoch:[ 99/240], step:[    2/    2], loss:[0.471/0.470], time:2489.554 ms, lr:0.00060
Train epoch time: 5292.093 ms, per step time: 2646.046 ms
Epoch time: 5293.094 ms, per step time: 2646.547 ms, avg loss: 0.470
Epoch:[100/240], step:[    1/    2], loss:[0.466/0.466], time:2541.437 ms, lr:0.00059
Epoch:[100/240], 

Train epoch time: 5173.464 ms, per step time: 2586.732 ms
Epoch time: 5174.465 ms, per step time: 2587.232 ms, avg loss: 0.466
Epoch:[125/240], step:[    1/    2], loss:[0.467/0.467], time:2513.672 ms, lr:0.00044
Epoch:[125/240], step:[    2/    2], loss:[0.466/0.466], time:2493.711 ms, lr:0.00044
Train epoch time: 5027.401 ms, per step time: 2513.700 ms
Epoch time: 5028.402 ms, per step time: 2514.201 ms, avg loss: 0.466
Epoch:[126/240], step:[    1/    2], loss:[0.466/0.466], time:2535.617 ms, lr:0.00043
Epoch:[126/240], step:[    2/    2], loss:[0.466/0.466], time:2531.222 ms, lr:0.00043
Train epoch time: 5084.855 ms, per step time: 2542.427 ms
Epoch time: 5085.856 ms, per step time: 2542.928 ms, avg loss: 0.466
Epoch:[127/240], step:[    1/    2], loss:[0.465/0.465], time:2555.778 ms, lr:0.00042
Epoch:[127/240], step:[    2/    2], loss:[0.465/0.465], time:2584.415 ms, lr:0.00042
Train epoch time: 5162.262 ms, per step time: 2581.131 ms
Epoch time: 5162.262 ms, per step time: 2581.

Epoch:[152/240], step:[    1/    2], loss:[0.467/0.467], time:3092.586 ms, lr:0.00028
Epoch:[152/240], step:[    2/    2], loss:[0.469/0.468], time:3111.950 ms, lr:0.00028
Train epoch time: 6225.536 ms, per step time: 3112.768 ms
Epoch time: 6226.537 ms, per step time: 3113.268 ms, avg loss: 0.468
Epoch:[153/240], step:[    1/    2], loss:[0.465/0.465], time:3257.188 ms, lr:0.00027
Epoch:[153/240], step:[    2/    2], loss:[0.466/0.465], time:3141.823 ms, lr:0.00027
Train epoch time: 6419.568 ms, per step time: 3209.784 ms
Epoch time: 6420.568 ms, per step time: 3210.284 ms, avg loss: 0.465
Epoch:[154/240], step:[    1/    2], loss:[0.465/0.465], time:3103.691 ms, lr:0.00026
Epoch:[154/240], step:[    2/    2], loss:[0.466/0.466], time:3293.181 ms, lr:0.00026
Train epoch time: 6416.890 ms, per step time: 3208.445 ms
Epoch time: 6418.892 ms, per step time: 3209.446 ms, avg loss: 0.466
Epoch:[155/240], step:[    1/    2], loss:[0.465/0.465], time:3740.847 ms, lr:0.00026
Epoch:[155/240], 

Train epoch time: 6566.323 ms, per step time: 3283.162 ms
Epoch time: 6568.325 ms, per step time: 3284.162 ms, avg loss: 0.466
Epoch:[180/240], step:[    1/    2], loss:[0.466/0.466], time:3189.067 ms, lr:0.00014
Epoch:[180/240], step:[    2/    2], loss:[0.468/0.467], time:3105.118 ms, lr:0.00014
Train epoch time: 6318.256 ms, per step time: 3159.128 ms
Epoch time: 6319.257 ms, per step time: 3159.628 ms, avg loss: 0.467
Epoch:[181/240], step:[    1/    2], loss:[0.467/0.467], time:3238.214 ms, lr:0.00013
Epoch:[181/240], step:[    2/    2], loss:[0.468/0.468], time:3244.334 ms, lr:0.00013
Train epoch time: 6504.554 ms, per step time: 3252.277 ms
Epoch time: 6505.555 ms, per step time: 3252.777 ms, avg loss: 0.468
Epoch:[182/240], step:[    1/    2], loss:[0.466/0.466], time:3229.640 ms, lr:0.00013
Epoch:[182/240], step:[    2/    2], loss:[0.466/0.466], time:3353.150 ms, lr:0.00013
Train epoch time: 6603.809 ms, per step time: 3301.904 ms
Epoch time: 6604.810 ms, per step time: 3302.

Epoch:[207/240], step:[    1/    2], loss:[0.464/0.464], time:3251.417 ms, lr:0.00004
Epoch:[207/240], step:[    2/    2], loss:[0.465/0.465], time:3310.607 ms, lr:0.00004
Train epoch time: 6587.047 ms, per step time: 3293.524 ms
Epoch time: 6588.048 ms, per step time: 3294.024 ms, avg loss: 0.465
Epoch:[208/240], step:[    1/    2], loss:[0.465/0.465], time:3240.446 ms, lr:0.00004
Epoch:[208/240], step:[    2/    2], loss:[0.464/0.465], time:3311.095 ms, lr:0.00004
Train epoch time: 6572.631 ms, per step time: 3286.315 ms
Epoch time: 6574.633 ms, per step time: 3287.316 ms, avg loss: 0.465
Epoch:[209/240], step:[    1/    2], loss:[0.464/0.464], time:3178.106 ms, lr:0.00004
Epoch:[209/240], step:[    2/    2], loss:[0.466/0.465], time:3331.935 ms, lr:0.00004
Train epoch time: 6733.831 ms, per step time: 3366.915 ms
Epoch time: 6734.833 ms, per step time: 3367.416 ms, avg loss: 0.465
Epoch:[210/240], step:[    1/    2], loss:[0.465/0.465], time:3361.317 ms, lr:0.00004
Epoch:[210/240], 

Train epoch time: 6572.902 ms, per step time: 3286.451 ms
Epoch time: 6573.904 ms, per step time: 3286.952 ms, avg loss: 0.464
Epoch:[235/240], step:[    1/    2], loss:[0.465/0.465], time:3201.495 ms, lr:0.00000
Epoch:[235/240], step:[    2/    2], loss:[0.465/0.465], time:3224.311 ms, lr:0.00000
Train epoch time: 6444.823 ms, per step time: 3222.412 ms
Epoch time: 6445.824 ms, per step time: 3222.912 ms, avg loss: 0.465
Epoch:[236/240], step:[    1/    2], loss:[0.467/0.467], time:3192.538 ms, lr:0.00000
Epoch:[236/240], step:[    2/    2], loss:[0.464/0.466], time:3139.458 ms, lr:0.00000
Train epoch time: 6349.991 ms, per step time: 3174.995 ms
Epoch time: 6349.991 ms, per step time: 3174.995 ms, avg loss: 0.466
Epoch:[237/240], step:[    1/    2], loss:[0.465/0.465], time:2943.285 ms, lr:0.00000
Epoch:[237/240], step:[    2/    2], loss:[0.465/0.465], time:3011.216 ms, lr:0.00000
Train epoch time: 5972.518 ms, per step time: 2986.259 ms
Epoch time: 5973.519 ms, per step time: 2986.

### 模型验证

In [27]:
import mindspore
def euclidean_squared_distance(input1, input2):
    m, n = input1.shape[0], input2.shape[0]

    shape_tensor1 = Tensor(np.zeros((m, n), dtype=np.float32))
    shape_tensor2 = Tensor(np.zeros((n, m), dtype=np.float32))
    op_pow = ops.Pow()

    mat1 = op_pow(input1, 2).sum(axis=1, keepdims=True).expand_as(shape_tensor1)
    mat2 = op_pow(input2, 2).sum(axis=1, keepdims=True).expand_as(shape_tensor2).T
    distmat = mat1 + mat2
    matmul = ops.MatMul(False, True)
    cast = ops.Cast()
    input1 = cast(input1, mindspore.float16)
    input2 = cast(input2, mindspore.float16)
    output = cast(matmul(input1, input2), mindspore.float32)
    distmat = distmat - 2 * output

    return distmat

def eval_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
    num_q, num_g = distmat.shape

    if num_g < max_rank:
        max_rank = num_g
        print(
            'Note: number of gallery samples is quite small, got {}'.
            format(num_g)
        )

    indices = np.argsort(distmat, axis=1)
    matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)

    # compute cmc curve for each query
    all_cmc = []
    all_AP = []
    num_valid_q = 0. # number of valid query

    for q_idx in range(num_q):
        # get query pid and camid
        q_pid = q_pids[q_idx]
        q_camid = q_camids[q_idx]

        # remove gallery samples that have the same pid and camid with query
        order = indices[q_idx]
        remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
        keep = np.invert(remove)

        # compute cmc curve
        raw_cmc = matches[q_idx][
            keep] # binary vector, positions with value 1 are correct matches
        if not np.any(raw_cmc):
            # this condition is true when query identity does not appear in gallery
            continue

        cmc = raw_cmc.cumsum()
        cmc[cmc > 1] = 1

        all_cmc.append(cmc[:max_rank])
        num_valid_q += 1.

        # compute average precision
        # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
        num_rel = raw_cmc.sum()
        tmp_cmc = raw_cmc.cumsum()
        tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
        tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
        AP = tmp_cmc.sum() / num_rel
        all_AP.append(AP)

    assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'

    all_cmc = np.asarray(all_cmc).astype(np.float32)
    all_cmc = all_cmc.sum(0) / num_valid_q
    mAP = np.mean(all_AP)

    return all_cmc, mAP


class CustomWithEvalCell(nn.Cell):
    def __init__(self, network):
        super(CustomWithEvalCell, self).__init__(auto_prefix=False)
        self._network = network

    def construct(self, data):
        outputs = self._network(data)
        return outputs

batch_size_test = 32
checkpoint_file_path = "./output/checkpoint/market10/osnet-200_2.ckpt"

num_train_classes, query_dataset = dataset_creator(root=data_path, height=height, width=width, batch_size_test=batch_size_test, mode='query')
num_train_classes, gallery_dataset = dataset_creator(root=data_path, height=height, width=width, batch_size_test=batch_size_test, mode='gallery')

net = create_osnet(num_train_classes)
param_dict = load_checkpoint(checkpoint_file_path, filter_prefix='epoch_num')
load_param_into_net(net, param_dict)

net.set_train(False)
net_eval = CustomWithEvalCell(net)

def feature_extraction(eval_dataset):
    f_, pids_, camids_ = [], [], []
    for data in eval_dataset.create_dict_iterator():
        imgs, pids, camids = data['img'], data['pid'], data['camid']
        features = net_eval(imgs)
        f_.append(features)
        pids_.extend(pids.asnumpy())
        camids_.extend(camids.asnumpy())
    concat = ops.Concat(axis=0)
    f_ = concat(f_)
    pids_ = np.asarray(pids_)
    camids_ = np.asarray(camids_)
    return f_, pids_, camids_

print('Extracting features from query set ...')
qf, q_pids, q_camids = feature_extraction(query_dataset)
print('Done, obtained {}-by-{} matrix'.format(qf.shape[0], qf.shape[1]))

print('Extracting features from gallery set ...')
gf, g_pids, g_camids = feature_extraction(gallery_dataset)
print('Done, obtained {}-by-{} matrix'.format(gf.shape[0], gf.shape[1]))

# if normalize_feature:
#     l2_normalize = ops.L2Normalize(axis=1)
#     qf = l2_normalize(qf)
#     gf = l2_normalize(gf)

print('Computing distance matrix with metric={} ...'.format('euclidean'))
distmat = euclidean_squared_distance(qf, gf)
distmat = distmat.asnumpy()

print('Computing CMC and mAP ...')
cmc, mAP = eval_rank(
    distmat,
    q_pids,
    g_pids,
    q_camids,
    g_camids
)

print('** Results **')
print('ckpt={}'.format(checkpoint_file_path))
print('mAP: {:.1%}'.format(mAP))
print('CMC curve')
ranks = [1, 5, 10, 20]
i = 0
for r in ranks:
    print('Rank-{:<3}: {:.1%}'.format(r, cmc[i]))
    i += 1



=> Loaded Market10
  ----------------------------------------
  subset   | # ids | # images | # cameras
  ----------------------------------------
  train    |     5 |       85 |         6
  query    |     5 |       24 |         6
  gallery  |     5 |       89 |         6
  ----------------------------------------
=> Loaded Market10
  ----------------------------------------
  subset   | # ids | # images | # cameras
  ----------------------------------------
  train    |     5 |       85 |         6
  query    |     5 |       24 |         6
  gallery  |     5 |       89 |         6
  ----------------------------------------
Extracting features from query set ...
Done, obtained 24-by-512 matrix
Extracting features from gallery set ...
Done, obtained 89-by-512 matrix
Computing distance matrix with metric=euclidean ...
Computing CMC and mAP ...
** Results **
ckpt=./output/checkpoint/market10/osnet-200_2.ckpt
mAP: 54.4%
CMC curve
Rank-1  : 62.5%
Rank-5  : 83.3%
Rank-10 : 83.3%
Rank-20 : 83

## 引用

[1] Liang Zheng*, Shengjin Wang, Liyue Shen*, Lu Tian*, Jiahao Bu, and Qi Tian. Person Re-identification Meets Image Search. Technical Report, 2015.

[2] Zhou K, Yang Y, Cavallaro A, et al. Omni-scale feature learning for person re-identification[C]//Proceedings of the IEEE/CVF International Conference on Computer Vision. 2019: 3702-3712.