In [None]:
import itertools as it
import time

import dask
import dask.array as da
import numpy as np
import xarray as xr
from distributed import Client, LocalCluster

## Create the Dask cluster

In [None]:
cluster = LocalCluster(
    n_workers=4,
    threads_per_worker=1,
    processes=True
)

In [None]:
client = Client(cluster)
client

In [None]:
client.close()
cluster.close()


Use it as a context manager, e.g.:
```
cluster_kwargs = {...}
with LocalCluster(**cluster_kwargs) as cluster, Client(cluster) as client:
    ... # code using the cluster
```

## Function to rasterise

In [None]:
@dask.delayed
def tile_function(i, j, chunk_x, chunk_y):
    time.sleep(1)
    return np.ones((chunk_x, chunk_y)) * i * j

## Define chunks and number of tiles

In [None]:
chunk_x = 10
chunk_y = 20
n_chunks_x = 10
n_chunks_y = 10

## Run across tiles

In [None]:
delayed_arr = []
for i in range(n_chunks_x):
    delayed_arr_i = []
    for j in range(n_chunks_y):
        delayed_arr_i.append(da.from_delayed(tile_function(i, j, chunk_x, chunk_y), shape=(chunk_x, chunk_y), dtype=float))
    delayed_arr.append(delayed_arr_i)
arr = da.block(delayed_arr)
arr

In [None]:
xarr = xr.DataArray(arr, coords={"x": range(arr.shape[0]), "y": range(arr.shape[1])})
xarr

In [None]:
print(xarr)

In [None]:
xarr = xarr.compute()
xarr

In [None]:
xarr.plot()

In [None]:
list(it.product(range(10), range(20)))

In [None]:
[da.from_delayed(tile_function(i, j), (10, 20), dtype=float)
    for i, j in it.product(range(10), range(20))]

## Orginal code - flat array not 2d tiled

In [None]:
delayed_arr = [da.from_delayed(tile_function(i, j, chunk_x, chunk_y), shape=(chunk_x, chunk_y), dtype=float)
               for i, j in it.product(range(n_chunks_x), range(n_chunks_y))]
arr = da.concatenate(delayed_arr, axis=-1)
arr

In [None]:
arr = da.concatenate([
    da.from_delayed(tile_function(i, j, chunk_x, chunk_y), shape=(chunk_x, chunk_y), dtype=float)
    for i, j in it.product(range(n_chunks_x), range(n_chunks_y))
], axis=-1)
arr