In [None]:
import mxnet as mx
from os import path as osp
import glob
import re
import scipy.io as sio
from mxnet import autograd, gluon, image, init, nd
import numpy as np
from matplotlib import pyplot as plt
import cv2
from mxnet.gluon.data import DataLoader
from mxnet.gluon.data import Dataset
from mxnet.image import imread
from mxnet.gluon import nn,model_zoo

In [None]:

class Market1501(object):

    dataset_dir = 'Market-1501-v15.09.15'

    def __init__(self, root='', **kwargs):
        self.dataset_dir = osp.join(root, self.dataset_dir)
        self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
        self.query_dir = osp.join(self.dataset_dir, 'query')
        self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')

        self._check_before_run()

        train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, relabel=True)
        query, num_query_pids, num_query_imgs = self._process_dir(self.query_dir, relabel=False)
        gallery, num_gallery_pids, num_gallery_imgs = self._process_dir(self.gallery_dir, relabel=False)
        num_total_pids = num_train_pids + num_query_pids
        num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs

        print("=> Market1501 loaded")
        print("Dataset statistics:")
        print("  ------------------------------")
        print("  subset   | # ids | # images")
        print("  ------------------------------")
        print("  train    | {:5d} | {:8d}".format(num_train_pids, num_train_imgs))
        print("  query    | {:5d} | {:8d}".format(num_query_pids, num_query_imgs))
        print("  gallery  | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs))
        print("  ------------------------------")
        print("  total    | {:5d} | {:8d}".format(num_total_pids, num_total_imgs))
        print("  ------------------------------")

        self.train = train
        self.query = query
        self.gallery = gallery

        self.num_train_pids = num_train_pids
        self.num_query_pids = num_query_pids
        self.num_gallery_pids = num_gallery_pids

    def _check_before_run(self):
        """Check if all files are available before going deeper"""
        if not osp.exists(self.dataset_dir):
            raise RuntimeError("'{}' is not available".format(self.dataset_dir))
        if not osp.exists(self.train_dir):
            raise RuntimeError("'{}' is not available".format(self.train_dir))
        if not osp.exists(self.query_dir):
            raise RuntimeError("'{}' is not available".format(self.query_dir))
        if not osp.exists(self.gallery_dir):
            raise RuntimeError("'{}' is not available".format(self.gallery_dir))

    def _process_dir(self, dir_path, relabel=False):
        img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
        pattern = re.compile(r'([-\d]+)_c(\d)')

        pid_container = set()
        for img_path in img_paths:
            pid, _ = map(int, pattern.search(img_path).groups())
            if pid == -1: continue  # junk images are just ignored
            pid_container.add(pid)
        pid2label = {pid: label for label, pid in enumerate(pid_container)}

        dataset = []
        for img_path in img_paths:
            pid, camid = map(int, pattern.search(img_path).groups())
            if pid == -1:
                continue  # junk images are just ignored
            assert 0 <= pid <= 1501  # pid == 0 means background
            assert 1 <= camid <= 6
            camid -= 1  # index starts from 0
            if relabel: pid = pid2label[pid]
            dataset.append((img_path, pid, camid))

        num_pids = len(pid_container)
        num_imgs = len(dataset)
        return dataset, num_pids, num_imgs



__factory = {
    'market': Market1501
}


def get_names():
    return __factory.keys()


def init_dataset(name, *args, **kwargs):
    if name not in __factory.keys():
        raise KeyError("Unknown datasets: {}".format(name))
    return __factory[name](*args, **kwargs)


dataset=Market1501()

In [None]:
from collections import defaultdict

import numpy as np
from mxnet.gluon.data.sampler import Sampler


class RandomIdentitySampler(Sampler):
    def __init__(self, data_source, num_instances=4):
        self.data_source = data_source#train_img
        self.num_instances = num_instances
        self.index_dic = defaultdict(list)
        for index, (_, pid, _) in enumerate(data_source):
            self.index_dic[pid].append(index)#第几个是当前类别的
        self.pids = list(self.index_dic.keys())#类别编号
        self.num_identities = len(self.pids)#类别总数

    def __iter__(self):
        indices = np.random.permutation(self.num_identities)
        ret = []
        for i in indices:
            pid = self.pids[i]
            t = self.index_dic[pid]
            replace = False if len(t) >= self.num_instances else True
            t = np.random.choice(t, size=self.num_instances, replace=replace)
            ret.extend(t)
        return iter(ret)

    def __len__(self):
        return self.num_identities * self.num_instances#751*4
    
from mxnet.gluon.data import Dataset

class ImageData(Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform

    def __getitem__(self, item):
        img, pid, camid = self.dataset[item]
        img = imread(img)
        if self.transform is not None:
            img = self.transform(img)
        return img, pid, camid

    def __len__(self):
        return len(self.dataset)
    
    def read_image(img_path):
        got_img = False
        while not got_img:
            try:
                img = imread(img_path)
                got_img = True
            except IOError:
                print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path))
                pass
        return img

from mxnet.gluon.data.vision import transforms as T

