In [None]:
from dataset import test_build_dataset
from torch.utils.data import DataLoader
import torch
import timm
from torch import nn

In [None]:
dataset_train, nb_classes = test_build_dataset()

dataloader_train = DataLoader(
    dataset_train,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    pin_memory=False
)

In [None]:
model = timm.create_model('resnet34.a1_in1k', pretrained=True, num_classes=nb_classes).to('cuda')
model.load_state_dict(torch.load('models_para/resnet34.a1_in1k_ecgid_kd.pth'), strict=False)

In [None]:
class EmbeddingHead(nn.Module):
    """
    Embedding compression head (dim reduction + normalization).
    """
    def __init__(self, in_dim: int, target_dim: int = 128):
        super().__init__()
        self.dim_reduction = nn.Sequential(
            nn.Linear(in_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(256, target_dim),
            nn.BatchNorm1d(target_dim),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        emb = nn.functional.normalize(self.dim_reduction(x), p=2, dim=1)
        return emb

In [None]:
data, targets = next(iter(dataloader_train))
data, targets = data.to('cuda'), targets.to('cuda')

embedding_head = EmbeddingHead(512, 128).to('cuda')
features = model.forward_features(data)
features = model.global_pool(features)
# embs = embedding_head(features)
# embs.shape
embs = features
embs = nn.functional.normalize(embs, p=2, dim=1)

In [None]:
def get_mask(batch_shape):
    """
    加速正负样本的查找
    """
    classes_num, embedding_num = batch_shape
    batch_size = classes_num * embedding_num
    negative_mask, positive_mask = torch.full(
        (batch_size, batch_size), False), torch.full((batch_size, batch_size), True)
    for s in range(0, batch_size, embedding_num):
        for i in range(embedding_num):
            for j in range(embedding_num):
                negative_mask[s + i][s + j] = True 
    for s in range(0, batch_size, embedding_num):
        for i in range(embedding_num):
            for j in range(embedding_num):
                positive_mask[s + i][s + j] = False
    return positive_mask, negative_mask

In [None]:
distance_matrix = torch.cdist(embs, embs, p=2)

# 生成掩码
positive_mask, negative_mask = get_mask((embs.shape[0] // 4, 4))

# 正样本与锚点的距离掩码
pos_masked_matrix = distance_matrix.clone()
pos_masked_matrix[positive_mask] = float('-inf')

# 负样本与锚点的距离掩码
neg_masked_matrix = distance_matrix.clone()
neg_masked_matrix[negative_mask] = float('inf')

# 32个锚点的对应的正样本
_, hardest_positive_idxs = torch.max(pos_masked_matrix, dim=1)
positives = embs[hardest_positive_idxs]

# 32个锚点对应的负样本
_, hardest_negative_idxs = torch.min(neg_masked_matrix, dim=1)
negatives = embs[hardest_negative_idxs]

print("dist_pos", (embs - positives).norm(dim=1).mean().item())
print("dist_neg", (embs - negatives).norm(dim=1).mean().item())

In [None]:
distance_list = [[round(x, 4) for x in row] for row in distance_matrix.tolist()]
distance_list