In [1]:
import numpy as np
import sys
import caffe
import time

In [2]:
N=1;C=3;W=80;H=160
crop_h=crop_w=0
transformer = caffe.io.Transformer({'data': (N,C,H+2*crop_h,W+2*crop_w)})
transformer.set_transpose('data', (2,0,1))
transformer.set_mean('data', np.array([ 104,  117,  123])) # mean pixel
transformer.set_raw_scale('data', 255)  # the reference model operates on images in [0,255] range instead of [0,1]
transformer.set_channel_swap('data', (2,1,0))  # the reference model has channels in BGR order instead of RGB


def readImages(images):
    imageLen=len(images)
    imageDataList=[]
    for imageIdx in range(imageLen):
        imageName=images[imageIdx]
        imageImage=transformer.preprocess('data', caffe.io.load_image(imageName))
        imageDataList.append(imageImage[:,crop_h:H+crop_h,crop_w:W+crop_w]) #center crop
        imageIdx+=1
    #imageData and imageData
    imageData=np.asarray(imageDataList)
    return imageData

def readDir(list_dir):
    import os
    file_list=os.listdir(list_dir)
    final_list=[]
    for filename in file_list:
        if filename[0]!='-' and filename[filename.rfind('.')+1:]=='jpg':
            final_list.append(list_dir+filename)
    return final_list   

In [3]:
def extract_features(file_list,net):
    file_len=len(file_list)
    features=[]
    batch_size=100
    for batch_idx in range(file_len/batch_size+1):
        cur_len=batch_size if batch_idx <file_len/batch_size else file_len%batch_size
        cur_list=file_list[batch_idx*batch_size+0:batch_idx*batch_size+cur_len]
        image_data=readImages(cur_list)
        net.blobs['data'].reshape(cur_len,C,H,W)
        net.blobs['data'].data[:] = image_data
        net.forward()
        normed_features=net.blobs['normed_feature'].data.copy()
        from sklearn.preprocessing import normalize
        for idx in range(cur_len):
            cur_feature=np.squeeze(normed_features[idx,:])
#             cur_feature = cur_feature/np.linalg.norm(cur_feature)
            features.append(cur_feature)
    return features
def get_gt_dict(gallery_list):
    gt_dict={}
    for idx in range(len(gallery_list)):
        gallery_name=gallery_list[idx]
        gallery_person_id=gallery_name[gallery_name.rfind('/')+1:gallery_name.rfind('/')+5]
        gallery_cam_id=gallery_name[gallery_name.rfind('/')+7:gallery_name.rfind('/')+8]
        if gt_dict.has_key(gallery_person_id):
            gt_dict[gallery_person_id].append(idx)
        else:
            gt_dict[gallery_person_id]=[idx]
    return gt_dict

In [4]:
def rank_for_queries(query_features,gallery_features):
    import numpy as np
    all_rank_list=[]
    for query_idx in range(len(query_features)):
        query_feature=query_features[query_idx]

        score_list=[]
        for gallery_idx in range(len(gallery_features)):
            gallery_feature=gallery_features[gallery_idx]
            dist = np.sqrt(np.sum((query_feature-gallery_feature)**2))
            similar_score=1.0/(1.0+dist)
            score_list.append(similar_score)
        #we get scoreList, then cal predictLists
        ranked_idx_list=np.argsort(score_list)[::-1]
        all_rank_list.append(ranked_idx_list)
    return all_rank_list

######################################################
##
## I use parallel to run the query in batch_num=10 batches
## In this way, one query on the 1w galleries takes 0.027s
## Multi-process does not work for ipython notebook on Windows
##
######################################################
def parallel_rank(query_features,gallery_features):
    import ipyparallel as ipp
    client = ipp.Client()
    view = client.load_balanced_view()
    batch_num=2*len(client.ids)
    batch_size_queries=len(query_features)/batch_num+1

    tic=time.time()
    task_results=[]
    for task_idx in range(batch_num):
        batch_query_features=query_features[task_idx*batch_size_queries:(task_idx+1)*batch_size_queries]
        task_results.append(view.apply(rank_for_queries,batch_query_features,gallery_features))    

    all_rank_list=[]
    for task_idx in range(batch_num):
        all_rank_list.extend(task_results[task_idx].result())
    toc=time.time()
    print len(all_rank_list),(toc-tic),(toc-tic)/len(query_features)
    return all_rank_list