class TrainTransform(object):
    def __init__(self, h, w):
        self.h = h
        self.w = w

    def __call__(self, x):
        x = Random2DTranslation(self.h, self.w)(x)
        x = T.RandomFlipLeftRight()(x)
        x = T.RandomColorJitter(brightness=0.4, contrast=0.4,
                                              saturation=0.4)(x)
        x = T.ToTensor()(x)
        x = T.Normalize(mean=(0.485, 0.456, 0.406),
                        std=(0.229, 0.224, 0.225))(x)
#         x = T.Cast(dtype='float16')(x)
        return x





class TestTransform(object):
    def __init__(self, h, w):
        self.h = h
        self.w = w

    def __call__(self, x=None):
        x = T.Resize((self.w, self.h))(x)
        x = T.ToTensor()(x)
        x = T.Normalize(mean=(0.485, 0.456, 0.406),
                        std=(0.229, 0.224, 0.225))(x)
#         x = T.Cast(dtype='float16')(x)
        return x

In [None]:
height=224
width=224
#     height = 256
#     width = 128
num_instances = 4
train_batch=32
workers=8
test_batch=32

galleryloader = DataLoader(
        ImageData(dataset.gallery, TestTransform(height, width)),
        batch_size=test_batch, num_workers=0,last_batch="keep"
    )


queryloader = DataLoader(
        ImageData(dataset.query, TestTransform(height, width)),
        batch_size=test_batch, num_workers=0,
    )

In [None]:
net = gluon.SymbolBlock.imports(symbol_file = 'resnet/jxx_res-symbol.json',ctx=mx.gpu(),
                                 param_file= 'resnet/jxx_res-0370.params',
                                 input_names=['data']
                                )

In [None]:
ctx=mx.gpu()
for i,inputs in enumerate(galleryloader):
    img=inputs[0].as_in_context(ctx)
    if i==0:
        _,gallery_feature=net(img)
        gallery_label = inputs[1]
        gallery_cam = inputs[2]
    else:     
        gallery_label=nd.concat(gallery_label,inputs[1],dim=0)
        _,t_feature=net(img)
        gallery_feature=nd.concat(gallery_feature,t_feature,dim=0)
        gallery_cam=nd.concat(gallery_cam,inputs[2],dim=0)
        


In [None]:

# for i,inputs in enumerate(galleryloader):
#     img=inputs[0].as_in_context(ctx)
#     if i==0:
#         _,gallery_feature=net(img)
#     break
# #         gallery_label = inputs[1]
# #         gallery_cam = inputs[2]
        
# # #         gallery_feature=gallery_feature_t.as_in_context(mx.cpu())
# # #         gallery_label=gallery_label.as_in_context(mx.cpu())
# # #         gallery_feature=gallery_feature.as_in_context(mx.cpu())
# #     else:     
# #         gallery_label=nd.concat(gallery_label,inputs[1],dim=0)
# #         _,t_feature=net(img)
# # #         t_feature=t_feature.as_in_context(mx.cpu())
        
# #         gallery_feature=nd.concat(gallery_feature,t_feature,dim=0)
# #         gallery_cam=nd.concat(gallery_cam,inputs[2],dim=0)
        
    



In [None]:

result = {'gallery_f':gallery_feature.asnumpy(),
          'gallery_label':gallery_label.asnumpy(),
          'gallery_cam':gallery_cam.asnumpy(),

         }
# result={"label":gallery_label.asnumpy()}
sio.savemat('result_res_gallery.mat',result)

In [None]:
print(gallery_feature[1].asnumpy()[1:10])

In [None]:
for i,inputs in enumerate(queryloader):
    img=inputs[0].as_in_context(ctx)
    if i==0:
        _,query_feature=net(img)
        query_label = inputs[1]
        query_cam = inputs[2]
    else:     
        query_label=nd.concat(query_label,inputs[1],dim=0)
        _,t_feature=net(img)
        query_feature=nd.concat(query_feature,t_feature,dim=0)
        query_cam=nd.concat(query_cam,inputs[2],dim=0)

In [None]:
for i,inputs in enumerate(galleryloader):
    img=inputs[0].as_in_context(ctx)
    if i==0:
        _,gallery_feature=net(img)
        gallery_label = inputs[1]
        gallery_cam = inputs[2]
    else:     
        gallery_label=nd.concat(gallery_label,inputs[1],dim=0)
        _,t_feature=net(img)
        gallery_feature=nd.concat(gallery_feature,t_feature,dim=0)
        gallery_cam=nd.concat(gallery_cam,inputs[2],dim=0)

In [None]:

result = {'gallery_f':gallery_feature.asnumpy(),
          'gallery_label':gallery_label.asnumpy(),
          'gallery_cam':gallery_cam.asnumpy(),
          'query_f':query_feature.asnumpy(),
          'query_label':query_label.asnumpy(),
          'query_cam':query_cam.asnumpy()
         }
# result={"label":gallery_label.asnumpy()}
sio.savemat('result_res.mat',result)