In [1]:
import os
import torch

In [2]:
save_dir = "/mnt/data/CVPR2025/task1_data/images_features"
def load_features(save_dir, batch_num):
    """加载指定批次的特征"""
    file_path = os.path.join(save_dir, f'features_batch_{batch_num}.pt')
    data_dict = torch.load(file_path)
    return data_dict['features'], data_dict['image_paths']

In [3]:
features, image_paths = load_features(save_dir, 0)

  data_dict = torch.load(file_path)


In [5]:
image_paths[:2]
image_id = image_paths

['/mnt/data/CVPR2025/task1_data/images/images/39632936139492141.png',
 '/mnt/data/CVPR2025/task1_data/images/images/39632936139492779.png']

In [8]:
import os
import json
import torch
import numpy as np
import h5py
from astropy.table import Table
from torch.utils.data import Dataset
from typing import Dict, List, Optional, Tuple
from pathlib import Path
from tqdm import tqdm, trange

def bf16_to_uint16(tensor: torch.Tensor) -> np.ndarray:
    """将bf16张量转换为uint16数组进行存储"""
    return tensor.view(dtype=torch.uint16).numpy()

def uint16_to_bf16(array: np.ndarray) -> torch.Tensor:
    """将uint16数组转换回bf16张量"""
    return torch.from_numpy(array).view(dtype=torch.bfloat16)

def merge_feature_batches(save_dir: str, output_file: str, feature_shape: tuple = (4, 1601)):
    """
    将所有批次的特征文件合并到一个HDF5文件中，使用内存映射方式处理
    
    Args:
        save_dir: 特征文件所在目录
        output_file: 输出HDF5文件路径
        feature_shape: 每个特征的形状，默认为(4, 1601)
    
    Returns:
        None
    """
    # 首先统计总样本数和收集所有的image_paths
    total_samples = 0
    all_image_paths = []
    batch_files = []
    
    print("Counting total samples...")
    for batch_num in trange(1000):  # 设置一个足够大的上限
        file_path = os.path.join(save_dir, f'features_batch_{batch_num}.pt')
        if not os.path.exists(file_path):
            break
        
        # 只加载元信息
        data_dict = torch.load(file_path, map_location='cpu')
        total_samples += len(data_dict['features'])
        all_image_paths.extend(data_dict['image_paths'])
        batch_files.append(file_path)
        
        # 释放内存
        del data_dict
        torch.cuda.empty_cache()
    
    if not batch_files:
        raise ValueError(f"No feature files found in {save_dir}")
    
    print(f"Total samples: {total_samples}")
    
    # 从图像路径中提取targetID
    target_ids = [Path(p).stem for p in all_image_paths]
    
    # 创建targetID到index的映射
    id_to_idx = {tid: idx for idx, tid in enumerate(target_ids)}
    
    # 创建HDF5文件
    with h5py.File(output_file, 'w') as f:
        # 创建特征数据集，使用uint16存储bf16数据
        features_dataset = f.create_dataset(
            'features', 
            shape=(total_samples, *feature_shape),
            dtype='uint16',
            chunks=(1, *feature_shape),  # 每个样本作为一个chunk
            compression="gzip",
            compression_opts=4
        )
        
        # 写入特征
        current_idx = 0
        for file_path in tqdm(batch_files, desc="Merging features"):
            # 逐个批次加载并写入
            data_dict = torch.load(file_path, map_location='cpu')
            batch_features = data_dict['features']  # shape: (batch_size, 4, 1601)
            
            # 确保形状正确
            if batch_features.shape[1:] != feature_shape:
                raise ValueError(f"Unexpected feature shape: {batch_features.shape[1:]}, expected {feature_shape}")
            
            # 转换为uint16并保存
            batch_features_uint16 = bf16_to_uint16(batch_features)
            batch_size = len(batch_features)
            
            features_dataset[current_idx:current_idx + batch_size] = batch_features_uint16
            current_idx += batch_size
            
            # 释放内存
            del data_dict, batch_features, batch_features_uint16
            torch.cuda.empty_cache()
        
        # 将id_to_idx映射保存为属性
        f.attrs['id_to_idx'] = json.dumps(id_to_idx)
        # 保存特征形状信息
        f.attrs['feature_shape'] = json.dumps(feature_shape)

In [9]:
merge_feature_batches(save_dir, output_file="/mnt/data/CVPR2025/task1_data/images_features.hdf5")

Counting total samples...


  data_dict = torch.load(file_path, map_location='cpu')
  6%|▌         | 55/1000 [10:53<3:07:12, 11.89s/it]


KeyboardInterrupt: 