def evaluate(query_list,gallery_list,all_rank_list,gt_dict):
    histogram=np.zeros(len(gallery_list))
    meanAP=0.0
    len_queries=len(query_list)
    for query_idx in range(len_queries):#
        ranked_idx_list=all_rank_list[query_idx]
        #good or junk
        query_name=query_list[query_idx]
        query_person_id=query_name[query_name.rfind('/')+1:query_name.rfind('/')+5]
        query_cam_id=query_name[query_name.rfind('/')+7:query_name.rfind('/')+8]
        relevant_idx_list=gt_dict[query_person_id]
        good_relevant=[]
        junk_relevant=[]
        for relevant_idx in relevant_idx_list:
            gallery_name=gallery_list[relevant_idx]
            gallery_cam_id=gallery_name[gallery_name.rfind('/')+7:gallery_name.rfind('/')+8]
            if gallery_cam_id==query_cam_id:
                junk_relevant.append(relevant_idx)
            else:
                good_relevant.append(relevant_idx)
        #cmc and meanAP
        matched_num=0.0
        sum_precision=0.0
        rank_idx=0
        for perdicted_idx in ranked_idx_list:
            if perdicted_idx in junk_relevant:
                continue
            elif perdicted_idx in good_relevant:
                matched_num+=1.0
                sum_precision+=matched_num/(rank_idx+1)
                histogram[rank_idx]+= 1 if matched_num<=1 else 0 #multiple results
            rank_idx+=1
            if matched_num>=len(good_relevant): #recall=1
                break
        meanAP+=sum_precision/len(good_relevant)

    cmc=np.cumsum(histogram)/len_queries
    meanAP/=len_queries
    return cmc*100.0,meanAP*100.0

def save_log(model_name,cmc,meanAP):
    log_name='log.txt'
    log_file=open(log_name,'a')
    log_str='\n\n'+'*'*65+'\n'
    log_str=log_str+ model_name+","+"%dx%d, mAP=%.2f\n"%(H,W,meanAP)
    log_str=log_str+'{}\n\n'.format(cmc[:50])
    log_str=log_str+'rank1\trank5\trank10\trank20\n'
    log_str=log_str+'%.2f\t%.2f\t%.2f\t%.2f\n'%(cmc[0],cmc[4],cmc[9],cmc[19])
    print log_str
    log_file.write(log_str)
    log_file.close()

In [5]:
def main(EXP_DIR,PRETRAINED):
    #init net
    MODEL_FILE = EXP_DIR+'test.prototxt'
    model_name=PRETRAINED[PRETRAINED.rfind('/')+1:-11]

    caffe.set_device(1)
    caffe.set_mode_gpu()
    net = caffe.Classifier(MODEL_FILE, PRETRAINED,caffe.TEST)
    
    DATA_DIR= '../../dataset/market/Market-1501/'
    query_list=readDir(DATA_DIR+'query/')
    gallery_list=readDir(DATA_DIR+'bounding_box_test/')

    gt_dict=get_gt_dict(gallery_list)
    print len(query_list),len(gallery_list),len(gt_dict),len(gt_dict['0000'])
    
    ################extract features############
    tic=time.time()
    query_features=extract_features(query_list,net)
    gallery_features=extract_features(gallery_list,net)
    toc=time.time()
    print (toc-tic),(toc-tic)/(len(query_list)+len(gallery_list))
    
    ################rank############
    all_rank_list=parallel_rank(query_features,gallery_features)
    cmc,mAP=evaluate(query_list,gallery_list,all_rank_list,gt_dict)
    save_log('Market: '+EXP_DIR+model_name,cmc,mAP)
    return cmc,mAP

In [12]:
EXP_DIR='./partnet/'
PRETRAINED = EXP_DIR+'snapshot/model_iter_50000.caffemodel'
cmc,mAP=main(EXP_DIR,PRETRAINED)

3368 15913 751 2798
57.3356001377 0.0029736839447
3368 88.2150070667 0.026192104236


*****************************************************************
Market: ./partnet/model_iter_50000,160x80, mAP=63.42
[ 80.90855107  86.31235154  88.92517815  90.82541568  91.74584323
  92.42874109  93.14133017  93.76484561  94.35866983  94.68527316
  94.804038    95.19002375  95.42755344  95.63539192  95.8432304
  96.05106888  96.25890736  96.46674584  96.55581948  96.6152019
  96.67458432  96.70427553  96.82304038  96.97149644  97.03087886
  97.06057007  97.14964371  97.26840855  97.32779097  97.32779097
  97.35748219  97.35748219  97.50593824  97.53562945  97.56532067
  97.62470309  97.62470309  97.68408551  97.74346793  97.74346793
  97.74346793  97.80285036  97.80285036  97.80285036  97.83254157
  97.89192399  97.9216152   97.95130641  97.98099762  97.98099762]

rank1	rank5	rank10	rank20
80.91	91.75	94.69	96.62

