In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [2]:
import torch
import torch.nn as nn
import timm
from torchvision import transforms as Transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader
import os
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import glob
from tqdm import tqdm

from models.backbones.resnet_student import ResNetStudent
from models.dis_losses.fmdv2 import AttentionProjector

from data_provider.datasets.market1501 import Market1501
from data_provider.datasets.cuhk03 import CUHK03
from data_provider.datasets.msmt17 import MSMT17
from data_provider.datasets.dukemtmcreid import DukeMTMCreID
from data_provider.datasets.custom_data import CustomReid
from data_provider.datasets import ImageDataset
from data_provider.collate_batch import val_collate_fn
from utils.reid_metric import R1_mAP, R1_mAP_reranking

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# 递归删除指定目录下的.ipynb_checkpoints文件夹
def remove_ipynb_checkpoints(root_folder):
    for root, dirs, files in os.walk(root_folder):
        for dir in dirs:
            if dir == ".ipynb_checkpoints":
                folder_path = os.path.join(root, dir)
                shutil.rmtree(folder_path)
                print(f"Deleted: {folder_path}")

class SwinTransformerTeacher(nn.Module):
    def __init__(self, num_features=512):
        super(SwinTransformerTeacher, self).__init__()
        self.model = timm.create_model('swin_base_patch4_window7_224')
        self.num_features = num_features
        self.feat = nn.Linear(1024, num_features) if num_features > 0 else None

    def extract_feat(self, x):
        # 创建一个空列表，用于保存各层的输出特征
        features = []
        
        patch_embed = self.model.patch_embed  # Patch Embedding 层
        pos_drop = self.model.pos_drop
        layers = self.model.layers  # 基本层（包含多个 SwinBlock）
        
        x = patch_embed(x)  # Patch Embedding
        x = pos_drop(x)
        for layer in layers:  # 逐个通过 BasicLayer
            # x = layer(x)
            # features.append(x)
            for block in layer.blocks:
                x = block(x)
            features.append(x)
            if layer.downsample is not None:
                x = layer.downsample(x)
        return tuple(features)

    def forward_specific_stage(self, x, stage, down_sample=True):
        BS, L, C = x.shape

        if stage == 2:
            if down_sample:
                x = self.model.layers[-4].downsample(x)

            for block in self.model.layers[-3].blocks:
                x = block(x)

        if stage == 3:
            if down_sample:
                x = self.model.layers[-3].downsample(x)

            for block in self.model.layers[-2].blocks:
                x = block(x)

        if stage == 4:
            if down_sample:
                x = self.model.layers[-2].downsample(x)

            for block in self.model.layers[-1].blocks:
                x = block(x)

            norm_layer = self.model.norm
            x = norm_layer(x)

        return x
        
    def forward_features(self, x):
        x = self.model.forward_features(x)
        return x

    def forward(self, x):
        x = self.model.forward_features(x)
        if not self.feat is None:
            x = self.feat(x)
        return x

