In [None]:
import torch
from torch import nn
import torchvision.models as models
from torchvision import transforms

In [2]:
# 加载预训练的 ResNet-50 模型
resnet50 = models.resnet50(pretrained=True)

# 去掉最后的全连接层，只保留卷积层和全局平均池化层
resnet50 = torch.nn.Sequential(*list(resnet50.children())[:-1])

# 将模型设置为评估模式
resnet50.eval()



Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [3]:
# 定义图像预处理步骤
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [5]:
from torchvision import datasets
from torch.utils.data import Dataset, DataLoader

In [None]:
class ContrastiveLearningDataset(Dataset):
    def __init__(self, dataset, transform):
        """
        dataset: 原始数据集（如 ImageFolder 或其他自定义数据集）
        transform: 数据增强方法（如 transform_train）
        """
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        # 获取原始图像和标签
        img, _ = self.dataset[index]

        # 应用 transform 生成两个不同的增强视图
        img1 = self.transform(img)
        img2 = self.transform(img)

        # 返回两个增强视图和标签
        return img1, img2

In [7]:
# 加载原始数据集
root_dir = '/Users/yonglxie/Desktop'  # 假设图片在 'data/train' 文件夹中
original_dataset = datasets.ImageFolder(root=root_dir)

# 创建对比学习的数据集
contrastive_dataset = ContrastiveLearningDataset(original_dataset, preprocess)

# 使用 DataLoader 载入对比学习数据集
train_loader = DataLoader(contrastive_dataset, batch_size=32, shuffle=True, num_workers=0)


In [8]:
import torchvision.models as models
class ResNet50Embedding(nn.Module):
    """
    Neural network model for ranking translations
    """
    def __init__(self):
        """
        Initialize the model with a pre-trained backbone and regression head
        """
        super(ResNet50Embedding, self).__init__()
        resnet50 = models.resnet50(pretrained=True)
        # 去掉最后的全连接层，只保留卷积层和全局平均池化层
        self.features = torch.nn.Sequential(*list(resnet50.children())[:-1])
        self.flatten = torch.nn.Flatten()

    def forward(self, imgs, **kwargs):
        """
        Forward pass of the model
        Returns: Ranking scores for input sequences
        """
        embedding = self.features(imgs)
        embedding = self.flatten(embedding)
        
        return embedding

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class NTXentLoss(nn.Module):
    def __init__(self, temperature=0.5):
        """
        参数：
            temperature: 温度参数，用于缩放相似度
        """
        super(NTXentLoss, self).__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j):
        """
        计算 NTXentLoss
        参数：
            z_i: tensor，形状为 [batch_size, embedding_dim]，来自一个增强视角的表示
            z_j: tensor，形状为 [batch_size, embedding_dim]，来自另一个增强视角的表示
        返回：
            loss: 标量，当前 batch 的对比损失
        """
        batch_size = z_i.size(0)

        # 对嵌入向量进行 L2 归一化
        z_i = F.normalize(z_i, p=2, dim=1)
        z_j = F.normalize(z_j, p=2, dim=1)

        # 将正样本对拼接在一起，形状变为 [2*batch_size, embedding_dim]
        z = torch.cat([z_i, z_j], dim=0)

        # 计算两两之间的余弦相似度矩阵，形状为 [2N, 2N]
        # 使用矩阵乘法，由于向量已归一化，相当于计算余弦相似度
        sim_matrix = torch.matmul(z, z.T)

        # 除以温度参数
        sim_matrix = sim_matrix / self.temperature

        # 对角线位置为自身与自身的相似度，需要屏蔽掉（设为一个很小的值）
        mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
        sim_matrix.masked_fill_(mask, -float('inf'))

        # 构造正样本标签：
        # 对于前 batch_size 个样本，其正样本为索引 i+batch_size；
        # 对于后 batch_size 个样本，其正样本为索引 i-batch_size。
        target = torch.arange(batch_size, device=z.device)
        target = torch.cat([target + batch_size, target], dim=0)

        # 使用交叉熵损失计算 NTXentLoss
        loss = F.cross_entropy(sim_matrix, target, reduction='sum')
        loss = loss / (2 * batch_size)
        return loss

In [None]:
import torch.optim as optim
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ResNet50Embedding().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = NTXentLoss(temperature=0.5)

In [12]:
# 5. 开始训练
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for img_i, img_j in train_loader:
        img_i, img_j = img_i.to(device), img_j.to(device)
        embeddings_i = model(img_i)
        embeddings_j = model(img_j)
        
        loss = criterion(embeddings_i, embeddings_j)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader)}")


Epoch [1/10], Loss: 2.419577121734619
Epoch [2/10], Loss: 1.903899073600769
Epoch [3/10], Loss: 1.6054431200027466
Epoch [4/10], Loss: 1.4628502130508423
Epoch [5/10], Loss: 1.375741720199585
Epoch [6/10], Loss: 1.3202033042907715
Epoch [7/10], Loss: 1.2833561897277832
Epoch [8/10], Loss: 1.2583922147750854
Epoch [9/10], Loss: 1.2404769659042358
Epoch [10/10], Loss: 1.2266950607299805


In [None]:
# 6. 提取最后一层的embedding
def extract_embedding(model, img):
    model.eval()
    with torch.no_grad():
        embedding = model(img.to(device))
    return embedding.cpu().numpy()

In [None]:
extract_embedding(model, input_batch).shape