In [1]:
# 导入必要的库
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import numpy as np
import json
from tqdm import tqdm
import open_clip
import logging
import random
import math
from torch.optim import lr_scheduler
from torchvision import transforms

import torch.nn as nn

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # 设置所有 GPU 的随机种子

    # 允许 CUDNN 使用非确定性算法以提升性能
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)  # 42 是任意选择的种子值

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 配置字典，存储可调节的参数
CONFIG = {
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'model_name': 'ViT-B-32',  # 使用的模型名称
    'pretrained': 'weights/RemoteCLIP-ViT-B-32.pt',  # 请替换为您的预训练模型路径
    'json_file': 'datasets/RSITMD/dataset_RSITMD.json',  # 请替换为您的JSON文件路径
    'images_root': 'datasets/RSITMD/images',  # 请替换为您的图像目录路径
    'batch_size': 256,  # 根据您的GPU内存进行调整
    'num_epochs': 20,
    'learning_rate': 1e-5,
    'weight_decay': 1e-4,
    'num_workers': 8,
    'recall_k_list': [1, 5, 10],
    'logs_dir': 'logs',  # 日志文件夹
    'warmup_epochs': 0,  # 设置为3，根据需要调整
    'min_lr_factor': 0.01,  # 最小学习率因子
}
# 设置设备
device = CONFIG['device']

In [None]:
# 提取数据集名称
dataset_name = os.path.basename(os.path.dirname(CONFIG['json_file']))

# 提取预训练权重类型
if CONFIG['pretrained']:
    pretrained_weights_type = os.path.splitext(os.path.basename(CONFIG['pretrained']))[0]
else:
    pretrained_weights_type = 'none'

# 创建日志目录
model_logs_dir = os.path.join(CONFIG['logs_dir'], dataset_name, CONFIG['model_name'], pretrained_weights_type)
os.makedirs(model_logs_dir, exist_ok=True)

# 配置日志记录
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s %(levelname)s: %(message)s',
    handlers=[
        logging.FileHandler(os.path.join(model_logs_dir, 'train.log')),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger()

# 创建模型但不加载预训练权重
model, _, preprocess = open_clip.create_model_and_transforms(
    CONFIG['model_name'], pretrained=False)
tokenizer = open_clip.get_tokenizer(CONFIG['model_name'])
model = model.to(device)

# 从本地检查点加载自定义权重
if CONFIG['pretrained']:
    checkpoint = torch.load(CONFIG['pretrained'], map_location=device)
    state_dict = checkpoint.get('state_dict', checkpoint)
    model.load_state_dict(state_dict)

# 定义训练时的预处理，包括随机数据增强
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([
        transforms.RandomChoice([
            transforms.RandomRotation(90),
            transforms.RandomRotation(180),
            transforms.RandomRotation(270),
        ])
    ], p=0.5),
    preprocess,  # CLIP 模型的预处理操作
])

# 定义自定义数据集类
class CustomDataset(Dataset):
    def __init__(self, json_file, img_dir, split, transform, tokenizer):
        self.json_file = json_file
        self.img_dir = img_dir
        self.transform = transform
        self.tokenizer = tokenizer
        self.split = split  # 'train' 或 'test'
        self.images = []
        self.captions = []
        self._load_data()
        
    def _load_data(self):
        # 加载JSON文件
        with open(self.json_file, 'r') as f:
            data = json.load(f)
        for item in data['images']:
            if item['split'] == self.split:
                image_file = item['filename']
                # 每张图像可能有多条描述
                for sentence in item['sentences']:
                    self.images.append(image_file)
                    self.captions.append(sentence['raw'])
                    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image_path = os.path.join(self.img_dir, self.images[idx])
        image = Image.open(image_path).convert('RGB')
        image = self.transform(image)
        caption = self.captions[idx]
        # Tokenizer expects a list of texts
        text = self.tokenizer([caption])[0]
        return image, text

