This notebook documents the ways to parellize the loop of mhw processing that was demonstrated in the `ejoliver_subset_MUR.ipynb` notebook. It uses a randomly generated 1,000 x 100 x 100 pixel grid of data for simplicity.
Two methods shown here:
1. `da.apply_along_axis()`
2. `da.map_blocks()`

Key stackoverflow support part of these processes:
- [parellelizing the loop](https://stackoverflow.com/questions/71916577/dask-looping-over-library-function-call)
- [iterating map_blocks for debugging](https://stackoverflow.com/questions/72015205/iterating-through-dask-array-chunks)

# Code Snippets

## `da.apply_along_axis()`

Output of `result` is some kind of thing where the entire ouptut of `mhw.detect()` all get smushed together into some kind of array.

Note: the `dtype=data.dtype, shape=(1000,)` inputs to `apply_along_axis()` are necessary.  Without them the function gets passed an array of shape `(1,)` and it doesn't work.

In [None]:
import numpy as np
import dask.array as da
import marineHeatWaves as mhw
from dask.distributed import Client

client = Client()

In [None]:
# Create fake input data
lat_size, long_size = 100, 100
data = da.random.random_integers(0, 30, size=(1_000, long_size, lat_size), chunks=(-1, 10, 10))  # size = (time, longitude, latitude)
time = np.arange(730_000, 731_000)  # time in ordinal days

# define a wrapper to rearrange arguments
def func1d(arr, time):
   return mhw.detect(time, arr)

result = da.apply_along_axis(func1d, 0, data, time=time, dtype=data.dtype, shape=(1000,))
result.compute()

## `da.map_blocks()`

This technically is a longer version of the `apply_along_axis()` method. It was helpful for seeing the process, though, when I was missing an arguemnt I needed in the other method.

Output of `result` is some kind of thing where the entire ouptut of `mhw.detect()` all get smushed together into some kind of array.

In [None]:
import numpy as np
import dask.array as da
import marineHeatWaves as mhw
from dask.distributed import Client

client = Client()

In [3]:
# Create fake input data
lat_size, long_size = 100, 100
data = da.random.random_integers(0, 30, size=(1_000, long_size, lat_size), chunks=(-1, 10, 10))  # size = (time, longitude, latitude)
time = np.arange(730_000, 731_000)  # time in ordinal days

# define a wrapper to rearrange arguments
def func1d(arr, time):
   return mhw.detect(time, arr)

def block_func(block, **kwargs):
    return np.apply_along_axis(func1d, 0, block, **kwargs)

result = data.map_blocks(block_func, meta=data, time=time)
result = result.compute()

## Final Output

This is the cleaned and final version of the different techniques above.  In general the `map_blocks()` approach is just a manual version of what was happening in the `apply_along_axis()` approach, so that is the one I went with for the final. I wasn't able to figure out how to get the function to `mhw.detect()` function to run only once and assign the output different places, although I'm pretty sure it can be done. Instead of that, then, the code below runs the function twice.

In [2]:
import numpy as np
import dask.array as da
import marineHeatWaves as mhw
from dask.distributed import Client

client = Client()

In [16]:
# Create fake input data
lat_size, long_size, time_size = 100, 100, 1000
data = da.random.random_integers(0, 30, size=(time_size, long_size, lat_size), chunks=(-1, 10, 10))  # size = (time, longitude, latitude)
time = np.arange(730_000, 731_000)  # time in ordinal days

# define a wrapper to rearrange arguments
def func1d_climatology(arr, time):
   _, point_clim = mhw.detect(time, arr)
   # return climatology
   return point_clim['seas']

# define a wrapper to rearrange arguments
def func1d_threshold(arr, time):
   _, point_clim = mhw.detect(time, arr)
   # return threshold
   return point_clim['thresh']

# output arrays
full_climatology = da.zeros_like(data)
full_threshold = da.zeros_like(data)

climatology = da.apply_along_axis(func1d_climatology, 0, data, time=time, dtype=data.dtype, shape=(time_size,))
threshold = da.apply_along_axis(func1d_threshold, 0, data, time=time, dtype=data.dtype, shape=(time_size,))


In [12]:
output_cim = climatology.compute()
# output_thresh = threshold.compute()

In [17]:
climatology

Unnamed: 0,Array,Chunk
Bytes,76.29 MiB,781.25 kiB
Shape,"(1000, 100, 100)","(1000, 10, 10)"
Count,200 Tasks,100 Chunks
Type,int64,numpy.ndarray
"Array Chunk Bytes 76.29 MiB 781.25 kiB Shape (1000, 100, 100) (1000, 10, 10) Count 200 Tasks 100 Chunks Type int64 numpy.ndarray",100  100  1000,

Unnamed: 0,Array,Chunk
Bytes,76.29 MiB,781.25 kiB
Shape,"(1000, 100, 100)","(1000, 10, 10)"
Count,200 Tasks,100 Chunks
Type,int64,numpy.ndarray


: 

In [14]:
data.nbytes / 1e6

80.0

# Timing blocks and chunks

### Block 1 (chunked in time)
**Source code**
```python
lat_size, long_size, time_size = 100, 100, 1000
data = da.random.random_integers(0, 30, size=(time_size, long_size, lat_size), chunks=(-1, 10, 10))
```
**Size Stats**

Data size: 80MB, 10 million grid points
`climatology.compute()`: 1 min 28s
No of tasks in xarray preview: 200 tasks, 800 KB chunks (80 MB total array)

**Linear prediction**

(220 times smaller than gulfstream block --> 330 minutes = 5.5 hours)

**Dask comments**

- 100 func_1d_clim tasks; 100 random_integers tasks
- seems to hang for the first minute, then the last 30 seconds it buzzes through the climatology calculation

### Block 2 (chunked in space)
**Source code**

```python
lat_size, long_size, time_size = 100, 100, 1000
data = da.random.random_integers(0, 30, size=(time_size, long_size, lat_size), chunks=(2, 100, 100))
```
**Size Stats**

Data size: 80MB, 10 million grid points
`climatology.compute()`: 11 min 50 sec
No of tasks in xarray preview: 502 tasks, 80 MB chunks (80 MB total array)

**Linear prediction**

(220 times smaller than gulfstream block --> 44 hours, 2 days)
I would guess the bigger in space dimenion won't affect timing much, but the bigger in time will dramatically. Not sure how the math on that will work out.

An alternate (albeit highly hopeful perspective):
(7 times as many time chunks --> 12 minutes * 7 = 84 minutes)

**Dask comments**

- 500 random_integers tasks, 1 func1d_clim task, 1 rechunk-merge task
- 500 random_integers task completes in the first minute, seems from CPU use (Workers tab) that 1 worker then handles the remaining func1d_clim task on its own

