In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib
import matplotlib.pyplot as plt
from PIL import Image
import time
import os

In [2]:
data_transforms = transforms.Compose([
        transforms.Resize((288,144), interpolation=3),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [3]:
data_dir = '/home/linshan/dataset/'
dataset_name = 'Market1501'
batch_size = 32

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir+'/'+ dataset_name + '/pytorch',x) ,data_transforms) for x in ['gallery','query']}
dataloader = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,
                                             shuffle=False, num_workers=4) for x in ['gallery','query']}

class_names = image_datasets['query'].classes

In [4]:
def load_network(network, which_epoch):
    save_path = os.path.join('./model','net_%s.pth'%which_epoch)
    network.load_state_dict(torch.load(save_path))
    return network

In [5]:
def fliplr(img):
    '''flip horizontal'''
    inv_idx = torch.arange(img.size(3)-1,-1,-1).long()  # N x C x H x W
    img_flip = img.index_select(3,inv_idx)
    return img_flip

def extract_feature(model,dataloaders):
    features = torch.FloatTensor()
    count = 0
    for data in dataloaders:
        img, label = data
        n, c, h, w = img.size()
        count += n
        print(count)
        ff = torch.FloatTensor(n,2048).zero_()
        for i in range(2):
            if(i==1):
                img = fliplr(img)
            input_img = Variable(img.cuda())
            outputs = model(input_img) 
            f = outputs.data.cpu()
            #print(f.size())
            ff = ff+f
        # norm feature
        fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
        ff = ff.div(fnorm.expand_as(ff))
        features = torch.cat((features,ff), 0)
    return features

In [6]:
def get_id(img_path):
    camera_id = []
    labels = []
    for path, v in img_path:
        filename = path.split('/')[-1]
        label = filename[0:4]
        camera = filename.split('c')[1]
        if label[0:2]=='-1':
            labels.append(-1)
        else:
            labels.append(int(label))
        camera_id.append(int(camera[0]))
    return camera_id, labels

gallery_path = image_datasets['gallery'].imgs
query_path = image_datasets['query'].imgs

gallery_cam,gallery_label = get_id(gallery_path)
query_cam,query_label = get_id(query_path)

In [7]:
from model import ft_net
model_structure = ft_net(751)
model = load_network(model_structure,60)

-------test-----------


In [8]:
# Remove the final fc layer and classifier layer
model.model.fc = nn.Sequential()
model.classifier = nn.Sequential()

In [9]:
# Change to test mode
model = model.eval()
model = model.cuda()

In [16]:
# Extract feature
gallery_feature = extract_feature(model,dataloader['gallery'])
query_feature = extract_feature(model,dataloader['query'])
gallery_feature = gallery_feature.numpy()
query_feature = query_feature.numpy()

32
64
96
128
160
192
224
256
288
320
352
384
416
448
480
512
544
576
608
640
672
704
736
768
800
832
864
896
928
960
992
1024
1056
1088
1120
1152
1184
1216
1248
1280
1312
1344
1376
1408
1440
1472
1504
1536
1568
1600
1632
1664
1696
1728
1760
1792
1824
1856
1888
1920
1952
1984
2016
2048
2080
2112
2144
2176
2208
2240
2272
2304
2336
2368
2400
2432
2464
2496
2528
2560
2592
2624
2656
2688
2720
2752
2784
2816
2848
2880
2912
2944
2976
3008
3040
3072
3104
3136
3168
3200
3232
3264
3296
3328
3360
3392
3424
3456
3488
3520
3552
3584
3616
3648
3680
3712
3744
3776
3808
3840
3872
3904
3936
3968
4000
4032
4064
4096
4128
4160
4192
4224
4256
4288
4320
4352
4384
4416
4448
4480
4512
4544
4576
4608
4640
4672
4704
4736
4768
4800
4832
4864
4896
4928
4960
4992
5024
5056
5088
5120
5152
5184
5216
5248
5280
5312
5344
5376
5408
5440
5472
5504
5536
5568
5600
5632
5664
5696
5728
5760
5792
5824
5856
5888
5920
5952
5984
6016
6048
6080
6112
6144
6176
6208
6240
6272
6304
6336
6368
6400
6432
6464
6496
6528
6560
6592
6624