# 定义评估数据集类
class EvaluationDataset(Dataset):
    def __init__(self, json_file, img_dir, split, transform):
        self.json_file = json_file
        self.img_dir = img_dir
        self.transform = transform
        self.split = split  # 'test' 或 'val'
        self.images = []
        self.captions = []
        self.image_ids = []
        self._load_data()
        
    def _load_data(self):
        # 加载JSON文件
        with open(self.json_file, 'r') as f:
            data = json.load(f)
        for idx, item in enumerate(data['images']):
            if item['split'] == self.split:
                image_file = item['filename']
                captions = [sentence['raw'] for sentence in item['sentences']]
                self.images.append(image_file)
                self.captions.append(captions)  # 此图像的描述列表
                self.image_ids.append(idx)  # 为每个图像分配唯一ID
                
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image_path = os.path.join(self.img_dir, self.images[idx])
        image = Image.open(image_path).convert('RGB')
        image = self.transform(image)
        captions = self.captions[idx]  # 此图像的描述列表
        return image, captions, idx  # 返回图像、描述列表和索引

# 创建数据集和数据加载器
batch_size = CONFIG['batch_size']

# 训练数据集和加载器
train_dataset = CustomDataset(CONFIG['json_file'], CONFIG['images_root'], 'train', train_transform, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                          num_workers=CONFIG['num_workers'], pin_memory=True)

# 验证数据集和加载器
val_dataset = EvaluationDataset(CONFIG['json_file'], CONFIG['images_root'], 'test', preprocess)

def evaluation_collate_fn(batch):
    images, captions_list, indices = zip(*batch)
    images = torch.stack(images, 0)
    return images, captions_list, indices

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                        num_workers=CONFIG['num_workers'], pin_memory=True, collate_fn=evaluation_collate_fn)

2025-06-04 00:39:42,770 INFO: Loaded ViT-B-32 model config.


In [4]:
# 定义评估函数
def evaluate(model, val_loader, tokenizer, device, recall_k_list=[1, 5, 10]):
    model.eval()
    # 用于存储嵌入和索引的列表
    image_embs = []
    text_embs = []
    texts_image_indices = []
    all_captions = []
    
    # 第一步：计算图像嵌入并收集描述
    for batch in tqdm(val_loader, desc="Computing image embeddings"):
        images, captions_list, indices = batch
        images = images.to(device)
        indices = list(indices)
        
        with torch.no_grad():
            image_features = model.encode_image(images)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            image_embs.append(image_features)
        
        # 收集所有描述，并将每个描述映射到其图像索引
        for i, captions in zip(indices, captions_list):
            all_captions.extend(captions)
            texts_image_indices.extend([i] * len(captions))
    
    # 将图像嵌入连接起来并移动到同一设备
    image_embs = torch.cat(image_embs).to(device)
    
    # 第二步：计算文本嵌入
    # 由于可能有很多描述，分批处理
    batch_size = 256
    num_texts = len(all_captions)
    text_embs = []
    for i in tqdm(range(0, num_texts, batch_size), desc="Computing text embeddings"):
        captions_batch = all_captions[i:i+batch_size]
        texts = tokenizer(captions_batch).to(device)
        with torch.no_grad():
            text_features = model.encode_text(texts)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            text_embs.append(text_features)
    text_embs = torch.cat(text_embs).to(device)
    
    # 计算相似度得分并应用温度缩放
    if hasattr(model, 'logit_scale'):
        logit_scale = model.logit_scale.exp()
    else:
        logit_scale = 1.0  # 如果不存在，则默认为 1.0
    scores = text_embs @ image_embs.t() * logit_scale  # 形状：(num_texts, num_images)
    
    metrics = {}
    
    # 图像检索文本
    image_ranks = []
    for i in range(image_embs.size(0)):
        # 获取图像 i 与所有文本的相似度
        sims = scores[:, i]
        inds = torch.argsort(sims, descending=True)
        # 找到与图像 i 相关的所有文本的位置
        pos_inds = (torch.tensor(texts_image_indices, device=device) == i).nonzero(as_tuple=False).squeeze()
        if pos_inds.ndim == 0:
            pos_inds = pos_inds.unsqueeze(0)
        # 在所有相关文本中找到最小排名
        rank = scores.size(0)  # 初始化排名为最大值
        for pos in pos_inds:
            pos_rank = (inds == pos).nonzero(as_tuple=False).item()
            if pos_rank < rank:
                rank = pos_rank
        image_ranks.append(rank)
    image_ranks = np.array(image_ranks)
    
    for k in recall_k_list:
        r_at_k = 100.0 * np.mean(image_ranks < k)
        metrics[f'Retrieval Image to Text R@{k}'] = r_at_k
    
    # 文本检索图像
    text_ranks = []
    for i in range(text_embs.size(0)):
        # 获取文本 i 与所有图像的相似度
        sims = scores[i, :]
        inds = torch.argsort(sims, descending=True)
        gt_image = texts_image_indices[i]
        rank = (inds == gt_image).nonzero(as_tuple=False).item()
        text_ranks.append(rank)
    text_ranks = np.array(text_ranks)
    
    for k in recall_k_list:
        r_at_k = 100.0 * np.mean(text_ranks < k)
        metrics[f'Retrieval Text to Image R@{k}'] = r_at_k
    
    # 计算平均召回率
    if metrics:
        recall_values = [metrics[key] for key in metrics.keys()]
        metrics['Retrieval Mean Recall'] = np.mean(recall_values)
    else:
        metrics['Retrieval Mean Recall'] = 0.0
        logger.warning("No metrics computed during evaluation.")
    
    # 打印metrics键以调试
    logger.debug(f"Metrics keys: {list(metrics.keys())}")
    
    return metrics 