'''
# 这个版本的残差网络是最基础的网络
class ResNetStudent(nn.Module):
    def __init__(self, num_features=512):
        super(ResNetStudent, self).__init__()
        self.model = timm.create_model('resnet50', pretrained=True)  # 使用ResNet-50作为学生模型
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten_dim = 2048
        self.feat = nn.Linear(self.flatten_dim, num_features) if num_features > 0 else None
    
    def extract_feat(self, x):
        # 创建一个空列表，用于保存各层的输出特征
        features = []
        
        # 提取每个阶段的层
        conv1 = self.model.conv1  # 初始卷积层
        bn1 = self.model.bn1
        act1 = self.model.act1
        maxpool = self.model.maxpool
        layer1 = self.model.layer1  # 第一阶段（残差块1）
        layer2 = self.model.layer2  # 第二阶段（残差块2）
        layer3 = self.model.layer3  # 第三阶段（残差块3）
        layer4 = self.model.layer4  # 第四阶段（残差块4）
        
        x = conv1(x)
        x = bn1(x)
        x = act1(x)
        x = maxpool(x)
        stage1_out = layer1(x)  # 第一阶段的输出
        features.append(stage1_out)
        stage2_out = layer2(stage1_out)  # 第二阶段的输出
        features.append(stage2_out)
        stage3_out = layer3(stage2_out)  # 第三阶段的输出
        features.append(stage3_out)
        stage4_out = layer4(stage3_out)  # 第四阶段的输出
        features.append(stage4_out)
        return tuple(features)
        
    def forward_features(self, x):
        x = self.model.forward_features(x)
        return x

    def forward(self, x):
        x = self.model.forward_features(x)
        # 池化操作，[batch_size, 2048, 7, 7] -> [batch_size, 2048, 1, 1]
        x = self.gap(x)
        # 展平特征图，将其变为 [batch_size, 2048 * 1 * 1]
        x = x.view(x.size(0), -1)  # 展平
        if not self.feat is None:
            x = self.feat(x)
        return x
'''
'''
# 这个版本的残差网络是修改过最大池化的网络
class ResNetStudent(nn.Module):
    def __init__(self, num_features=512):
        super(ResNetStudent, self).__init__()
        self.model = timm.create_model('resnet50', pretrained=True)  # 使用ResNet-50作为学生模型
        # 修改 layer1 和 layer2，向每个残差块中的 ReLU 前加上 InstanceNorm2d
        self._modify_layer(self.model.layer1)
        self._modify_layer(self.model.layer2)
        self._modify_layer_stride(self.model.layer4[0].conv2, self.model.layer4[0].downsample[0])
        # 输出设置
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten_dim = 2048
        self.feat = nn.Linear(self.flatten_dim, num_features) if num_features > 0 else None
    
    def extract_feat(self, x):
        # 创建一个空列表，用于保存各层的输出特征
        features = []
        
        # 提取每个阶段的层
        conv1 = self.model.conv1  # 初始卷积层
        bn1 = self.model.bn1
        act1 = self.model.act1
        maxpool = self.model.maxpool
        layer1 = self.model.layer1  # 第一阶段（残差块1）
        layer2 = self.model.layer2  # 第二阶段（残差块2）
        layer3 = self.model.layer3  # 第三阶段（残差块3）
        layer4 = self.model.layer4  # 第四阶段（残差块4）
        
        x = conv1(x)
        x = bn1(x)
        x = act1(x)
        x = maxpool(x)
        stage1_out = layer1(x)  # 第一阶段的输出
        features.append(stage1_out)
        stage2_out = layer2(stage1_out)  # 第二阶段的输出
        features.append(stage2_out)
        stage3_out = layer3(stage2_out)  # 第三阶段的输出
        features.append(stage3_out)
        stage4_out = layer4(stage3_out)  # 第四阶段的输出
        features.append(stage4_out)
        return tuple(features)

    def _modify_layer(self, layer):
        """
        在每个残差块中的 ReLU 前加上 InstanceNorm2d 操作。
        """
        for block in layer:
            # 修改 conv1 和 conv2 之后的 ReLU，将 InstanceNorm2d 放在 ReLU 前面
            # 对于每个残差块，将 InstanceNorm2d 加入到 ReLU 之前
            block.act3 = nn.Sequential(
                nn.InstanceNorm2d(block.conv3.out_channels, affine=True),
                nn.ReLU(inplace=True)
            )
            
    def _modify_layer_stride(self, last_layer, last_layer_downsample):
        # 在最后一层将stride改为1
        last_layer.stride = (1, 1)
        last_layer_downsample.stride = (1, 1)
        
    def forward_features(self, x):
        x = self.model.forward_features(x)
        return x

    def forward(self, x):
        x = self.model.forward_features(x)
        # 池化操作，[batch_size, 2048, 7, 7] -> [batch_size, 2048, 1, 1]
        x = self.gap(x)
        # 展平特征图，将其变为 [batch_size, 2048 * 1 * 1]
        x = x.view(x.size(0), -1)  # 展平
        if not self.feat is None:
            x = self.feat(x)
        return x
'''

