In [None]:
import os
import json
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import clip
import torch.nn as nn
import numpy as np

In [None]:
# 定义数据集类
class RSITMDDataset(Dataset):
    def __init__(self, image_dir, json_file, preprocess):
        with open(json_file, 'r') as f:
            self.data = json.load(f)['images']
            print(f"Loaded {len(self.data)} image-sentence pairs from {json_file}")
        self.image_dir = image_dir
        self.preprocess = preprocess

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_info = self.data[idx]
        image_path = os.path.join(self.image_dir, image_info['filename'])
        image = Image.open(image_path).convert("RGB")
        image = self.preprocess(image)
        
        # # 将所有描述拼接为一个字符串
        # descriptions = [sentence['raw'] for sentence in image_info['sentences']]
        # description = " ".join(descriptions)
        # 将所有描述拼接为一个字符串，并截断到最大长度77
        descriptions = [sentence['raw'] for sentence in image_info['sentences']]
        description = " ".join(descriptions)
        if len(description) > 77:
            description = description[:77]
        
        return image, description

In [None]:
# 初始化 CLIP 模型和预处理
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('RN50', device=device,download_root='checkpoints')
#设置为单精度
model = model.float()

In [None]:
# 加载数据集
image_dir = 'RSITMD\images'  # 请确保此目录存在并包含所有图像
json_file = 'RSITMD\dataset_RSITMD.json'  # 请确保此文件存在并包含所有图像描述
dataset = RSITMDDataset(image_dir, json_file, preprocess)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
# 定义对比损失函数
class ContrastiveLoss(nn.Module):
    def __init__(self, device):
        super(ContrastiveLoss, self).__init__()
        self.criterion = nn.CrossEntropyLoss()
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)).to(device)
        self.criterion = self.criterion.to(device)

    def forward(self, image_features, text_features):
        batch_size = len(image_features)
        image_features = image_features / image_features.norm(dim=1, keepdim=True)
        text_features = text_features / text_features.norm(dim=1, keepdim=True)
        logit_scale = self.logit_scale.exp()
        
        labels = torch.arange(batch_size, dtype=torch.long, device=image_features.device)
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logit_scale * text_features @ image_features.t()

        loss_i = self.criterion(logits_per_image, labels)
        loss_t = self.criterion(logits_per_text, labels)
        return (loss_i + loss_t) / 2

In [None]:
# 初始化损失函数和优化器
loss_fn = ContrastiveLoss(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

# 训练模型
epochs = 5
for epoch in range(epochs):
    model.train()
    total_loss = 0.0

    for images, descriptions in dataloader:
        images = images.to(device)

        texts = clip.tokenize(descriptions).to(device)


        image_features = model.encode_image(images)
        text_features = model.encode_text(texts)
        
        loss = loss_fn(image_features, text_features)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

# 保存模型
torch.save(model.state_dict(), "clip_finetuned.pth")

In [None]:
# 保存模型
torch.save(model.state_dict(), "clip_finetuned.pth")


In [None]:
# 初始化模型和数据集
device = "cuda" if torch.cuda.is_available() else "cpu"
# model, preprocess = clip.load("ViT-B/32", device=device)

# # 加载模型权重
model.load_state_dict(torch.load("clip_finetuned.pth"))
model.eval()  # 设置模型为评估模式

# 加载测试数据集
test_image_dir = 'RSITMD\images'  # 测试图像目录
test_json_file = 'RSITMD\dataset_RSITMD.json'  # 测试集 JSON 文件
test_dataset = RSITMDDataset(test_image_dir, test_json_file, preprocess)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# 计算图像和文本特征
all_image_features = []
all_text_features = []

with torch.no_grad():
    for images, descriptions in test_dataloader:
        images = images.to(device)
        texts = clip.tokenize(descriptions).to(device)
        
        image_features = model.encode_image(images)
        text_features = model.encode_text(texts)
        
        all_image_features.append(image_features)
        all_text_features.append(text_features)

all_image_features = torch.cat(all_image_features)
all_text_features = torch.cat(all_text_features)

# 计算相似度矩阵
logit_scale = model.logit_scale.exp()
similarity_matrix = logit_scale * all_image_features @ all_text_features.t()

# 计算 Recall@K
def recall_at_k(similarity_matrix, k):
    num_correct = 0
    for i in range(similarity_matrix.shape[0]):
        similarity_vector = similarity_matrix[i]
        sorted_indices = torch.argsort(similarity_vector, descending=True)
        if i in sorted_indices[:k]:
            num_correct += 1
    
    return num_correct / similarity_matrix.shape[0]

# 打印不同 K 值的 Recall
for k in [1, 5, 10]:
    recall = recall_at_k(similarity_matrix, k)
    print(f"Recall@{k}: {recall:.4f}")
