In [2]:
import os
import torch
import h5py
from tqdm import tqdm
from glob import glob
pt_folder = '/home/mila/l/le.zhang/scratch/light_align/data/tensor_data/image_embedding/dinov2-large/LLaVA558K'  # 替换为你的文件夹路径
pt_file_list = glob(os.path.join(pt_folder, '*.pt'))


with h5py.File('LLaVA558K.h5', 'w') as hdf:
    tensor_counter = 0  # 用于给每个张量命名
    for pt_file in tqdm(pt_file_list):
        tensors = torch.load(pt_file, weights_only=True)
        for tensor in tensors:
            hdf.create_dataset(f'tensor_{tensor_counter}', data=tensor.numpy())
            tensor_counter += 1

print("所有 .pt 文件已成功转换为 HDF5 文件！")


100%|██████████| 273/273 [01:31<00:00,  2.97it/s]


所有 .pt 文件已成功转换为 HDF5 文件！


In [1]:
import torch
import h5py
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class HDF5Dataset(Dataset):
    def __init__(self, h5_file_path):
        # Store the path to the HDF5 file
        self.h5_file_path = h5_file_path
        
        # Assuming the tensors in the dataset are stored in the form of 'tensor_0', 'tensor_1', ...
        # Get the names of all tensors (keys)
        with h5py.File(h5_file_path, 'r') as h5_file:
            self.tensor_keys = list(h5_file.keys())

    def __len__(self):
        # The length of the dataset is the number of tensors in the HDF5 file
        return len(self.tensor_keys)
    
    def _open_hdf5(self):
        self._hf = h5py.File(self.h5_file_path, 'r')

    def __getitem__(self, idx):
        # Open the HDF5 file
        if not hasattr(self, '_hf'):
            self._open_hdf5()
        tensor_name = self.tensor_keys[idx]
        tensor_data = torch.from_numpy(self._hf[tensor_name][...])
        return tensor_data
    
    def close(self):
        if hasattr(self, '_hf'):
            self._hf.close()

# Example usage
h5_file_path = 'LLaVA558K.h5'
h5_dataset = HDF5Dataset(h5_file_path)

In [2]:
import time
h5_dataloader = DataLoader(h5_dataset, batch_size=32568, num_workers=0)
idx = 0
start_time = time.time()
for batch in h5_dataloader:
    print(batch.shape)
    idx += 1
    if idx > 5:
        break
end_time = time.time()
print(f"Time taken for first 11 batches: {end_time - start_time} seconds")

torch.Size([32568, 2048])
torch.Size([32568, 2048])
torch.Size([32568, 2048])
torch.Size([32568, 2048])
torch.Size([32568, 2048])
torch.Size([32568, 2048])
Time taken for first 11 batches: 33.375715494155884 seconds


In [6]:
import h5py
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class HDF5Dataset(Dataset):
    def __init__(self, h5_file_path):
        self.h5_file_path = h5_file_path
        self.h5_file = h5py.File(self.h5_file_path, 'r')
        self.tensor_keys = list(self.h5_file.keys())

    def __len__(self):
        return len(self.tensor_keys)

    def __getitem__(self, idx):
        tensor_name = self.tensor_keys[idx]
        # 获取数据集的形状和dtype
        dataset = self.h5_file[tensor_name]
        shape = dataset.shape
        dtype = dataset.dtype
        
        # 使用np.memmap映射到内存
        memmapped_data = np.memmap(self.h5_file_path, mode='r', shape=shape, dtype=dtype, offset=dataset.id.get_offset())
        
        # 将memmapped数据转换为PyTorch张量
        tensor_data = torch.from_numpy(memmapped_data)
        return tensor_data

    def close(self):
        self.h5_file.close()

h5_file_path = '/home/mila/l/le.zhang/scratch/light_align/data/LLaVA558K.h5'
h5_dataset = HDF5Dataset(h5_file_path)


In [17]:
import time
h5_dataloader = DataLoader(h5_dataset, batch_size=32568, num_workers=4, prefetch_factor=2)
idx = 0
start_time = time.time()
for batch in h5_dataloader:
    print(batch.shape)
    idx += 1
    if idx > 5:
        break
end_time = time.time()
print(f"Time taken for first 11 batches: {end_time - start_time} seconds")

torch.Size([32568, 2048])
torch.Size([32568, 2048])
torch.Size([32568, 2048])
torch.Size([32568, 2048])
torch.Size([32568, 2048])
torch.Size([32568, 2048])
Time taken for first 11 batches: 79.70797538757324 seconds


# current emebdding dataset

In [47]:
import torch
import os
import glob
from natsort import natsorted
import numpy as np

class VLEmbeddingDataset(Dataset):
    def __init__(self, text_embedding_list):
        self.text_embedding_dir = text_embedding_list
        
        # Note: must sort the file names to ensure the correspondence of text and image vectors
        self.text_files = []
        for dir_path in text_embedding_list:
            files = glob.glob(os.path.join(dir_path, "*.pt"))
            sorted_files = natsorted(files)
            self.text_files.extend(sorted_files)
        
        self.text_vectors = [vector for file in self.text_files for vector in torch.load(file, weights_only=True)]


        self.total_length = len(self.text_vectors)
        self.text_dim = self.text_vectors[0].shape[0]
    def __len__(self):
        return self.total_length
    
    def __getitem__(self, idx):
        return self.text_vectors[idx]
    
dataset = VLEmbeddingDataset(['/home/mila/l/le.zhang/scratch/light_align/data/text_embedding/gte-large-en-v1.5/LLaVA558K'])

In [48]:
dataloader = DataLoader(dataset, batch_size=32568, shuffle=True, num_workers=0)

In [72]:
for batch in dataloader:
    print(batch.shape)
    break

torch.Size([32568, 1024])
