In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
import sys
import os
import cv2
import csv
import time
import math
import random
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms as Transforms
import torch.nn.functional as F
import timm
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

In [None]:
transformer = Transforms.Compose([
            Transforms.Resize((224, 224)),
            Transforms.ToTensor(),
            Transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
train_dir = 'test_data/Market-1501/bounding_box_train'  #训练数据地址
query_dir = 'test_data/Market-1501/query'  #query地址
test_dir = 'test_data/Market-1501/bounding_box_test'   #gallery数据地址

# 递归删除指定目录下的.ipynb_checkpoints文件夹
def remove_ipynb_checkpoints(root_folder):
    for root, dirs, files in os.walk(root_folder):
        for dir in dirs:
            if dir == ".ipynb_checkpoints":
                folder_path = os.path.join(root, dir)
                shutil.rmtree(folder_path)
                print(f"Deleted: {folder_path}")

def transform_img1(img_path): #单张处理图片使用
    # 读入的图像数据格式是[H, W, C]
    # 使用转置操作将其变成[C, H, W]
    img = Image.open(img_path).convert('RGB')
    img = transformer(img)
    return img

# 定义训练集数据读取器
def data_loader(imgs_name,train_class_dict,path_dir=train_dir, batch_size=10, mode = 'train'):
    # 将datadir目录下的文件列出来，每条文件都要读入
    '''
    imgs_name: 所有图片数据的地址list
    train_class_dict： 训练数据的分类号字典
    path_dir： 数据集路径
    mode： 读取数据模式，如果train，会打乱数据序号
    '''
    def reader():
        if mode == 'train':
            random.shuffle(imgs_name)# 训练时随机打乱数据顺序
        batch_imgs = []
        batch_labels = []
        for name in imgs_name:
            filepath = os.path.join(path_dir, name) #合成数据的绝对地址
            img = transform_img1(filepath) # 读取一张图片128*64, 对图片尺寸作出改动
            label=train_class_dict[name[:4]] #读取图片对应label 0-750
            # 每读取一个样本的数据，就将其放入数据列表中
            batch_imgs.append(img)
            batch_labels.append(torch.tensor(label))
            if len(batch_imgs) == batch_size:
                # 当数据列表的长度等于batch_size的时候，
                # 把这些数据当作一个mini-batch，并作为数据生成器的一个输出
                imgs_array = torch.stack(batch_imgs) 
                labels_array = torch.stack(batch_labels).reshape(-1, 1) 
                #images_aug = seq.augment_images(imgs_array)
                #images_aug=transform_img(imgs_array)
                #print(images_aug.shape)
                yield imgs_array, labels_array #shape: [batch_size,3,224,224],[batch_size,1]
                batch_imgs = []
                batch_labels = []

        if len(batch_imgs) > 0:
            # 剩余样本数目不足一个batch_size的数据，一起打包成一个mini-batch
            imgs_array = torch.stack(batch_imgs)
            labels_array = torch.stack(batch_labels).reshape(-1, 1)
            #images_aug = seq.augment_images(imgs_array)
            #images_aug=transform_img(imgs_array)
            #print(images_aug.shape)
            yield imgs_array, labels_array #shape: [batch_size,3,224,224],[batch_size,1]

    return reader

In [None]:
class SwinTransformerTeacher(nn.Module):
    def __init__(self, num_features=512):
        super(SwinTransformerTeacher, self).__init__()
        self.model = timm.create_model('swin_base_patch4_window7_224')
        self.num_features = num_features
        self.feat = nn.Linear(1024, num_features) if num_features > 0 else None

    def extract_feat(self, x):
        # 创建一个空列表，用于保存各层的输出特征
        features = []
        
        patch_embed = self.model.patch_embed  # Patch Embedding 层
        pos_drop = self.model.pos_drop
        layers = self.model.layers  # 基本层（包含多个 SwinBlock）
        
        x = patch_embed(x)  # Patch Embedding
        x = pos_drop(x)
        for layer in layers:  # 逐个通过 BasicLayer
            # x = layer(x)
            # features.append(x)
            for block in layer.blocks:
                x = block(x)
            features.append(x)
            if layer.downsample is not None:
                x = layer.downsample(x)
        return tuple(features)

    def forward_specific_stage(self, x, stage, down_sample=True):
        BS, L, C = x.shape

        if stage == 2:
            if down_sample:
                x = self.model.layers[-4].downsample(x)

            for block in self.model.layers[-3].blocks:
                x = block(x)

        if stage == 3:
            if down_sample:
                x = self.model.layers[-3].downsample(x)

            for block in self.model.layers[-2].blocks:
                x = block(x)

        if stage == 4:
            if down_sample:
                x = self.model.layers[-2].downsample(x)

            for block in self.model.layers[-1].blocks:
                x = block(x)

            norm_layer = self.model.norm
            x = norm_layer(x)

        return x
        
    def forward_features(self, x):
        x = self.model.forward_features(x)
        return x

    def forward(self, x):
        x = self.model.forward_features(x)
        if not self.feat is None:
            x = self.feat(x)
        return x

class ResNetStudent(nn.Module):
    def __init__(self, num_features=512):
        super(ResNetStudent, self).__init__()
        self.model = timm.create_model('resnet50', pretrained=True)  # 使用ResNet-50作为学生模型
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten_dim = 2048
        self.feat = nn.Linear(self.flatten_dim, num_features) if num_features > 0 else None
    
    def extract_feat(self, x):
        # 创建一个空列表，用于保存各层的输出特征
        features = []
        
        # 提取每个阶段的层
        conv1 = self.model.conv1  # 初始卷积层
        bn1 = self.model.bn1
        act1 = self.model.act1
        maxpool = self.model.maxpool
        layer1 = self.model.layer1  # 第一阶段（残差块1）
        layer2 = self.model.layer2  # 第二阶段（残差块2）
        layer3 = self.model.layer3  # 第三阶段（残差块3）
        layer4 = self.model.layer4  # 第四阶段（残差块4）
        
        x = conv1(x)
        x = bn1(x)
        x = act1(x)
        x = maxpool(x)
        stage1_out = layer1(x)  # 第一阶段的输出
        features.append(stage1_out)
        stage2_out = layer2(stage1_out)  # 第二阶段的输出
        features.append(stage2_out)
        stage3_out = layer3(stage2_out)  # 第三阶段的输出
        features.append(stage3_out)
        stage4_out = layer4(stage3_out)  # 第四阶段的输出
        features.append(stage4_out)
        return tuple(features)
        
    def forward_features(self, x):
        x = self.model.forward_features(x)
        return x

    def forward(self, x):
        x = self.model.forward_features(x)
        # 池化操作，[batch_size, 2048, 7, 7] -> [batch_size, 2048, 1, 1]
        x = self.gap(x)
        # 展平特征图，将其变为 [batch_size, 2048 * 1 * 1]
        x = x.view(x.size(0), -1)  # 展平
        if not self.feat is None:
            x = self.feat(x)
        return x

swin_model = SwinTransformerTeacher(num_features=512).cuda()
swin_model.eval()

resnet_model = ResNetStudent(num_features=512).cuda()
resnet_model.eval()

In [None]:
# swin_transformer加载预训练权重
swin_weight_path = '/homec/xiaolei/projects/ISR/weights/swin_base_patch4_window7_224.pth'
swin_weight = torch.load(swin_weight_path)
swin_model.load_state_dict(swin_weight['state_dict'], strict=True)

# 残差网络加载预训练权重
resnet_weight_path = '/homec/xiaolei/projects/ReID/weights/student_model_base_marke11501/best_student_model.pth'
resnet_weight = torch.load(resnet_weight_path)
resnet_model.load_state_dict(resnet_weight, strict=True)

In [None]:
#定义变量
train_image_id = 1                  #图片计数
val_image_id = 1
super_class_id = 1                  #始终保持1
train_class_id = 0                  #行人分类号
train_class_dict = {}               #新行人分类号与原分类号的对应dict
train_total = []                    #图片信息总存储list，train_image_id, train_class_dict[train_person_id], super_class_id, train_path
val_total = []

train_iden_pathdic=defaultdict(list)               #以id分类存储训练集行人的路径    [0:["path1","path2"...],...]

train_imgs_name = os.listdir(train_dir) #读取地址下的图片名,返回str list
#print('Before sort train_imgs__List 5: ',train_imgs_name[:5]) #查看前五项
train_imgs_name.sort() #排序
#print('After sort train_imgs__List 5: ',train_imgs_name[:5]) #查看前五项

#数据路径处理
for img in train_imgs_name:
    if 'jpg' or 'png' in img:
        train_person_id = img[:4] #人物数字标号
        if train_person_id not in train_class_dict:
            train_class_dict[train_person_id] = train_class_id #对人物标号的自然计数  0，1，2，3...
            train_iden_pathdic[train_class_id].append(img)     #以id分类存储训练集行人的路径[0:["path1","path2"...],...]
            train_class_id += 1 #分类号递增
        else:
            train_iden_pathdic[train_class_id-1].append(img)   #以id分类存储训练集行人的路径
print('最大train_class_id = ',train_class_id-1) #显示分类数

#数据分割
random.shuffle(train_imgs_name) #打乱名称次序
img_num_99=int(len(train_imgs_name)*0.99) #选取前99%作为训练集，后1%为验证集
val_imgs_name=train_imgs_name[img_num_99:]
#train_imgs_name=train_imgs_name[:img_num_99]

train_imgs_name.sort() #训练数据集排序
val_imgs_name.sort() #验证数据集排序

#构建line字符串，为TXT文件生成做准备，TXT文件方便查看数据
for img in train_imgs_name:
    if 'jpg' or 'png' in img:
        train_person_id = img[:4]
        train_path = 'bounding_box_train/'+img #获取图片地址
        line = '%s %s %s %s\n' % (train_image_id, train_class_dict[train_person_id], super_class_id, train_path) #保留数据char，例如1001 66 1 gt_bbox/0066_c4s1_008826_00.jpg
        train_image_id += 1
        train_total.append(line) #逐条将数据加入到list里面

#print(train_total[:2])

for img in val_imgs_name:
    if 'jpg' or 'png' in img:
        val_person_id = img[:4]
        val_path = 'bounding_box_train/'+img#获取图片地址
        line = '%s %s %s %s\n' % (val_image_id, train_class_dict[val_person_id], super_class_id, val_path) #保留数据char，例如1001 66 1 gt_bbox/0066_c4s1_008826_00.jpg
        val_image_id += 1
        val_total.append(line)
#print(val_total[:2])
print('分类总数： ',len(train_class_dict.values())) #查看训练数据的类别数

In [None]:
#query数据集的数据处理
query_class_dict={}
query_class_id = 0
query_total=[]

query_image_id=1
query_names=os.listdir(query_dir)
query_names.sort()
camera_class_id = 0
camera_class_dict = {'1':0,'2':1,'3':2,'4':3,'5':4,'6':5}

query_path_list={} #存储query和gallery的读取地址
gallery_path_list={}
query_cam=[]
gallery_cam=[]

qg_iden_pathdic=defaultdict(list)

for img in query_names:
    if 'jpg' or 'png' in img:
        query_cam.append(camera_class_dict[img[6]])


for img in query_names:
    if 'jpg' or 'png' in img:
        query_person_id = img[:4]#人物数字标号
        if query_person_id not in query_class_dict:
            query_class_dict[query_person_id] = query_class_id#对人物标号的自然计数
            qg_iden_pathdic[query_class_id].append(query_dir+'/'+img)
            query_class_id += 1
        else:
            qg_iden_pathdic[query_class_id-1].append(query_dir+'/'+img)            

print('query_class_id = ',query_class_id-1)#显示分类数 750

for img in query_names:
    if 'jpg' or 'png' in img:
        query_person_id = img[:4]
        query_path = 'query/'+img#获取图片地址
        line = '%s %s %s %s\n' % (query_image_id, query_class_dict[query_person_id], super_class_id, query_path)#保留数据char，例如1001 66 1 gt_bbox/0066_c4s1_008826_00.jpg
        query_path_list[query_image_id-1]=query_path
        query_image_id += 1
        query_total.append(line)
        
#gallery集数据处理
test_imgs_name = os.listdir(test_dir)
test_imgs_name.sort()
test_names=[]
test_total=[]
test_image_id=1

for img in test_imgs_name:
    if 'jpg' or 'png' in img:
        test_person_id = img[:4]#人物数字标号
        if test_person_id in query_class_dict:
            test_names.append(img)

for img in test_names:
    if 'jpg' or 'png' in img:
        gallery_cam.append(camera_class_dict[img[6]])

for img in test_names:
    if 'jpg' or 'png' in img:
        test_person_id = img[:4]
        qg_iden_pathdic[query_class_dict[test_person_id]].append(test_dir+'/'+img)
        test_path = 'bounding_box_test/'+img#获取图片地址
        line = '%s %s %s %s\n' % (test_image_id, query_class_dict[test_person_id], super_class_id, test_path)#保留数据char，例如1001 66 1 gt_bbox/0066_c4s1_008826_00.jpg
        test_image_id += 1
        test_total.append(line)

with open('test_data/Market-1501/query.txt', 'w', encoding='UTF-8') as f:
    f.write('image_id class_id super_class_id path\n')
    for line in query_total:
        f.write(line)#将gallery数据写入txt文件 经过筛选得到13115张数据图片


with open('test_data/Market-1501/gallery.txt', 'w', encoding='UTF-8') as f:
    f.write('image_id class_id super_class_id path\n')
    for line in test_total:
        f.write(line)#将gallery数据写入txt文件 经过筛选得到13115张数据图片

#定义数据读取reader
query_loder=data_loader(query_names[:-1],query_class_dict, query_dir,batch_size=32, mode = 'val')
gallery_loder=data_loader(test_names[:-1],query_class_dict, test_dir,batch_size=32, mode = 'val')
QUERY_NAMES_LEN=len(query_names)-1
TEST_NAMES_LEN=len(test_names)-1
query_cam = np.array(query_cam)
gallery_cam = np.array(gallery_cam)

In [None]:
# 8.特征提取函数定义
### 1.定义计算feature函数
def get_feature(data_loder,model):
    idex_feture={}
    idex_label={}
    for idex,data in enumerate(data_loder):
        img,label=data
        with torch.no_grad():
            img=torch.tensor(img).to(device)
        feature512 = model(img)
        batch_lenn = len(label)
        for i in range(batch_lenn):
            idex_feture[idex*batch_lenn+i]=feature512[i].cpu().detach().numpy()
            idex_label[idex*batch_lenn+i]=label[i]
    return idex_feture,idex_label

### 2.定义根据特征feature计算score并排序的函数
def rank_score(query_feature,gallery_feature):
    query_gallery_score={}
    gallery_idex_score={}
    gallery_feature_values=[]
    query_feature_values=query_feature.values()
    temp=gallery_feature.values()    
    for features in temp:
        gallery_feature_values.append(features)
    gallery_feature_values=np.array(gallery_feature_values)
    #print('gallery_feature_values: ',gallery_feature_values.shape)
    j=0
    print('>>>>Score geting&sorting......')
    for query in query_feature_values:
        #print('query = ',query)
        #query_feature=torch.to_tensor(query)
        #print(gallery_feature)
        
        #dist = np.sqrt(np.sum((query-gallery_feature_values)**2,axis=1))
        dist=np.dot(gallery_feature_values,query)
        #print('dist: ',dist.shape)
        #similar_score=1.0/(1.0+dist)
        similar_score=dist
        #print('similar_score: ',similar_score.shape)
        for i in range(len(gallery_feature_values)):
            gallery_idex_score[i]=similar_score[i]
        #更新代码
        idex_score_sorted=sorted(gallery_idex_score.items(),key=lambda x:x[1],reverse=True)#降序按value排序
        query_gallery_score[j]=idex_score_sorted
        j+=1
    return query_gallery_score

### 3.计算feture、标签索引合集
# 定义计算rank过程
def rank_total(model):
    print('>>>>starting maping')
    #print('>>>>Get query_feature......')
    query_idex_feature,query_idex_label=get_feature(query_loder(),model)
    #print('>>>>Get gallery_feature......')
    gallery_idex_feature,gallery_idex_label=get_feature(gallery_loder(),model)
    #print('shape= ',gallery_idex_feature[0])
    return query_idex_feature,gallery_idex_feature,query_idex_label,gallery_idex_label
    #query_gallery_score=rank_score(query_idex_feature,gallery_idex_feature)

    #return query_gallery_score,query_idex_label,gallery_idex_label#query_gallery_score[0][0]=(0, array([0.05808292], dtype=float32))
    #print(query_gallery_score[0][0],query_idex_label[0][0][0],gallery_idex_label[0][0][0])===>>(0, array([0.05808292], dtype=float32)) 1 1

### 4.计算新的score，用于rank——k计算
def query_gallery_score_to1(query_gallery_score,query_idex_label,gallery_idex_label):
    query_gallery_score_list=np.zeros([QUERY_NAMES_LEN,TEST_NAMES_LEN])
    len_query_idex_label=len(query_idex_label.keys())
    len_gallery_idex_label=len(gallery_idex_label.keys())
    print('>>>>Fit_score_1......')
    #print(query_gallery_score)
    for i in range(len_query_idex_label):
        for j in range(len_gallery_idex_label):
            if gallery_idex_label[query_gallery_score[i][j][0]] == query_idex_label[i][0]:
                query_gallery_score_list[i][j]=1.0
    return query_gallery_score_list


### 计算rank——k
def rank_k(k,query_gallery_score_list):
    len_query=query_gallery_score_list.shape[0]
    rank_k=[]
    for i in range(len_query):
        if np.sum(query_gallery_score_list[i,:k])>=1:
            rank_k.append(1)
    rank_k=np.sum(rank_k)/len_query
    return rank_k
#rank=rank_k(10,query_gallery_score_list)

#绘制cmc曲线
def draw_cmc(x,query_gallery_score_list):
    x_value=[]
    for i in range(x):
        print('>>>>rank_',i,'  ',rank_k(i,query_gallery_score_list))
        x_value.append(rank_k(i,query_gallery_score_list))
    plt.plot(range(x),x_value)

#计算mAP
def mAP_cul(query_gallery_score_list):
    print('>>>>Starting calcu mAP......')
    len_query=query_gallery_score_list.shape[0]
    len_gallery=query_gallery_score_list.shape[1]
    gallery_map=[]
    query_map=[]
    for i in tqdm(range(len_query)):
        true_num=0
        for j in range(len_gallery):
            if query_gallery_score_list[i,j]:
                true_num+=1
                gallery_map.append(true_num/(j+1.0))

        query_map.append(np.mean(gallery_map))
    mAP=np.mean(query_map)
    return mAP

#图像拼接函数
def image_compose_qg(qlist,glist,rows_good_list,ROW_SIZE=64,COLUMN_SIZE=128,IMAGE_COLUMN=16):
    '''
    qlist : query 索引 [1,2,3]
    glist : gallery 索引 [1,2,3,4] 图像会先拼接query再拼接gallery
    ROW_SIZE=64,COLUMN_SIZE=128，拼接的图像像素
    IMAGE_COLUMN ：水平拼接张数
    '''
    img_names=[]
    for i in qlist:
        img_names.append(query_dir+'/'+query_names[i])
    #for i in glist:
        #img_names.append(test_dir+'/'+test_names[i])

    ###
    for i in glist:
        img_names.append(test_dir+'/'+test_names[i])
    ###
    img_len = len(img_names)
    IMAGE_ROW = int(img_len/IMAGE_COLUMN+1)
    to_image = Image.new('RGB', (int(IMAGE_COLUMN * COLUMN_SIZE/2),IMAGE_ROW * ROW_SIZE*2),(255,255,255)) #创建一个新图
    print('glist: ',glist)
    print('rows_good_list: ',rows_good_list)
    for y in range(0, int(img_len/16+1) ):
        x_len = img_len-16*y
        for x in range(0,x_len):
            img_name = img_names[y*IMAGE_COLUMN+x]
            if (y*IMAGE_COLUMN+x-1) in rows_good_list:
                color = (0,255,0)
            else:
                color = (255,0,0)
            from_image = np.array(draw_rectange(img_name,color))
            print(img_name,color)         
            from_image = Image.fromarray(from_image.astype('uint8')).convert('RGB')
            to_image.paste(from_image, (x *ROW_SIZE , y * COLUMN_SIZE,(x+1) *ROW_SIZE , (y+1) * COLUMN_SIZE))
    print('拼接图片成功！')   
    plt.figure(figsize=[25,10])
    plt.imshow(to_image)

#返回gallery中和query_num相同标签的索引
def get_same_label(query_num,gallery_idex_label):
    '''
    query_num ：第几张query
    gallery_idex_label ：gallery所有label
    '''
    gallery_label = np.array(list(gallery_idex_label.values()))
    gallery_index = np.argwhere(query_num==gallery_label)
    return gallery_index.flatten()

#图片填充函数
def draw_rectange(img_name,color): #(255,0,0)red (0,255,0)green
    img = Image.open(img_name)
    draw = ImageDraw.Draw(img) # 在上面画画
    draw.rectangle([2,2,61,125], outline=color) # [左上角x，左上角y，右下角x，右下角y]，outline边框颜色
    return img

# query_idex_feature,gallery_idex_feature,query_idex_label,gallery_idex_label=rank_total(resnet_model)
query_idex_feature,gallery_idex_feature,query_idex_label,gallery_idex_label=rank_total(swin_model)
query_gallery_score=rank_score(query_idex_feature,gallery_idex_feature)
query_gallery_score_list=query_gallery_score_to1(query_gallery_score,query_idex_label,gallery_idex_label)

draw_cmc(15,query_gallery_score_list) #绘图
mAP=mAP_cul(query_gallery_score_list) #求map
print('>>>>mAP = ',mAP)

In [None]:
# 9.mAP计算函数定义
def evaluate(qf,ql,qc,gf,gl,gc):
    query = torch.tensor(qf)
    #query = torch.reshape(query,[1,-1])
    gallery = torch.tensor(gf) 
    score = torch.sqrt(torch.sum(torch.square(query-gallery),axis= 1)).numpy()  #欧氏距离得分
    #print(query.shape,gallery.shape)
    #score = torch.matmul(query,gallery,transpose_y=True).numpy().flatten()
    #print(score.shape)
    #print(score[:20])
    # predict index
    index = np.argsort(score)  #from small to large  gallery按得分排序的索引 
    score_list.append(index)   #query的排序list
    #print('index_',index[:40])
    #index = index[::-1]
    #index = index[0:2000]
    # good index
    query_index = np.argwhere(gl==ql)   #query和gallery的标签相同的索引 gallery的索引
    query_index = query_index[:,0]
    camera_index = np.argwhere(gc==qc)  #gallery中和query camera相同的索引  

    good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)   #找到两个list中的差异，保持顺序的情况下去除query_index中的camera_index
    #print('query_index:',query_index)
    #print('camera_index:',camera_index[:])
    #print('good_index:',good_index)
    junk_index1 = np.argwhere(gl==-1)
    junk_index2 = np.intersect1d(query_index, camera_index)  #查找两个数组中相同的值  
    junk_index = np.append(junk_index2, junk_index1).flatten()   #整理所有junk index
    
    CMC_tmp = compute_mAP(index, good_index, junk_index)
    return CMC_tmp


def compute_mAP(index, good_index, junk_index):
    ap = 0
    cmc = torch.zeros(shape=[len(index)])
    #cmc = torch.IntTensor(len(index)).zero_()
    if good_index.size==0:   # if empty
        cmc[0] = -1
        return ap,cmc
    #print('index ',index[:20],'good_index',good_index[:10])
    #print('junk_index',junk_index[:20])

    # remove junk_index
    mask = np.in1d(index, junk_index, invert=True) 
    #test = [0, 1, 2, 5, 0],states = [0, 2],mask = np.in1d(test, states, invert=True),[False, True, False, True, False]
    index = index[mask]
    index_remove_junkindex.append(index)
    #print('index_remove_junkindex:',index[:60])
    good_index_list.append(index)  #用于后续查看的数据
    #print('index: ',index)

    # find good_index index
    ngood = len(good_index)
    mask = np.in1d(index, good_index)  
    rows_good = np.argwhere(mask==True) #score gallery中去除junkindex后的索引 
    rows_good = rows_good.flatten()
    #print('rows_good',rows_good)
    rows_good_list.append(rows_good)
    
    cmc[rows_good[0]:] = 1
    for i in range(ngood):
        d_recall = 1.0/ngood
        precision = (i+1)*1.0/(rows_good[i]+1)
        if rows_good[i]!=0:
            old_precision = i*1.0/rows_good[i]
        else:
            old_precision=1.0
        #ap = ap + d_recall*(old_precision + precision)/2
        ap= ap+precision

    return ap/(len(rows_good)), cmc


def rank_map(rows_good_list=[],good_index_list=[],score_list=[],index_remove_junkindex=[],ap=[]):
######################################################################赋初值
    query_idex_feature,gallery_idex_feature,query_idex_label,gallery_idex_label=rank_total(model)
    query_feature = np.array(list(query_idex_feature.values()))
    #query_cam = np.array(query_cam)
    query_label = np.array(list(query_idex_label.values()))
    gallery_feature = np.array(list(gallery_idex_feature.values()))
    #gallery_cam = np.array(gallery_cam)
    gallery_label = np.array(list(gallery_idex_label.values()))
#######################################################################

# Evaluate
######################################################################
    CMC = torch.zeros(shape=[len(gallery_label)])    
    #CMC = torch.IntTensor(len(gallery_label)).zero_()

    #print(query_label)
    for i in range(len(query_label)):
        ap_tmp, CMC_tmp = evaluate(query_feature[i],query_label[i],query_cam[i],gallery_feature,gallery_label,gallery_cam)
        if CMC_tmp[0]==-1:
            continue
        CMC = CMC + CMC_tmp
        ap.append(ap_tmp)
        #print(i, CMC_tmp[0])
    CMC = CMC/len(query_label) #average CMC
    print('Rank@1:%f Rank@5:%f Rank@10:%f mAP:%f'%(CMC[0],CMC[4],CMC[9],np.mean(ap)))

In [None]:
import torch
from torchvision import transforms
from PIL import Image

# 读取图像
img = Image.open('test_data/reid/20-15-40_0.jpg')

# 定义数据增强的操作
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # 随机水平翻转
    transforms.RandomVerticalFlip(),    # 随机垂直翻转
    transforms.RandomRotation(30),      # 随机旋转30度
    transforms.RandomResizedCrop(224),  # 随机裁剪并调整为224x224
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),  # 随机调整亮度、对比度等
    transforms.ToTensor(),              # 将图像转换为张量
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 归一化
])

# 应用增强操作
img_tensor = transform(img)

print(img_tensor.shape)