In [None]:
# 定义对比损失函数，使用模型的 logit_scale 和软标签
class ContrastiveLoss(nn.Module):
    def __init__(self, model):
        super(ContrastiveLoss, self).__init__()
        self.model = model

    def forward(self, image_embeddings, text_embeddings):
        # 归一化嵌入
        image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)
        text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)

        # 计算 logits
        logits_per_image = image_embeddings @ text_embeddings.t()
        logits_per_text = text_embeddings @ image_embeddings.t()
        
        # 数值稳定性处理
        logits_per_image = logits_per_image - logits_per_image.max(dim=-1, keepdim=True)[0]
        logits_per_text = logits_per_text - logits_per_text.max(dim=-1, keepdim=True)[0]

        # 从模型中获取 logit_scale 并应用温度缩放
        if hasattr(self.model, 'logit_scale'):
            logit_scale = self.model.logit_scale.exp()
        else:
            logit_scale = 1.0  # 如果不存在，则默认为 1.0

        logits_per_image = logits_per_image * logit_scale
        logits_per_text = logits_per_text * logit_scale

        # 计算相似度矩阵（用于软标签）
        images_similarity = image_embeddings @ image_embeddings.t()
        texts_similarity = text_embeddings @ text_embeddings.t()

        # 计算软标签（目标分布）
        # 注意，这里也要乘以 logit_scale，以保持尺度一致
        targets_per_image = F.softmax((images_similarity + texts_similarity) / 2 * logit_scale, dim=-1)
        targets_per_text = targets_per_image.t()  # 转置以获得文本的目标

        # 计算图像到文本的损失
        loss_i = self.cross_entropy_with_soft_labels(logits_per_image, targets_per_image)
        # 计算文本到图像的损失
        loss_t = self.cross_entropy_with_soft_labels(logits_per_text, targets_per_text)
        # 平均两个损失
        loss = (loss_i + loss_t) / 2.0
        return loss

    def cross_entropy_with_soft_labels(self, logits, soft_targets):
        log_prob = F.log_softmax(logits, dim=-1)
        loss = (-soft_targets * log_prob).sum(dim=-1).mean()
        return loss
        
# 定义基于步骤的学习率调整函数
def lr_lambda_func(step, warmup_steps, total_steps, min_lr_factor=0.0):
    if step < warmup_steps:
        return float(step) / float(max(1, warmup_steps))  # 线性 warmup
    else:
        progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
        return float(min_lr_factor + (1.0 - min_lr_factor) * cosine_decay)             

# 初始化自定义损失函数，使用模型的 logit_scale
loss_fn = ContrastiveLoss(model).to(device)