'''
# 这个残差网络是提取每个残差块特征的网络
class ResNetStudent(nn.Module):
    def __init__(self, num_features=512):
        super(ResNetStudent, self).__init__()
        self.model = timm.create_model('resnet50', pretrained=True)  # 使用ResNet-50作为学生模型
        # 修改 layer1 和 layer2，向每个残差块中的 ReLU 前加上 InstanceNorm2d
        self._modify_layer(self.model.layer1)
        self._modify_layer(self.model.layer2)
        self._modify_layer_stride(self.model.layer4[0].conv2, self.model.layer4[0].downsample[0])
        # 进行特征映射
        self.projector_1 = AttentionProjector(student_dims=256, teacher_dims=128, hw_dims=(56, 56), pos_dims=128, window_shapes=(1, 1), self_query=True, 
                                 softmax_scale=5.0, num_heads=4)
        self.projector_2 = AttentionProjector(student_dims=512, teacher_dims=256, hw_dims=(28, 28), pos_dims=256, window_shapes=(1, 1), self_query=True, 
                                         softmax_scale=5.0, num_heads=8)
        self.projector_3 = AttentionProjector(student_dims=1024, teacher_dims=512, hw_dims=(14, 14), pos_dims=512, window_shapes=(1, 1), self_query=True, 
                                         softmax_scale=5.0, num_heads=16)
        self.projector_4 = AttentionProjector(student_dims=2048, teacher_dims=1024, hw_dims=(7, 7), pos_dims=1024, window_shapes=(1, 1), self_query=True, 
                                 softmax_scale=5.0, num_heads=32)
        # 输出设置
        self.gap = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten_dim = 2048
        self.feat = nn.Linear(self.flatten_dim, num_features) if num_features > 0 else None
    
    def extract_feat(self, x):
        # 创建一个空列表，用于保存各层的输出特征
        features = []
        
        # 提取每个阶段的层
        conv1 = self.model.conv1  # 初始卷积层
        bn1 = self.model.bn1
        act1 = self.model.act1
        maxpool = self.model.maxpool
        layer1 = self.model.layer1  # 第一阶段（残差块1）
        layer2 = self.model.layer2  # 第二阶段（残差块2）
        layer3 = self.model.layer3  # 第三阶段（残差块3）
        layer4 = self.model.layer4  # 第四阶段（残差块4）
        
        x = conv1(x)
        x = bn1(x)
        x = act1(x)
        x = maxpool(x)
        stage1_out = layer1(x)  # 第一阶段的输出
        # features.append(stage1_out)
        stage2_out = layer2(stage1_out)  # 第二阶段的输出
        # features.append(stage2_out)
        stage3_out = layer3(stage2_out)  # 第三阶段的输出
        student_feature_proj3 = self.projector_3(stage3_out)
        features.append(student_feature_proj3)
        stage4_out = layer4(stage3_out)  # 第四阶段的输出
        student_feature_proj4 = self.projector_4(stage4_out)
        features.append(student_feature_proj4)
        return tuple(features)
        
    def extract_feat_proj(self, x):
        features = []
        student_features = self.extract_feat(x)
        # student_feature_proj1 = self.projector_1(student_features[0])
        # features.append(student_feature_proj1)
        # student_feature_proj2 = self.projector_2(student_features[1])
        # features.append(student_feature_proj2)
        student_feature_proj3 = self.projector_3(student_features[2])
        features.append(student_feature_proj3)
        student_feature_proj4 = self.projector_4(student_features[3])
        features.append(student_feature_proj4)
        return tuple(features)
        
    def _modify_layer(self, layer):
        """
        在每个残差块中的 ReLU 前加上 InstanceNorm2d 操作。
        """
        for block in layer:
            # 修改 conv1 和 conv2 之后的 ReLU，将 InstanceNorm2d 放在 ReLU 前面
            # 对于每个残差块，将 InstanceNorm2d 加入到 ReLU 之前
            block.act3 = nn.Sequential(
                nn.InstanceNorm2d(block.conv3.out_channels, affine=True),
                nn.ReLU(inplace=True)
            )
            
    def _modify_layer_stride(self, last_layer, last_layer_downsample):
        # 在最后一层将stride改为1
        last_layer.stride = (1, 1)
        last_layer_downsample.stride = (1, 1)
        
    def forward_features(self, x):
        x = self.model.forward_features(x)
        return x

    def forward(self, x):
        x = self.model.forward_features(x)
        # 池化操作，[batch_size, 2048, 7, 7] -> [batch_size, 2048, 1, 1]
        x = self.gap(x)
        # 展平特征图，将其变为 [batch_size, 2048 * 1 * 1]
        x = x.view(x.size(0), -1)  # 展平
        if not self.feat is None:
            x = self.feat(x)
        return x
'''

