## To Do

- what is a good chunk size?
- how can I verify a better performance?
- check accuracy of results
- check computation time on cluster

In [11]:
from toad import TOAD
import numpy as np
import xarray as xr
from toad.shifts_detection.methods import ASDETECT as ASDETECT
#import dask

fp = "tutorials/test_data/garbe_2020_antarctica.nc"
#fp = "tutorials/test_data/global_mean_summer_tas.nc"
var = "thk"

data = xr.open_dataset(fp)
spatial_dims = list(data[var].dims)
spatial_dims.remove("time")

c = 5
c_dict = {dim: c for dim in spatial_dims}
c_dict["time"] = 3
data = data.coarsen(**c_dict,
                    boundary="trim").reduce(np.mean)

print(f"Dimensions after coarsening:\n{data.sizes}")

cs = None
print(data[var].chunk({'x': cs, 'y': cs}).data.nbytes / 1e6, "MB")


Dimensions after coarsening:
Frozen({'time': 116, 'y': 38, 'x': 38})
0.670016 MB


In [2]:
from dask.distributed import Client

client = Client(n_workers=5)
client


0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 5
Total threads: 10,Total memory: 15.29 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:43491,Workers: 0
Dashboard: http://127.0.0.1:8787/status,Total threads: 0
Started: Just now,Total memory: 0 B

0,1
Comm: tcp://127.0.0.1:34241,Total threads: 2
Dashboard: http://127.0.0.1:42651/status,Memory: 3.06 GiB
Nanny: tcp://127.0.0.1:45665,
Local directory: /tmp/dask-scratch-space/worker-1jdsnjjx,Local directory: /tmp/dask-scratch-space/worker-1jdsnjjx

0,1
Comm: tcp://127.0.0.1:38363,Total threads: 2
Dashboard: http://127.0.0.1:44027/status,Memory: 3.06 GiB
Nanny: tcp://127.0.0.1:45927,
Local directory: /tmp/dask-scratch-space/worker-5fkng8da,Local directory: /tmp/dask-scratch-space/worker-5fkng8da

0,1
Comm: tcp://127.0.0.1:38025,Total threads: 2
Dashboard: http://127.0.0.1:42727/status,Memory: 3.06 GiB
Nanny: tcp://127.0.0.1:45021,
Local directory: /tmp/dask-scratch-space/worker-8fw2huuy,Local directory: /tmp/dask-scratch-space/worker-8fw2huuy

0,1
Comm: tcp://127.0.0.1:43559,Total threads: 2
Dashboard: http://127.0.0.1:35665/status,Memory: 3.06 GiB
Nanny: tcp://127.0.0.1:41247,
Local directory: /tmp/dask-scratch-space/worker-l34ofvwv,Local directory: /tmp/dask-scratch-space/worker-l34ofvwv

0,1
Comm: tcp://127.0.0.1:41587,Total threads: 2
Dashboard: http://127.0.0.1:44701/status,Memory: 3.06 GiB
Nanny: tcp://127.0.0.1:33167,
Local directory: /tmp/dask-scratch-space/worker-_kv60a03,Local directory: /tmp/dask-scratch-space/worker-_kv60a03


In [5]:
#dask.config.set(scheduler='threads')

td_new = TOAD(data)
td_new.compute_shifts(var,
                  method=ASDETECT(),
                  overwrite=True,
                  chunk_size=10,
                  dask_compute=True)

[                                        ] | 0% Completed | 169.48 us