# 定义需要排除权重衰减的参数关键字
no_decay_keywords = ['bias', 'ln_', 'logit_scale']

decay_params = []
no_decay_params = []

for name, param in model.named_parameters():
    if not param.requires_grad:
        continue  # 排除冻结的参数
    if any(keyword in name for keyword in no_decay_keywords):
        no_decay_params.append(param)
    else:
        decay_params.append(param)

# 初始化优化器
optimizer = torch.optim.AdamW(
    [
        {'params': decay_params, 'weight_decay': CONFIG['weight_decay']},
        {'params': no_decay_params, 'weight_decay': 0.0},
    ],
    lr=CONFIG['learning_rate']
)

# 计算总的训练步骤数和 warmup 步骤数
total_steps = CONFIG['num_epochs'] * len(train_loader)
warmup_steps = CONFIG['warmup_epochs'] * len(train_loader)

# 定义学习率调度器
scheduler = lr_scheduler.LambdaLR(
    optimizer,
    lr_lambda=lambda step: lr_lambda_func(
        step,
        warmup_steps,
        total_steps,
        CONFIG['min_lr_factor']
    )
)

num_epochs = CONFIG['num_epochs']
best_mean_recall = 0.0
global_step = 0  # 全局训练步骤

# 训练循环
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}", leave=False)
    
    for batch in progress_bar:
        images, texts = batch  # images: [batch_size, 3, H, W], texts: [batch_size, seq_length]
        images = images.to(device)
        texts = texts.to(device)

        optimizer.zero_grad()

        # 获取嵌入
        image_embeddings = model.encode_image(images)
        text_embeddings = model.encode_text(texts)

        # 使用对比损失函数计算损失
        loss = loss_fn(image_embeddings, text_embeddings)
    
        loss.backward()
        optimizer.step()
        scheduler.step()  # 在每个 batch 后更新学习率

        running_loss += loss.item()
        global_step += 1  # 更新全局训练步骤

        # 获取当前学习率
        current_lr = optimizer.param_groups[0]['lr']
        
        # 更新进度条后缀，显示当前学习率
        progress_bar.set_postfix({'lr': f"{current_lr:.8f}", 'loss': f"{loss.item():.4f}"})

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

    # 获取当前学习率
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Current Learning Rate: {current_lr:.8f}")

    # 记录学习率到日志
    logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}, LR: {current_lr:.8f}")

    # 评估
    try:
        metrics = evaluate(model, val_loader, tokenizer, device, CONFIG['recall_k_list'])

        # 打印所有指标
        for key, value in metrics.items():
            print(f"{key}: {value:.2f}")
            logger.info(f"{key}: {value:.2f}")

        # 如果平均召回率提升，保存模型
        if 'Retrieval Mean Recall' in metrics:
            mean_recall = metrics['Retrieval Mean Recall']
            if mean_recall > best_mean_recall:
                best_mean_recall = mean_recall
                best_model_path = os.path.join(model_logs_dir, f"{CONFIG['model_name']}_best_model.pt")
                torch.save(model.state_dict(), best_model_path)
                print(f"New best model saved with Mean Recall: {best_mean_recall:.2f}")
                logger.info(f"New best model saved with Mean Recall: {best_mean_recall:.2f}")
    except Exception as e:
        print(f"Error during evaluation: {e}")
        logger.error(f"Error during evaluation: {e}")

Training Epoch 1/20:   0%|          | 0/84 [00:00<?, ?it/s]

2025-06-04 00:40:21,438 INFO: Epoch 1/20, Loss: 1.2817, LR: 0.00000994                          


Epoch 1/20, Loss: 1.2817
Current Learning Rate: 0.00000994


Computing image embeddings: 0it [00:00, ?it/s]
2025-06-04 00:40:21,577 ERROR: Error during evaluation: torch.cat(): expected a non-empty list of Tensors


Error during evaluation: torch.cat(): expected a non-empty list of Tensors


Training Epoch 2/20:  74%|███████▍  | 62/84 [00:26<00:09,  2.37it/s, lr=0.00000982, loss=0.5543]