class Data_Processor(object):
    def __init__(self, height, width):
        self.height = height
        self.width = width
        self.transformer = Transforms.Compose([
            Transforms.Resize((self.height, self.width)),
            Transforms.ToTensor(),
            Transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __call__(self, img):
        return self.transformer(img).unsqueeze(0)

data_processor = Data_Processor(height=224, width=224)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# swin_model = SwinTransformerTeacher(num_features=512).cuda()
# swin_model.eval()

# resnet_model = ResNetStudent(num_features=512).cuda()
# resnet_model.eval()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [5]:
# swin_transformer
swin_model = SwinTransformerTeacher(num_features=512).cuda()
swin_model.eval()

swin_weight_path = '/homec/xiaolei/projects/ISR/weights/swin_base_patch4_window7_224.pth'
swin_weight = torch.load(swin_weight_path)
swin_model.load_state_dict(swin_weight['state_dict'], strict=True)

# 残差网络
# 初始化网络
resnet_model = ResNetStudent(num_features=512).cuda()
resnet_model.eval()
# 加载预训练权重
resnet_weight_path = 'weights/student_model_base5_strong_reid_mmd_mse_loss_confusion/best_student_model.pth'
resnet_weight = torch.load(resnet_weight_path)
resnet_model.load_state_dict(resnet_weight, strict=True)

ResNetStudent(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): ReLU(inplace=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     

In [6]:
img_root = "/homec/xiaolei/projects/ultralytics/runs/results11"

all_simlarity = 0
img_file_count = 0
for item in os.listdir(img_root):
    sub_img_root = os.path.join(img_root, item)
    print(f'sub_img_root: {sub_img_root}')
    
    # 判断是否是文件夹
    if os.path.isdir(sub_img_root):
        img_files = glob.glob(os.path.join(sub_img_root, '*.jpg'))
        img_file_count += len(img_files)
        for img_file in img_files:
            query = data_processor(Image.open(img_file).convert('RGB')).cuda()
            with torch.no_grad():
                A_feat = F.normalize(swin_model(query), dim=1).cpu()
                B_feat = F.normalize(resnet_model(query), dim=1).cpu()
            # print(A_feat)
            simlarity = A_feat.matmul(B_feat.transpose(1, 0))
            # print(f'simlarity: {simlarity[0, 0]}')
            all_simlarity += simlarity[0, 0]
    # break

print(f'avg_simlarity: {all_simlarity / img_file_count}')

sub_img_root: /homec/xiaolei/projects/ultralytics/runs/results11/147_1129
sub_img_root: /homec/xiaolei/projects/ultralytics/runs/results11/results.txt
sub_img_root: /homec/xiaolei/projects/ultralytics/runs/results11/.ipynb_checkpoints
sub_img_root: /homec/xiaolei/projects/ultralytics/runs/results11/148_1129
sub_img_root: /homec/xiaolei/projects/ultralytics/runs/results11/202_1129
sub_img_root: /homec/xiaolei/projects/ultralytics/runs/results11/203_1129
sub_img_root: /homec/xiaolei/projects/ultralytics/runs/results11/236_1129
sub_img_root: /homec/xiaolei/projects/ultralytics/runs/results11/302_1129
sub_img_root: /homec/xiaolei/projects/ultralytics/runs/results11/303_1129
sub_img_root: /homec/xiaolei/projects/ultralytics/runs/results11/438_1129
sub_img_root: /homec/xiaolei/projects/ultralytics/runs/results11/439_1129
avg_simlarity: 0.8889502882957458


In [None]:
# 测试自定义数据集
def make_val_data_loader():
    val_transforms = Transforms.Compose([
        Transforms.Resize((224, 224)),
        Transforms.ToTensor(),
        Transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    dataset = CustomReid(dataset_dir='court1888')
    val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms)
    val_loader = DataLoader(
        val_set, batch_size=128, shuffle=False, num_workers=8,
        collate_fn=val_collate_fn
    )
    return val_loader, len(dataset.query)

val_loader, num_query = make_val_data_loader()

evaluator_resnet = R1_mAP_reranking(num_query)
evaluator_swin = R1_mAP_reranking(num_query)
evaluator_resnet.reset()
evaluator_swin.reset()
with tqdm(total=len(val_loader), desc=f"valid") as pbar:
    with torch.no_grad():
        for person_image_bs in val_loader:
            output_resnet = []
            output_swim = []
            data, pids, camids = person_image_bs
            data = data.to(device) if torch.cuda.device_count() >= 1 else data
            feat_resnet = resnet_model(data)
            feat_swin = swin_model(data)
            output_resnet.append(feat_resnet)
            output_resnet.append(pids)
            output_resnet.append(camids)
            output_resnet = tuple(output_resnet)
            output_swim.append(feat_swin)
            output_swim.append(pids)
            output_swim.append(camids)
            output_swim = tuple(output_swim)
            evaluator_resnet.update(output_resnet)
            evaluator_swin.update(output_swim)
            pbar.update(1)
cmc_resnet, mAP_resnet = evaluator_resnet.compute()
print('Validation Results')
print("mAP_resnet: {:.1%}".format(mAP_resnet))
for r in [1, 5, 10]:
    print("CMC_resnet curve, Rank-{:<3}:{:.1%}".format(r, cmc_resnet[r - 1]))
cmc_swin, mAP_swin = evaluator_swin.compute()
print('Validation Results')
print("mAP_swin: {:.1%}".format(mAP_swin))
for r in [1, 5, 10]:
    print("CMC_swin curve, Rank-{:<3}:{:.1%}".format(r, cmc_swin[r - 1]))

In [None]:
# 测试数据集Market1501
def make_val_data_loader():
    val_transforms = Transforms.Compose([
        Transforms.Resize((224, 224)),
        Transforms.ToTensor(),
        Transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    dataset = Market1501()
    val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms)
    val_loader = DataLoader(
        val_set, batch_size=128, shuffle=False, num_workers=8,
        collate_fn=val_collate_fn
    )
    return val_loader, len(dataset.query)

val_loader, num_query = make_val_data_loader()

evaluator_resnet = R1_mAP_reranking(num_query)
evaluator_swin = R1_mAP_reranking(num_query)
evaluator_resnet.reset()
evaluator_swin.reset()
with tqdm(total=len(val_loader), desc=f"valid") as pbar:
    with torch.no_grad():
        for person_image_bs in val_loader:
            output_resnet = []
            output_swim = []
            data, pids, camids = person_image_bs
            data = data.to(device) if torch.cuda.device_count() >= 1 else data
            feat_resnet = resnet_model(data)
            feat_swin = swin_model(data)
            output_resnet.append(feat_resnet)
            output_resnet.append(pids)
            output_resnet.append(camids)
            output_resnet = tuple(output_resnet)
            output_swim.append(feat_swin)
            output_swim.append(pids)
            output_swim.append(camids)
            output_swim = tuple(output_swim)
            evaluator_resnet.update(output_resnet)
            evaluator_swin.update(output_swim)
            pbar.update(1)
cmc_resnet, mAP_resnet = evaluator_resnet.compute()
print('Validation Results')
print("mAP_resnet: {:.1%}".format(mAP_resnet))
for r in [1, 5, 10]:
    print("CMC_resnet curve, Rank-{:<3}:{:.1%}".format(r, cmc_resnet[r - 1]))
cmc_swin, mAP_swin = evaluator_swin.compute()
print('Validation Results')
print("mAP_swin: {:.1%}".format(mAP_swin))
for r in [1, 5, 10]:
    print("CMC_swin curve, Rank-{:<3}:{:.1%}".format(r, cmc_swin[r - 1]))

In [None]:
# 测试数据集MSMT17
def make_val_data_loader():
    val_transforms = Transforms.Compose([
        Transforms.Resize((224, 224)),
        Transforms.ToTensor(),
        Transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    dataset = MSMT17()
    val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms)
    val_loader = DataLoader(
        val_set, batch_size=128, shuffle=False, num_workers=8,
        collate_fn=val_collate_fn
    )
    return val_loader, len(dataset.query)

val_loader, num_query = make_val_data_loader()

evaluator_resnet = R1_mAP_reranking(num_query)
evaluator_swin = R1_mAP_reranking(num_query)
evaluator_resnet.reset()
evaluator_swin.reset()
with tqdm(total=len(val_loader), desc=f"valid") as pbar:
    with torch.no_grad():
        for person_image_bs in val_loader:
            output_resnet = []
            output_swim = []
            data, pids, camids = person_image_bs
            data = data.to(device) if torch.cuda.device_count() >= 1 else data
            feat_resnet = resnet_model(data)
            feat_swin = swin_model(data)
            output_resnet.append(feat_resnet)
            output_resnet.append(pids)
            output_resnet.append(camids)
            output_resnet = tuple(output_resnet)
            output_swim.append(feat_swin)
            output_swim.append(pids)
            output_swim.append(camids)
            output_swim = tuple(output_swim)
            evaluator_resnet.update(output_resnet)
            evaluator_swin.update(output_swim)
            pbar.update(1)
cmc_resnet, mAP_resnet = evaluator_resnet.compute()
print('Validation Results')
print("mAP_resnet: {:.1%}".format(mAP_resnet))
for r in [1, 5, 10]:
    print("CMC_resnet curve, Rank-{:<3}:{:.1%}".format(r, cmc_resnet[r - 1]))
cmc_swin, mAP_swin = evaluator_swin.compute()
print('Validation Results')
print("mAP_swin: {:.1%}".format(mAP_swin))
for r in [1, 5, 10]:
    print("CMC_swin curve, Rank-{:<3}:{:.1%}".format(r, cmc_swin[r - 1]))

In [None]:
# 测试数据集CUHK03
def make_val_data_loader():
    val_transforms = Transforms.Compose([
        Transforms.Resize((224, 224)),
        Transforms.ToTensor(),
        Transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    dataset = CUHK03()
    val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms)
    val_loader = DataLoader(
        val_set, batch_size=128, shuffle=False, num_workers=8,
        collate_fn=val_collate_fn
    )
    return val_loader, len(dataset.query)

val_loader, num_query = make_val_data_loader()

evaluator_resnet = R1_mAP_reranking(num_query)
evaluator_swin = R1_mAP_reranking(num_query)
evaluator_resnet.reset()
evaluator_swin.reset()
with tqdm(total=len(val_loader), desc=f"valid") as pbar:
    with torch.no_grad():
        for person_image_bs in val_loader:
            output_resnet = []
            output_swim = []
            data, pids, camids = person_image_bs
            data = data.to(device) if torch.cuda.device_count() >= 1 else data
            feat_resnet = resnet_model(data)
            feat_swin = swin_model(data)
            output_resnet.append(feat_resnet)
            output_resnet.append(pids)
            output_resnet.append(camids)
            output_resnet = tuple(output_resnet)
            output_swim.append(feat_swin)
            output_swim.append(pids)
            output_swim.append(camids)
            output_swim = tuple(output_swim)
            evaluator_resnet.update(output_resnet)
            evaluator_swin.update(output_swim)
            pbar.update(1)
cmc_resnet, mAP_resnet = evaluator_resnet.compute()
print('Validation Results')
print("mAP_resnet: {:.1%}".format(mAP_resnet))
for r in [1, 5, 10]:
    print("CMC_resnet curve, Rank-{:<3}:{:.1%}".format(r, cmc_resnet[r - 1]))
cmc_swin, mAP_swin = evaluator_swin.compute()
print('Validation Results')
print("mAP_swin: {:.1%}".format(mAP_swin))
for r in [1, 5, 10]:
    print("CMC_swin curve, Rank-{:<3}:{:.1%}".format(r, cmc_swin[r - 1]))

In [None]:
# 测试数据集DukeMTMCreID
def make_val_data_loader():
    val_transforms = Transforms.Compose([
        Transforms.Resize((224, 224)),
        Transforms.ToTensor(),
        Transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    dataset = DukeMTMCreID()
    val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms)
    val_loader = DataLoader(
        val_set, batch_size=128, shuffle=False, num_workers=8,
        collate_fn=val_collate_fn
    )
    return val_loader, len(dataset.query)

val_loader, num_query = make_val_data_loader()

evaluator_resnet = R1_mAP_reranking(num_query)
evaluator_swin = R1_mAP_reranking(num_query)
evaluator_resnet.reset()
evaluator_swin.reset()
with tqdm(total=len(val_loader), desc=f"valid") as pbar:
    with torch.no_grad():
        for person_image_bs in val_loader:
            output_resnet = []
            output_swim = []
            data, pids, camids = person_image_bs
            data = data.to(device) if torch.cuda.device_count() >= 1 else data
            feat_resnet = resnet_model(data)
            feat_swin = swin_model(data)
            output_resnet.append(feat_resnet)
            output_resnet.append(pids)
            output_resnet.append(camids)
            output_resnet = tuple(output_resnet)
            output_swim.append(feat_swin)
            output_swim.append(pids)
            output_swim.append(camids)
            output_swim = tuple(output_swim)
            evaluator_resnet.update(output_resnet)
            evaluator_swin.update(output_swim)
            pbar.update(1)
cmc_resnet, mAP_resnet = evaluator_resnet.compute()
print('Validation Results')
print("mAP_resnet: {:.1%}".format(mAP_resnet))
for r in [1, 5, 10]:
    print("CMC_resnet curve, Rank-{:<3}:{:.1%}".format(r, cmc_resnet[r - 1]))
cmc_swin, mAP_swin = evaluator_swin.compute()
print('Validation Results')
print("mAP_swin: {:.1%}".format(mAP_swin))
for r in [1, 5, 10]:
    print("CMC_swin curve, Rank-{:<3}:{:.1%}".format(r, cmc_swin[r - 1]))

In [None]:
import torch
import torch.nn as nn
import timm

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

swin_model = SwinTransformerTeacher()
swin_model = swin_model.to(device)
res_model = ResNetStudent()
res_model = res_model.to(device)

swin_features = swin_model.extract_feat(torch.randn(1, 3, 224, 224).to(device))
res_features = res_model.extract_feat(torch.randn(1, 3, 224, 224).to(device))

for swin_feature in swin_features:
    print(swin_feature.shape)
print('='*100)
for res_feature in res_features:
    print(res_feature.shape)

print(res_model(torch.randn(1, 3, 224, 224).to(device)).shape)
print(swin_model)

In [None]:
import torch
import torch.nn.functional as F

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def MMD(x, y, kernel):
    """Emprical maximum mean discrepancy. The lower the result
       the more evidence that distributions are the same.

    Args:
        x: first sample, distribution P
        y: second sample, distribution Q
        kernel: kernel type such as "multiscale" or "rbf"
    """
    xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t())
    rx = (xx.diag().unsqueeze(0).expand_as(xx))
    ry = (yy.diag().unsqueeze(0).expand_as(yy))
    
    dxx = rx.t() + rx - 2. * xx # Used for A in (1)
    dyy = ry.t() + ry - 2. * yy # Used for B in (1)
    dxy = rx.t() + ry - 2. * zz # Used for C in (1)
    
    XX, YY, XY = (torch.zeros(xx.shape),
                  torch.zeros(xx.shape),
                  torch.zeros(xx.shape))
    
    if kernel == "multiscale":
        
        bandwidth_range = [0.2, 0.5, 0.9, 1.3]
        for a in bandwidth_range:
            XX += a**2 * (a**2 + dxx)**-1
            YY += a**2 * (a**2 + dyy)**-1
            XY += a**2 * (a**2 + dxy)**-1
            
    if kernel == "rbf":
      
        bandwidth_range = [10, 15, 20, 50]
        for a in bandwidth_range:
            XX += torch.exp(-0.5*dxx/a)
            YY += torch.exp(-0.5*dyy/a)
            XY += torch.exp(-0.5*dxy/a)
      
      

    return torch.mean(XX + YY - 2. * XY)

class RBF(nn.Module):

    def __init__(self, n_kernels=5, mul_factor=2.0, bandwidth=None):
        super().__init__()
        self.bandwidth_multipliers = mul_factor ** (torch.arange(n_kernels) - n_kernels // 2)
        self.bandwidth = bandwidth

    def get_bandwidth(self, L2_distances):
        if self.bandwidth is None:
            n_samples = L2_distances.shape[0]
            return L2_distances.data.sum() / (n_samples ** 2 - n_samples)

        return self.bandwidth

    def forward(self, X):
        L2_distances = torch.cdist(X, X) ** 2
        return torch.exp(-L2_distances[None, ...] / (self.get_bandwidth(L2_distances) * self.bandwidth_multipliers)[:, None, None]).sum(dim=0)


class MMDLoss(nn.Module):

    def __init__(self, kernel=RBF()):
        super().__init__()
        self.kernel = kernel

    def forward(self, X, Y):
        K = self.kernel(torch.vstack([X, Y]))

        X_size = X.shape[0]
        XX = K[:X_size, :X_size].mean()
        XY = K[:X_size, X_size:].mean()
        YY = K[X_size:, X_size:].mean()
        return XX - 2 * XY + YY

def rbf_kernel(x, y, sigma=1.0):
    """
    计算高斯 RBF 核函数
    :param x: 输入张量 x (batch_size, feature_dim)
    :param y: 输入张量 y (batch_size, feature_dim)
    :param sigma: 核函数的宽度，控制相似度的范围
    :return: 计算得到的 RBF 核
    """
    # 计算样本之间的平方欧几里得距离
    xx = torch.sum(x ** 2, dim=1, keepdim=True)
    yy = torch.sum(y ** 2, dim=1, keepdim=True)
    dist = xx + yy.t() - 2 * torch.matmul(x, y.t())
    
    # 计算 RBF 核（高斯核）
    return torch.exp(-dist / (2 * sigma ** 2))

def mmd_loss(X, Y, sigma=1.0):
    """
    计算最大均值差异（MMD）损失
    :param X: 样本集 X (batch_size_1, feature_dim)
    :param Y: 样本集 Y (batch_size_2, feature_dim)
    :param sigma: 核函数的宽度，控制相似度的范围
    :return: MMD 损失
    """
    # 计算 RBF 核
    XX = rbf_kernel(X, X, sigma)  # X 中样本对之间的核
    YY = rbf_kernel(Y, Y, sigma)  # Y 中样本对之间的核
    XY = rbf_kernel(X, Y, sigma)  # X 和 Y 中样本对之间的核

    # 计算 MMD 损失
    loss = XX.mean() + YY.mean() - 2 * XY.mean()
    
    return loss

# 示例：定义一些随机的样本数据
X = torch.randn(100, 128)  # 样本集 X，假设有 100 个样本，每个样本是 128 维特征
Y = torch.randn(100, 128)  # 样本集 Y，假设有 100 个样本，每个样本是 128 维特征

# 计算 MMD 损失
loss = mmd_loss(X, Y, sigma=1.0)
MMDLoss = MMDLoss()
res = MMDLoss(X, Y)
res2 = MMD(X, Y, kernel="rbf")
print(f"MMD Loss: {loss.item()}")
print(f"MMD Loss: {res.item()}")
print(f"MMD Loss: {res2.item()}")

In [None]:
# The path of your image pair
image_pair_path = ['test_data/reid/19-24-16_0_2*2.jpg',
           'test_data/reid/19-24-19_0_1*2.jpg']

query_A = data_processor(Image.open(image_pair_path[0]).convert('RGB')).cuda()
query_B = data_processor(Image.open(image_pair_path[1]).convert('RGB')).cuda()
with torch.no_grad():
    A_feat = F.normalize(model(query_A), dim=1).cpu()
    B_feat = F.normalize(model(query_B), dim=1).cpu()
simlarity = A_feat.matmul(B_feat.transpose(1, 0))


fig, axes = plt.subplots(1, 2, figsize = (2 * 2, 2))
image1 = np.array(Image.open(image_pair_path[0]).convert('RGB').resize((64, 128)))
image2 = np.array(Image.open(image_pair_path[1]).convert('RGB').resize((64, 128)))

axes[0].imshow(image1)
axes[0].set_title('image 1', fontsize=16)
axes[0].axis('off')

axes[1].imshow(image2)
axes[1].set_title('image 2', fontsize=16)
axes[1].axis('off')
print("\033[1;31m The similarity is {}".format(simlarity[0,0]))

In [None]:
# The path of your image pair
image_pair_path = ['test_data/reid/18-48-00_0_1*1.jpg',
           'test_data/reid/19-24-19_0_1*2.jpg']

query_A = data_processor(Image.open(image_pair_path[0]).convert('RGB')).cuda()
query_B = data_processor(Image.open(image_pair_path[1]).convert('RGB')).cuda()
with torch.no_grad():
    A_feat = F.normalize(model(query_A), dim=1).cpu()
    B_feat = F.normalize(model(query_B), dim=1).cpu()
simlarity = A_feat.matmul(B_feat.transpose(1, 0))


fig, axes = plt.subplots(1, 2, figsize = (2 * 2, 2))
image1 = np.array(Image.open(image_pair_path[0]).convert('RGB').resize((64, 128)))
image2 = np.array(Image.open(image_pair_path[1]).convert('RGB').resize((64, 128)))

axes[0].imshow(image1)
axes[0].set_title('image 1', fontsize=16)
axes[0].axis('off')

axes[1].imshow(image2)
axes[1].set_title('image 2', fontsize=16)
axes[1].axis('off')
print("\033[1;31m The similarity is {}".format(simlarity[0,0]))

In [None]:
model_a = SwinTransformerTeacher(num_features=512).cuda()
model.eval()

model_b = ResNetStudent(num_features=512).cuda()
model.eval()

# swin_transformer
weight_path = '/homec/xiaolei/projects/ISR/weights/swin_base_patch4_window7_224.pth'
weight = torch.load(weight_path)
model_a.load_state_dict(weight['state_dict'], strict=True)

# 残差网络
weight_path = 'weights/student_model_base5_strong_reid_mmd_mse_loss/best_student_model.pth'
weight = torch.load(weight_path)
model_b.load_state_dict(weight, strict=True)

In [None]:
def get_model_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()  # 参数数量 * 每个参数的字节数
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    size_all_mb = (param_size + buffer_size) / 1024**2  # 转换为MB
    return size_all_mb

# 比较两个模型的大小
print(f"Model A size (MB): {get_model_size(model_a)}")
print(f"Model B size (MB): {get_model_size(model_b)}")

In [None]:
from ptflops import get_model_complexity_info

# 比较两个模型的 FLOPs
with torch.cuda.device(0):  # 如果有GPU可以指定使用
    flops_a, params_a = get_model_complexity_info(model_a, (3, 224, 224), as_strings=True, print_per_layer_stat=False)
    flops_b, params_b = get_model_complexity_info(model_b, (3, 224, 224), as_strings=True, print_per_layer_stat=False)

print(f"Model A - FLOPs: {flops_a}, Params: {params_a}")
print(f"Model B - FLOPs: {flops_b}, Params: {params_b}")

In [None]:
import torch
import torch.cuda

def get_memory_usage(model, input_size=(1, 3, 224, 224)):
    # 清理缓存，确保数据准确
    torch.cuda.empty_cache()
    input_data = torch.randn(input_size).cuda()  # 假设输入大小为 224x224 的图像
    model = model.cuda()

    with torch.no_grad():
        torch.cuda.reset_peak_memory_stats()
        _ = model(input_data)
        max_memory = torch.cuda.max_memory_allocated() / 1024**2  # 转换为 MB

    return max_memory

print(f"Model A Memory Usage (MB): {get_memory_usage(model_a)}")
print(f"Model B Memory Usage (MB): {get_memory_usage(model_b)}")

In [None]:
import time

def get_inference_time(model, input_size=(1, 3, 224, 224), iterations=100):
    input_data = torch.randn(input_size).cuda()
    model = model.cuda()
    
    # 热身运行，避免初始加载影响结果
    with torch.no_grad():
        _ = model(input_data)

    start_time = time.time()
    with torch.no_grad():
        for _ in range(iterations):
            _ = model(input_data)
    avg_inference_time = (time.time() - start_time) / iterations
    return avg_inference_time

print(f"Model A Inference Time (seconds): {get_inference_time(model_a)}")
print(f"Model B Inference Time (seconds): {get_inference_time(model_b)}")

In [None]:
import json

def read_json(file_path):
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)  # 使用 json.load() 将 JSON 文件加载为 Python 对象
            return data
    except FileNotFoundError:
        print(f"文件 {file_path} 未找到")
    except json.JSONDecodeError as e:
        print(f"JSON 文件解析错误: {e}")
    except Exception as e:
        print(f"发生错误: {e}")

def modify_path(original_path):
    # 替换路径的根目录，同时将反斜杠替换为正斜杠
    modified_path = original_path.replace("\\", "/").replace("./data", "./test_data")
    return modified_path

def save_to_json(data, file_path):
    try:
        # 使用 json.dump 将数据保存到文件
        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(data, f, indent=4, ensure_ascii=False)  # 格式化输出并支持非 ASCII 字符
            print(f"数据已成功保存到 {file_path}")
    except Exception as e:
        print(f"保存数据时发生错误: {e}")

# 示例用法
file_path = "test_data/cuhk03/splits_new_labeled.json"  # 替换为你的 JSON 文件路径
output_file_path = "test_data/cuhk03/splits_new_labeled2.json"  # 替换为你的 JSON 文件路径
json_data = read_json(file_path)

if json_data:
    print("读取到的 JSON 数据：")
    for i in range(len(json_data)):
        for key in ('train', 'query', 'gallery'):
            for j in range(len(json_data[i][key])):
                json_data[i][key][j][0] = modify_path(json_data[i][key][j][0])
    save_to_json(json_data, output_file_path)

In [None]:
import os
import shutil

def copy_files(src_folder, dest_folder):
    # 检查源文件夹是否存在
    if not os.path.exists(src_folder):
        print(f"源文件夹 {src_folder} 不存在")
        return

    # 如果目标文件夹不存在，则创建
    if not os.path.exists(dest_folder):
        os.makedirs(dest_folder)

    # 遍历源文件夹中的所有文件
    for file_name in os.listdir(src_folder):
        src_file = os.path.join(src_folder, file_name)
        dest_file = os.path.join(dest_folder, file_name)

        # 如果是文件则进行复制
        if os.path.isfile(src_file):
            shutil.copy2(src_file, dest_file)  # 使用 copy2 保留文件的元数据
            # print(f"已复制: {src_file} 到 {dest_file}")
        else:
            print(f"跳过非文件: {src_file}")

# 示例用法
src_folder = "datasets/valid"
dest_folder = "datasets/basketball_player_fusion/valid"
copy_files(src_folder, dest_folder)
print(f'复制完成')

In [None]:
import os
import shutil
from sklearn.model_selection import train_test_split

def split_dataset(source_dir, target_dir, train_ratio=0.8):
    # 获取源文件夹中所有文件
    files = [f for f in os.listdir(source_dir) if os.path.isfile(os.path.join(source_dir, f))]
    
    # 按照 train_ratio 划分训练集和验证集
    train_files, val_files = train_test_split(files, test_size=1-train_ratio, random_state=42)
    
    # 定义训练集和验证集的目标路径
    train_dir = os.path.join(target_dir, "train")
    val_dir = os.path.join(target_dir, "valid")
    
    # 创建目标文件夹
    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(val_dir, exist_ok=True)
    
    # 将文件复制到训练集和验证集文件夹
    for file in train_files:
        shutil.copy(os.path.join(source_dir, file), os.path.join(train_dir, file))
    
    for file in val_files:
        shutil.copy(os.path.join(source_dir, file), os.path.join(val_dir, file))
    
    print(f"训练集文件数量: {len(train_files)}")
    print(f"验证集文件数量: {len(val_files)}")
    print(f"数据集已成功分配到 {target_dir}")

# 使用示例
source_dir = "test_data/MSMT17/bounding_box_train"  # 源文件夹路径
target_dir = "datasets/basketball_player_fusion"  # 目标文件夹路径
split_dataset(source_dir, target_dir)