[########################################] | 100% Completed | 1.67 sms


## Chunk Size

In [6]:
import time

# Define the chunk sizes to test
chunk_sizes = [None, 5, 20, 30, 40, 50]
sample_size = 5

# Run tests
results = [0] * len(chunk_sizes)

print("Benchmarking chunk sizes...\n")
for i in range(len(chunk_sizes)):
    size = chunk_sizes[i]
    for j in range(sample_size):
        # get test data
        td = TOAD(data)

        # Time the execution
        lazy_shifts = td_new.compute_shifts(var,
                                            method=ASDETECT(),
                                            overwrite=True,
                                            return_results_directly=True,
                                            chunk_size=size,
                                            dask_compute=False)
        
        start_time = time.time()
        _ = lazy_shifts.compute()
        elapsed = time.time() - start_time

        results[i] += elapsed
    results[i] /= sample_size
    print(f"Chunk size {size}x{size}: {results[i]:.2f} seconds")

Benchmarking chunk sizes...

Chunk size NonexNone: 1.29 seconds
Chunk size 5x5: 1.42 seconds
Chunk size 20x20: 1.41 seconds
Chunk size 30x30: 1.33 seconds
Chunk size 40x40: 1.43 seconds
Chunk size 50x50: 1.33 seconds


## Artificial Dataset

In [12]:
import dask.array as da
import xarray as xr
import numpy as np

# Parameters
size = 2500
shape = (50, 50000, 50000)
chunks = (1, size, size)

# Create Dask array lazily
data_dask = da.random.random(shape, chunks=chunks)

# Wrap in xarray
time = np.arange(shape[0])
lat = np.linspace(-90, 90, shape[1])
lon = np.linspace(-180, 180, shape[2])

data = xr.DataArray(
    data_dask,
    dims=["time", "x", "y"],
    coords={"time": time, "x": lat, "y": lon}
)

sizes = [np.prod(c) * data.dtype.itemsize for c in zip(*data.chunks)]
print("Chunk sizes (MB):", [float(s) / 1e6 for s in sizes[:]])

print(f"Expected number of chunks: {shape[1] // size}")
print(f"Actual number of chunks: {len(sizes)}")


Chunk sizes (MB): [50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0]
Expected number of chunks: 20
Actual number of chunks: 20


In [20]:
import xarray as xr
import numpy as np

# Define large shape (e.g., 1000 time steps, 1000x1000 spatial grid)
shape = (500, 100, 100)  # (time, lat, lon)

# Create coordinate values
time = np.arange(shape[0])
lat = np.linspace(-90, 90, shape[1])
lon = np.linspace(-180, 180, shape[2])

# Create synthetic data using Dask
data = xr.DataArray(
    {"var": (['time','lat','lon'],np.random.rand(*shape))},
    #dims=["time", "lat", "lon"],
    coords={"time": time, "lat": lat, "lon": lon},
)
td_art = TOAD(data) 

print(td_art.data)


ValueError: coordinate time has dimensions ('time',), but these are not a subset of the DataArray dimensions ()

In [21]:
import xarray as xr
import numpy as np

# Define large shape (e.g., 500 time steps, 100x100 spatial grid)
shape = (500, 100, 100)  # (time, lat, lon)

# Create coordinate values
time = np.arange(shape[0])
lat = np.linspace(-90, 90, shape[1])
lon = np.linspace(-180, 180, shape[2])

# Create synthetic data
data = np.random.rand(*shape)

# Create a Dataset with a named variable
dataset = xr.Dataset(
    {"temperature": (["time", "lat", "lon"], data)},
    coords={"time": time, "lat": lat, "lon": lon}
)

# If TOAD expects a DataArray, extract the named variable
td_art = TOAD(dataset["temperature"])  # 'temperature' is the variable name

print(td_art.data)


<xarray.DataArray 'temperature' (time: 500, lat: 100, lon: 100)> Size: 40MB
array([[[0.4964379 , 0.56487303, 0.0873338 , ..., 0.79871699,
         0.06645395, 0.25481763],
        [0.5672902 , 0.54881669, 0.2441572 , ..., 0.87434235,
         0.37644149, 0.69728853],
        [0.4961602 , 0.29740827, 0.5065987 , ..., 0.44672649,
         0.82140411, 0.13097236],
        ...,
        [0.37807596, 0.63159655, 0.94648246, ..., 0.80824281,
         0.0150163 , 0.14470417],
        [0.95143854, 0.19375991, 0.94351719, ..., 0.05937079,
         0.01415021, 0.60344563],
        [0.38335639, 0.06730697, 0.93073946, ..., 0.63246888,
         0.21162823, 0.42703749]],

       [[0.57597649, 0.00543985, 0.06140057, ..., 0.40709494,
         0.95218625, 0.07490342],
        [0.61286478, 0.18191182, 0.47682589, ..., 0.13871663,
         0.26933242, 0.03425804],
        [0.74979889, 0.69611793, 0.42371704, ..., 0.29718326,
         0.41019395, 0.86923975],
...
        [0.35342874, 0.25212406, 0.935644

In [23]:
dataset

In [20]:
td_art.compute_shifts(td_art.data,
                  method=ASDETECT(),
                  overwrite=True,)

ValueError: The truth value of a Array is ambiguous. Use a.any() or a.all().

In [None]:
import time

# Define the chunk sizes to test
chunk_sizes = [50, 100, 500]

# Run tests
results = []

print("Benchmarking chunk sizes...\n")
for size in chunk_sizes:
    # get test data
    td = TOAD(data)

    # Time the execution
    start_time = time.time()
    td_art.compute_shifts(var,
                    method=ASDETECT(),
                    overwrite=True,
                    chunk_size=size,)
    elapsed = time.time() - start_time

    results.append((size, elapsed))
    print(f"Chunk size {size}x{size}: {elapsed:.2f} seconds")

# Summary
print("\nSummary:")
for size, elapsed in results:
    print(f"Chunk size {size}x{size}: {elapsed:.2f} s")


Benchmarking chunk sizes...

