In [1]:
import xarray as xr
import numpy as np
import datetime
from tonik.utils import generate_test_data
from torch.utils.data import BatchSampler, RandomSampler, SequentialSampler

In [12]:
class SliceBatchSampler(BatchSampler):
    def __init__(self, sampler, batch_size, drop_last):
        super().__init__(sampler, batch_size, drop_last)
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self):
        # Generate all indices from the sampler
        indices = list(self.sampler)

        # Yield slices instead of list of indices
        for i in range(indices[0], indices[-1], self.batch_size):
            if self.drop_last and i + self.batch_size > len(indices):
                break
            yield slice(i, min(i + self.batch_size, indices[-1]))

list(SliceBatchSampler(range(0, 12), 4, False))

[slice(0, 4, None), slice(4, 8, None), slice(8, 11, None)]

In [8]:
list(range(5, 10))

[5, 6, 7, 8, 9]

In [11]:
xds = generate_test_data(dim=2)
xds.to_netcdf("/tmp/test.nc")
xds = generate_test_data(dim=1)
xds.to_netcdf("/tmp/test_1d.nc")

In [21]:
with xr.open_dataset("/tmp/test.nc", chunks={'datetime': 144}) as xds1:
    print(len(xds1['ssam'].sizes))
    print(xds1['ssam'].median(axis=1).compute())

2
<xarray.DataArray 'ssam' (frequency: 10)>
array([488.6201254, 488.6201254, 488.6201254, 488.6201254, 488.6201254,
       488.6201254, 488.6201254, 488.6201254, 488.6201254, 488.6201254])
Coordinates:
  * frequency  (frequency) int64 0 1 2 3 4 5 6 7 8 9


In [22]:
try:
    with xr.open_dataset("/tmp/test_1d.nc", chunks={'datetime': 144}) as xds1:
        print(len(xds1['rsam'].sizes))
        print(xds1['rsam'].median().compute())
except NotImplementedError as e:
    print(e)
    with xr.open_dataset("/tmp/test_1d.nc") as xds1:
        print(xds1['rsam'].median().compute())
        print(xds1['dsar'].median().compute())

1
The da.nanmedian function only works along an axis or a subset of axes.  The full algorithm is difficult to do in parallel
<xarray.DataArray 'rsam' ()>
array(488.6201254)
<xarray.DataArray 'dsar' ()>
array(488.6201254)


In [20]:
xds1.isnull().sum().compute()

In [39]:
xds['ssam'][0].reset_coords(drop=True)

In [25]:
npts = 3 
npts2 = 4 
vals = [[1., np.nan, np.nan, 3.],
        [4., 5., np.nan, 7.],
        [np.nan, 9., 10., 11.]]
xds = xr.Dataset({'rsam': (['x', 'y'], vals)},
                 coords={'x': np.arange(npts),
                         'y': np.arange(npts2)}) 
xds.attrs['starttime'] = "2023-01-01"
xds

In [26]:
xda = xds['rsam']

In [27]:
xda.name

'rsam'

In [16]:
npts = int(1e3)
npts2 = 100
xds = xr.Dataset({'rsam': (['x', 'y'], np.random.randn(npts, npts2))},
                 coords={'x': np.arange(npts),
                         'y': np.arange(npts2)}) 
xds.attrs['starttime'] = "2023-01-01"

In [29]:
s = slice(1,None)
print(s.start, s.stop)

1 None


In [40]:
start = datetime.datetime(2023, 1, 1)
end = datetime.datetime(2023, 1, 12)
int((end - start)/datetime.timedelta(hours=1/6.))

1584

In [27]:
isinstance(s, slice)

True

In [17]:
xda = xds['rsam']
xda

In [19]:
xda.attrs = xds.attrs
xda.attrs

{'starttime': '2023-01-01'}

In [6]:
%time
with xr.open_dataset('/tmp/test.nc', chunks={}) as xda1:
    sz = xda1.sizes

print(sz)

CPU times: user 3 µs, sys: 4 µs, total: 7 µs
Wall time: 20.7 µs
Frozen({'x': 1000000, 'y': 100})


In [9]:
len(sz)

2

In [10]:
%time
with xr.open_dataset('/tmp/test.nc') as ds:
    rq = ds.load()
vals = rq.rsam.data[:]
print(vals.shape)

CPU times: user 3 µs, sys: 5 µs, total: 8 µs
Wall time: 15.7 µs
(1000000, 100)


In [27]:
vals

array([      0,       1,       2, ..., 9999997, 9999998, 9999999])