This Example shows how to convert a torch dataset into parquet files for efficient storage and loading.

In [None]:
import os

import datasets
from torch.utils import data
from torchvision.datasets import MNIST

In [2]:
# convert torch dataset to datasets.Dataset
root_path = os.getcwd()
print("Root path:", root_path)

train_dataset = MNIST(f"{root_path}/data", download=True, train=True)
val_dataset = MNIST(f"{root_path}/data", download=True, train=False)


def gen(torch_dataset):
    for idx in range(len(torch_dataset)):
        example = torch_dataset[idx]
        yield {
            "image": example[0],
            "label": example[1],
        }  # this has to be a dictionary


hf_train_dataset = datasets.Dataset.from_generator(
    gen,
    gen_kwargs={"torch_dataset": train_dataset},
    features=datasets.Features(
        {"image": datasets.Image(), "label": datasets.ClassLabel(num_classes=10)}
    ),
)
hf_val_dataset = datasets.Dataset.from_generator(
    gen,
    gen_kwargs={"torch_dataset": val_dataset},
    features=datasets.Features(
        {"image": datasets.Image(), "label": datasets.ClassLabel(num_classes=10)}
    ),
)

hf_dataset_dict = datasets.DatasetDict(
    {
        "train": hf_train_dataset,
        "validation": hf_val_dataset,
    }
)

Root path: /Users/tanganke/Documents/GitHub/fusion_bench/examples/convert_dataset


100%|██████████| 9.91M/9.91M [00:01<00:00, 6.08MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 766kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.63MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.96MB/s]


Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
# save the dataset to parquet files
hf_train_dataset.to_parquet("converted_data/mnist_dataset/data/train.parquet")
hf_val_dataset.to_parquet("converted_data/mnist_dataset/data/validation.parquet")

# or push to the hub (recommended if you are going to share the dataset)
# hf_dataset_dict.push_to_hub(repo_id="<user_name>/mnist_dataset")

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

2865730

In [4]:
# load the dataset from parquet files
datasets.load_dataset("converted_data/mnist_dataset")

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 60000
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 10000
    })
})