# Create a dataset with MNIST images

In [1]:
%load_ext autoreload
%autoreload 2

In [33]:
from pathlib import Path
import random

import keras
from tqdm import tqdm

In [5]:
from sedpack.io import Dataset, Metadata, DatasetStructure, Attribute
from sedpack.io.types import SplitT

## Describe data we are saving

In [7]:
# General info about the dataset
metadata = Metadata(
    description="MNIST dataset in the sedpack format",
    dataset_license="""
    Yann LeCun and Corinna Cortes hold the copyright of MNIST dataset, which is
    a derivative work from original NIST datasets. MNIST dataset is made
    available under the terms of the Creative Commons Attribution-Share Alike
    3.0 license.
    """,
    custom_metadata={
        "list of authors": ["Alice", "Bob"],
    },
)

In [10]:
# Types of attributes stored
dataset_structure = DatasetStructure(saved_data_description=[
    Attribute(
        name="input",
        shape=(28, 28),
        dtype="float16",  # We are going to directly save scaled data
    ),
    Attribute(
        name="digit",
        shape=(),
        dtype="int8",
    ),
])

## Create a new dataset

No need to worry about overwriting an existing dataset.
The method `Dataset.create` checks that there is no existing metadata file.

In [29]:
dataset = Dataset.create(
    path=Path.home() / "Datasets/mnist_dataset",  # All files are stored here
    metadata=metadata,
    dataset_structure=dataset_structure,
)

## Get raw data and scale them

In [22]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255

## Convert the holdout split

In [30]:
# DatasetFiller makes sure that all shard files are written properly
# when exiting the context.
with dataset.filler() as dataset_filler:
    # Determine which data are in the holdout (test)
    for i in tqdm(range(len(x_test)), desc="holdout"):
        dataset_filler.write_example(
            values={
                "input": x_test[i],
                "digit": y_test[i],
            },
            split="holdout",
        )

holdout: 100%|██████████| 10000/10000 [00:03<00:00, 3021.25it/s]


## Convert and split training into training and test (validation)

The splits are saved and can be easily and deterministically loaded by others.

In [34]:
with dataset.filler() as dataset_filler:
    # Randomly assign 10% of validation and the rest is training
    train_and_val = list(zip(x_train, y_train))
    random.shuffle(train_and_val)
    for i in tqdm(range(len(train_and_val)), desc="train and val"):
        x, y = train_and_val[i]
        split: SplitT = "train"
        if i < len(train_and_val) // 10:
            split = "test"
        dataset_filler.write_example(
            values={
                "input": x,
                "digit": y
            },
            split=split,
        )

train and val: 100%|██████████| 60000/60000 [00:20<00:00, 2979.46it/s]
