In [3]:
import numpy as np
import torch
from torch.utils.data import Dataset

class BinaryFileDataset(Dataset):
    def __init__(self, file_path, dtype=np.float32, sample_shape=None):
        self.file_path = file_path
        self.dtype = dtype
        self.sample_shape = sample_shape

        # Compute the size of one sample (in bytes)
        sample_size = np.prod(sample_shape) * np.dtype(dtype).itemsize
        # Memory-map the binary file
        self.mem_map = np.memmap(file_path, dtype=dtype, mode='r')
        # Compute the total number of samples in the binary file
        self.num_samples = self.mem_map.size // np.prod(sample_shape)
        
    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        start_idx = np.prod(self.sample_shape) * idx
        end_idx = start_idx + np.prod(self.sample_shape)
        # Read a chunk from the memory-mapped file
        sample_np = self.mem_map[start_idx:end_idx]
        sample_np = sample_np.reshape(self.sample_shape)
        sample_torch = torch.from_numpy(sample_np)
        return {'input': sample_torch}  # No label here; you can add it if needed


In [7]:
# Define the shape of one sample and its data type (change as needed)
sample_shape = (128, 768)  # e.g., (1, 512)
dtype = np.float32  # or np.float64, depending on how you saved it

# Create a dataset from the binary file
dataset = BinaryFileDataset(
    "data/final_layer_output_128.bin", dtype=dtype, sample_shape=sample_shape
)

# Create a data loader
from torch.utils.data import DataLoader

data_loader = DataLoader(dataset, batch_size=1, shuffle=True)

In [8]:
# analyze the mean and variance of the dataset
mean = 0.
std = 0.

for batch in data_loader:
    data = batch['input']
    batch_samples = data.size(0) # batch size (the last batch can have smaller size!)
    data = data.view(batch_samples, data.size(1)*data.size(2))
    mean += data.mean(1).sum(0)
    std += data.std(1).sum(0)

mean /= len(data_loader.dataset)
std /= len(data_loader.dataset)

In [9]:
mean, std

(tensor(0.0094), tensor(0.2968))

In [6]:
dataiter = iter(data_loader)
sample = next(dataiter)
#print(sample['input'])
print(sample['input'].shape)
len(data_loader)*8

torch.Size([8, 128, 768])


9600