# Dask processing tests

this needs graphviz for visualizations

In [None]:
import dask
import dask.array as da
import numpy as np


## define a computational graph
embarassingly parallel here 

In [None]:
def f(x):
    return x + 1

def g(x):
    return x * 2

def h(x):
    return x - 3

def m(x):  
    return da.mean(x)  # we can also use numpy.mean here if we want, but the latter will directly materialize the array. da.mean needs a 'compute' call to materialize the array instead

def identity(x):
    return x

In [None]:
f_delayed = dask.delayed(f)
g_delayed = dask.delayed(g)
h_delayed = dask.delayed(h)
m_delayed = dask.delayed(m)

In [None]:
input_placeholder = dask.delayed(identity, pure = True)(da.arange(2))
fx = f_delayed(input_placeholder)
gx = g_delayed(fx)
hx = h_delayed(gx)
mx = m_delayed(hx)
input_key = input_placeholder.key
computation = mx 
graph = dict(computation.dask)

In [None]:
input_key

In [None]:
computation

In [None]:
graph

In [None]:
dask.visualize(computation, filename='dask_graph.svg', format='svg')
# Save the graph to a file

### Execute with dask

replace input with something more supstantial

In [None]:
og_data = np.random.rand(int(1e9))

wrap numpy array into dask array. only for the experiments here

In [None]:
input_data = da.from_array(og_data, chunks=(int(1e6),))

built the result by hand to be able to check it 

replace input of graph with new input and compute result 

In [None]:
dask.config.set(scheduler='synchronous') # make ALL operations synchronous
graph[input_key] = input_data 
result = dask.get(graph, mx.key,).compute()
result

try with parallel scheduler

In [None]:
dask.config.set(scheduler='processes') 

graph[input_key] = input_data 
result = dask.get(graph, mx.key).compute()  # Use processes scheduler
result

for dataframes, have a look [here](https://examples.dask.org/dataframe.html), for xarray, check [this](https://docs.xarray.dev/en/stable/user-guide/dask.html) out, and [that](https://examples.dask.org/xarray.html?highlight=schedulers) 

## Partial parallelization of a computational graph with `dask.annotate`

annotate `h` to be serial only

In [None]:

def h_single(x): 
    with dask.config.set(scheduler='synchronous'):
        hx =  h(x)
    return hx.rechunk('auto') # re-split the data again to make downstream operations parallel

h_delayed_serial = dask.delayed(h_single)

build up the graph again

In [None]:
input_placeholder = dask.delayed(identity, pure = True)(da.arange(2))
fx = f_delayed(input_placeholder)
gx = g_delayed(fx)
hx = h_delayed_serial(gx) # README: use the serial version of h
mx = m_delayed(hx)
input_key = input_placeholder.key
computation = mx 
graph = dict(computation.dask) 

execute the thing again with a parallel scheduler - check how the CPU load behaves.

In [None]:
dask.config.set(scheduler='processes')  # make operations parallel
graph[input_key] = input_data
result = dask.get(graph, computation.key).compute()  # Use processes scheduler
result