In [36]:
print(gallery_feature.shape)
print(query_feature.shape)

(19732, 2048)
(3368, 2048)


In [40]:
temp = np.dot(query_feature,gallery_feature.transpose())
print(temp.shape)

(3368, 19732)


In [46]:
index = np.argsort(temp[0])
print(index)
print(index[::-1])

[6574 4163 5703 ..., 6640 6620 6617]
[6617 6620 6640 ..., 5703 4163 6574]


In [66]:
query_index = np.argwhere(gallery_label == query_label[0])
print(query_index.shape)
camera_index = np.argwhere(gallery_cam == query_cam[0])
print(camera_index.shape)
good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
print(good_index.shape)
junk_index1 = np.argwhere(gallery_label==-1)
junk_index2 = np.intersect1d(query_index, camera_index)
junk_index = np.append(junk_index2, junk_index1)

(59, 1)
(3156, 1)
(51,)


In [17]:
def evaluate(qf,ql,qc,gf,gl,gc):
    query = qf
    score = np.dot(gf,query)
    # predict index
    index = np.argsort(score)  #from small to large
    index = index[::-1]
    #index = index[0:2000]
    # good index
    query_index = np.argwhere(gl==ql)
    camera_index = np.argwhere(gc==qc)

    good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
    junk_index1 = np.argwhere(gl==-1)
    junk_index2 = np.intersect1d(query_index, camera_index)
    junk_index = np.append(junk_index2, junk_index1) #.flatten())
    
    CMC_tmp = compute_mAP(index, good_index, junk_index)
    return CMC_tmp


def compute_mAP(index, good_index, junk_index):
    ap = 0
    cmc = torch.IntTensor(len(index)).zero_()
    if good_index.size==0:   # if empty
        cmc[0] = -1
        return ap,cmc

    # remove junk_index
    mask = np.in1d(index, junk_index, invert=True)
    index = index[mask]

    # find good_index index
    ngood = len(good_index)
    mask = np.in1d(index, good_index)
    rows_good = np.argwhere(mask==True)
    rows_good = rows_good.flatten()
    
    cmc[rows_good[0]:] = 1
    for i in range(ngood):
        d_recall = 1.0/ngood
        precision = (i+1)*1.0/(rows_good[i]+1)
        if rows_good[i]!=0:
            old_precision = i*1.0/rows_good[i]
        else:
            old_precision=1.0
        ap = ap + d_recall*(old_precision + precision)/2

    return ap, cmc

In [18]:
CMC = torch.IntTensor(len(gallery_label)).zero_()
ap = 0.0
for i in range(len(query_label)):
    ap_tmp, CMC_tmp = evaluate(query_feature[i],query_label[i],query_cam[i],gallery_feature,gallery_label,gallery_cam)
    if CMC_tmp[0]==-1:
        continue
    CMC = CMC + CMC_tmp
    ap += ap_tmp
    print(i, CMC_tmp[0])

CMC = CMC.float()
CMC = CMC/len(query_label) #average CMC
print('top1:%f top5:%f top10:%f mAP:%f'%(CMC[0],CMC[4],CMC[9],ap/len(query_label)))

0 1
1 0
2 1
3 1
4 1
5 1
6 1
7 1
8 1
9 1
10 1
11 1
12 1
13 1
14 1
15 1
16 1
17 1
18 1
19 1
20 1
21 1
22 1
23 1
24 1
25 1
26 1
27 0
28 0
29 0
30 1
31 0
32 1
33 0
34 0
35 1
36 0
37 0
38 1
39 1
40 1
41 1
42 1
43 1
44 1
45 1
46 1
47 1
48 1
49 1
50 1
51 0
52 1
53 1
54 1
55 1
56 1
57 1
58 1
59 1
60 0
61 1
62 1
63 0
64 0
65 0
66 1
67 1
68 1
69 0
70 1
71 0
72 1
73 1
74 1
75 0
76 0
77 0
78 0
79 1
80 1
81 0
82 1
83 1
84 1
85 1
86 1
87 1
88 1
89 1
90 1
91 1
92 0
93 1
94 1
95 1
96 1
97 1
98 0
99 1
100 0
101 0
102 0
103 0
104 0
105 0
106 1
107 1
108 1
109 1
110 1
111 1
112 1
113 1
114 1
115 1
116 1
117 1
118 1
119 0
120 1
121 1
122 1
123 1
124 1
125 1
126 1
127 0
128 1
129 1
130 1
131 1
132 1
133 1
134 1
135 0
136 1
137 1
138 1
139 0
140 1
141 1
142 0
143 0
144 0
145 0
146 0
147 1
148 0
149 1
150 1
151 1
152 0
153 1
154 0
155 1
156 0
157 0
158 0
159 1
160 1
161 1
162 1
163 1
164 0
165 1
166 1
167 1
168 0
169 1
170 1
171 1
172 1
173 1
174 1
175 1
176 1
177 1
178 1
179 0
180 1
181 1
182 1
183 1
184 0


