In [None]:
from atdata.dataset import (
    Dataset,
    PackableSample,
)

from dataclasses import dataclass

import numpy as np
from numpy.typing import NDArray

In [6]:
@dataclass
class TestSample( PackableSample ):
    value: int
    data: NDArray

In [11]:
x = TestSample( value=42, data=np.array([1,2,3]) )
y = TestSample( value=7, data=np.array([4,5,6]) )

In [36]:
x_wds = x.as_wds

In [37]:
y_wds = y.as_wds

In [38]:
x_wds

{'__key__': '5d33e544-afd0-11f0-8000-000000000000',
 'msgpack': b"\x82\xa5value*\xa4data\xc4\x98\x93NUMPY\x01\x00v\x00{'descr': '<i8', 'fortran_order': False, 'shape': (3,), }                                                            \n\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00"}

In [28]:
list( sorted( [ x_wds['__key__'], y_wds['__key__'] ] ) )

['414e57e2-afd0-11f0-8000-000000000000',
 '4162ba48-afd0-11f0-8000-000000000000']

---

In [44]:
from pathlib import Path

output_dir = Path( 'output' )
output_dir.mkdir( parents = True, exist_ok = True )

In [45]:
import webdataset as wds

k_samples_test = 1_000

with wds.ShardWriter( 
    pattern = str( output_dir / 'test_data-%06d.tar' ),
    maxcount = 100,
) as sink:
    for i_sample in range( k_samples_test ):
        sample = TestSample(
            value = i_sample,
            data = np.random.randn( 64, 256 ),
        )
        sink.write( sample.as_wds )

# writing output/test_data-000000.tar 0 0.0 GB 0
# writing output/test_data-000001.tar 100 0.0 GB 100
# writing output/test_data-000002.tar 100 0.0 GB 200
# writing output/test_data-000003.tar 100 0.0 GB 300
# writing output/test_data-000004.tar 100 0.0 GB 400
# writing output/test_data-000005.tar 100 0.0 GB 500
# writing output/test_data-000006.tar 100 0.0 GB 600
# writing output/test_data-000007.tar 100 0.0 GB 700
# writing output/test_data-000008.tar 100 0.0 GB 800
# writing output/test_data-000009.tar 100 0.0 GB 900


In [55]:
output_stem = (output_dir / 'test_data-{shards}.tar').as_posix()

ds = Dataset[TestSample]( output_stem.format( shards = '{000000..000009}') )

In [86]:
for sample in ds.shuffled( batch_size = 32 ):
    break
    # print( sample )

In [89]:
sample.data.shape

(32, 64, 256)