Skip to content
TFRecord reader for PyTorch
Python
Branch: master
Clone or download
Fetching latest commit…
Cannot retrieve the latest commit at this time.
Permalink
Type Name Latest commit message Commit time
Failed to load latest commit information.
tfrecord Better seed. Nov 28, 2019
.gitignore Update setup.py. Nov 26, 2019
MANIFEST Update manifest. Nov 27, 2019
README.md Fixed multi-worker issues. Nov 27, 2019
setup.py Better seed. Nov 28, 2019

README.md

TFRecord reader

Installation

pip3 install tfrecord

Usage

It's recommended to create an index file for each TFRecord file. Index file must be provided when using multiple workers, otherwise the loader may return duplicate records.

python3 -m tfrecord.tools.tfrecord2idx <tfrecord path> <index path>

Use TFRecordDataset to read TFRecord files in PyTorch.

import torch
from tfrecord.torch.dataset import TFRecordDataset

tfrecord_path = "/path/to/data.tfrecord"
index_path = None
description = {"image": "byte", "label": "float"}
dataset = TFRecordDataset(tfrecord_path, index_path, description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32)

data = next(iter(loader))
print(data)

Use MultiTFRecordDataset to read multiple TFRecord files. This class samples from given tfrecord files with given probability.

import torch
from tfrecord.torch.dataset import MultiTFRecordDataset

tfrecord_pattern = "/path/to/{}.tfrecord"
index_pattern = "/path/to/{}.index"
splits = {
    "dataset1": 0.8,
    "dataset2": 0.2,
}
description = {"image": "byte", "label": "int"}
dataset = MultiTFRecordDataset(tfrecord_pattern, index_pattern, splits, description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32)

data = next(iter(loader))
print(data)
You can’t perform that action at this time.