1335 1
1336 1
1337 1
1338 1
1339 0
1340 0
1341 1
1342 1
1343 1
1344 1
1345 1
1346 1
1347 1
1348 1
1349 1
1350 1
1351 1
1352 1
1353 1
1354 1
1355 0
1356 1
1357 1
1358 1
1359 1
1360 0
1361 0
1362 1
1363 1
1364 0
1365 1
1366 1
1367 1
1368 1
1369 1
1370 1
1371 1
1372 1
1373 1
1374 1
1375 1
1376 1
1377 1
1378 1
1379 1
1380 1
1381 1
1382 1
1383 1
1384 1
1385 1
1386 1
1387 1
1388 1
1389 1
1390 1
1391 1
1392 0
1393 1
1394 1
1395 1
1396 1
1397 1
1398 1
1399 1
1400 1
1401 1
1402 0
1403 0
1404 0
1405 1
1406 1
1407 1
1408 1
1409 1
1410 1
1411 1
1412 1
1413 1
1414 1
1415 1
1416 1
1417 1
1418 1
1419 1
1420 1
1421 1
1422 1
1423 1
1424 1
1425 1
1426 1
1427 1
1428 1
1429 1
1430 1
1431 1
1432 1
1433 1
1434 1
1435 1
1436 1
1437 1
1438 1
1439 1
1440 1
1441 1
1442 1
1443 1
1444 1
1445 1
1446 1
1447 1
1448 1
1449 0
1450 1
1451 1
1452 1
1453 1
1454 1
1455 1
1456 1
1457 1
1458 1
1459 1
1460 0
1461 1
1462 1
1463 1
1464 0
1465 1
1466 1
1467 1
1468 1
1469 1
1470 1
1471 1
1472 1
1473 1
1474 1
1475 0
1476 1
1477 1

2509 1
2510 1
2511 1
2512 1
2513 1
2514 1
2515 1
2516 1
2517 1
2518 1
2519 0
2520 1
2521 1
2522 1
2523 0
2524 1
2525 1
2526 1
2527 1
2528 1
2529 1
2530 1
2531 1
2532 1
2533 1
2534 1
2535 1
2536 1
2537 1
2538 1
2539 1
2540 1
2541 1
2542 1
2543 1
2544 1
2545 0
2546 1
2547 1
2548 1
2549 1
2550 1
2551 1
2552 1
2553 1
2554 1
2555 1
2556 1
2557 1
2558 1
2559 1
2560 1
2561 1
2562 1
2563 1
2564 1
2565 1
2566 1
2567 1
2568 1
2569 1
2570 1
2571 1
2572 1
2573 1
2574 1
2575 1
2576 1
2577 1
2578 1
2579 1
2580 1
2581 1
2582 0
2583 1
2584 1
2585 1
2586 1
2587 1
2588 1
2589 1
2590 1
2591 1
2592 1
2593 1
2594 1
2595 1
2596 1
2597 1
2598 1
2599 1
2600 1
2601 1
2602 0
2603 0
2604 1
2605 1
2606 1
2607 1
2608 1
2609 1
2610 1
2611 1
2612 1
2613 1
2614 1
2615 1
2616 1
2617 1
2618 1
2619 1
2620 1
2621 1
2622 0
2623 1
2624 1
2625 0
2626 1
2627 1
2628 1
2629 1
2630 1
2631 1
2632 1
2633 1
2634 1
2635 1
2636 1
2637 1
2638 1
2639 1
2640 1
2641 0
2642 1
2643 0
2644 1
2645 1
2646 1
2647 1
2648 1
2649 1
2650 1
2651 1