# CIFAR-10 Dataset Handling with Atria

## Setup and Auto-reloading Modules
We enable auto-reloading of modules so that any changes in imported libraries are automatically reflected.

In [7]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Importing Dependencies
Here, we modify the system path to include the project's root directory and import necessary modules for dataset handling.

In [8]:
import sys 
sys.path.append('../../')

## Loading the CIFAR-10 Dataset
We load the CIFAR-10 dataset using the `CIFAR10.load` method, specifying the training split.

In [9]:
from atria_core.types import DatasetSplit
from atria_core.types import ImageInstance
from atria_examples.datasets.cifar10_huggingface import HuggingfaceCifar10

cifar10 = HuggingfaceCifar10.load(
    split=DatasetSplit.train,
    config_name='plain_text',
)

[2025-04-28 03:22:18][atria.data.datasets.atria_dataset][INFO] Loading dataset HuggingfaceCifar10/plain_text from registry.


[2025-04-28 03:22:20][atria.data.datasets.atria_dataset][INFO] Setting up dataset type with split=train
[2025-04-28 03:22:20][atria.data.datasets.atria_dataset][INFO] No storage manager provided. Preparing downloads and split iterator.
Downloading data: 100%|██████████| 120M/120M [00:11<00:00, 10.5MB/s] 
Downloading data: 100%|██████████| 23.9M/23.9M [00:02<00:00, 10.3MB/s]


In [12]:
# Extract a sample instance from the dataset
next(iter(cifar10))

ImageInstance(
    index=0,
    id=UUID('73c861b9-3b9a-456a-bbbc-13e3dfa25130'),
    image=Image(
        file_path=None,
        content=<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32 at 0x79EFA8FFB7A0>,
        source_size=None,
        shape=(3, 32, 32),
        dtype=None
    ),
    label=Label(value=0, name='airplane')
)

## Creating batched instances from a list of samples
We create a list of samples and then call batched on the list which is the class method of the specific instance

In [14]:
# Make a list of instances
instances = []
for idx, sample in enumerate(cifar10):
    instances.append(sample.to_tensor())
    if (idx + 1) >= 2:
        break

# Batch the instances
instances[0].batched(instances)

ImageInstance(
    batch_size=2,
    index=[0, 1],
    id=[
        UUID('920b3bed-92e2-4e43-afb3-749fae4aefc9'),
        UUID('69349f28-b889-4a54-84b2-575c0911a143')
    ],
    image=Image(
        batch_size=2,
        file_path=None,
        content=tensor([[[[0.6980, 0.6980, 0.6980,  ..., 0.6667, 0.6588, 0.6471],
          [0.7059, 0.7020, 0.7059,  ..., 0.6784, 0.6706, 0.6588],
          [0.6941, 0.6941, 0.6980,  ..., 0.6706, 0.6627, 0.6549],
          ...,
          [0.4392, 0.4431, 0.4471,  ..., 0.3922, 0.3843, 0.3961],
          [0.4392, 0.4392, 0.4431,  ..., 0.4000, 0.4000, 0.4000],
          [0.4039, 0.3922, 0.4039,  ..., 0.3608, 0.3647, 0.3569]],

         [[0.6902, 0.6902, 0.6902,  ..., 0.6588, 0.6510, 0.6392],
          [0.6980, 0.6941, 0.6980,  ..., 0.6706, 0.6627, 0.6510],
          [0.6863, 0.6863, 0.6902,  ..., 0.6627, 0.6549, 0.6471],
          ...,
          [0.4196, 0.4275, 0.4314,  ..., 0.3804, 0.3686, 0.3725],
          [0.4000, 0.4039, 0.4039,  ..., 0.3725, 0.3647

## Dataset handling with File Storage
Load the dataset with a file storage manager that first caches the data into disk

In [15]:
from atria.data.storage.file_storage_manager import FileStorageManager
from atria.data.storage.utilities import FileStorageType

# Creat a file storage manager
file_storage_manager = FileStorageManager(
    storage_dir="/tmp", streaming_mode=False, storage_type=FileStorageType.MSGPACK, 
    max_samples=100, # save up to 100 samples
)

In [16]:
from atria_core.types import DatasetSplit
from atria_examples.datasets.cifar10_huggingface import HuggingfaceCifar10

# load the dataset with the file storage manager
cifar10 = HuggingfaceCifar10.load(
    split=DatasetSplit.train,
    storage_manager=file_storage_manager,
    config_name='plain_text',
)

[2025-04-28 03:23:35][atria.data.datasets.atria_dataset][INFO] Loading dataset HuggingfaceCifar10/plain_text from registry.
[2025-04-28 03:23:37][atria.data.datasets.atria_dataset][INFO] Setting up dataset type with split=train
[2025-04-28 03:23:37][atria.data.storage.file_storage_manager][INFO] Preparing dataset split train to cache dir /tmp/msgpack/HuggingfaceCifar10/plain_text/a061c9f2b33df971/max_samples-100 with max samples 100.
2025-04-28 03:23:38,069	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
Writing dataset HuggingfaceCifar10 to FileStorageType.MSGPACK: 0it [00:00, ?it/s]2025-04-28 03:23:39,105	INFO worker.py:1816 -- Started a local Ray instance.
:job_id:01000000
:actor_name:FileStorageShardWriterActor
[2025-04-28 03:23:39][atria.data.storage.msgpack_shard_writer][INFO] # Writing /tmp/msgpack/HuggingfaceCifar10/plain_text/a061c9f2b33df971/max_samples-100/train-000000-000000.ms

:job_id:01000000
:actor_name:FileStorageShardWriterActor


Writing dataset HuggingfaceCifar10 to FileStorageType.MSGPACK: 99it [00:01, 62.56it/s]
train-000000-000000.msgpack   0% 00:00<?, ?it/s
[2025-04-28 03:23:39][atria.data.datasets.metadata][INFO] Writing dataset storage info to /tmp/msgpack/HuggingfaceCifar10/plain_text/a061c9f2b33df971/max_samples-100/train_storage_info.json:
AtriaDatasetStorageInfo(
    metadata=DatasetMetadata(
        dataset_name='HuggingfaceCifar10',
        citation='',
        homepage='',
        license='',
        config=AtriaHuggingfaceDatasetConfig(
            name='plain_text',
            description=None,
            version='0.0.0',
            data_urls=None,
            hf_repo='uoft-cs/cifar10'
        ),
        dataset_labels=DatasetLabels(
            instance_classification=[
                'airplane',
                'automobile',
                'bird',
                'cat',
                'deer',
                'dog',
                'frog',
                'horse',
                'ship',


100 samples written


In [17]:
# Extract a sample instance from the dataset
print(cifar10[0])
print(next(iter(cifar10)))

ImageInstance(
    index=0,
    id=UUID('64ba361d-dca7-4b8c-9f75-f61eb661dfd3'),
    image=Image(
        file_path=None,
        content=<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32 at 0x79EFA8ED8110>,
        source_size=None,
        shape=(3, 32, 32),
        dtype=None
    ),
    label=Label(value=0, name='airplane')
)
ImageInstance(
    index=0,
    id=UUID('64ba361d-dca7-4b8c-9f75-f61eb661dfd3'),
    image=Image(
        file_path=None,
        content=<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32 at 0x79EFD9B3E900>,
        source_size=None,
        shape=(3, 32, 32),
        dtype=None
    ),
    label=Label(value=0, name='airplane')
)
