In [None]:
# parameters
dataset_version = "v2019.09.11.2"
bucket_stokes = f"pangeo-parcels/med_sea_connectivity_{dataset_version}/traj_data_with_stokes.zarr"
bucket_nostokes = f"pangeo-parcels/med_sea_connectivity_{dataset_version}/traj_data_without_stokes.zarr"

filter_warnings = "ignore"  # No warnings will bother you.  Change for debugging.

## Load all modules and spin up a Dask cluster

In [None]:
%matplotlib inline
from dask import array as da
import numpy as np
import xarray as xr
from gcsfs.mapping import GCSMap
from xhistogram.xarray import histogram as xhist
from matplotlib import pyplot as plt
import pandas as pd
from dask import delayed

In [None]:
import warnings
warnings.filterwarnings(filter_warnings)

In [None]:
from dask.distributed import Client, progress

from dask_kubernetes import KubeCluster
cluster = KubeCluster(n_workers=8)
cluster.adapt(minimum=8, maximum=60, wait_count=15)

client = Client(cluster)
client

In [None]:
import cloudpickle

In [None]:
with open("intermediate_data/all_traj_dataframe.pickle", mode="rb") as f:
    data = cloudpickle.load(f)

In [None]:
data

In [None]:
fig, ax = plt.subplots(9, 3, sharex=True, sharey=True, figsize=(18, 54))
print(ax.shape)

for MPA in range(1, 10):
    
    xr.ufuncs.log10(
        xhist(
            data.loc[True, MPA, 0]["thinned_data_005_percent"].distance.persist(retries=40),
            bins=[np.linspace(0, 1200, 41), ],
            dim=["traj", ]).compute(retries=40)
    ).plot(ax=ax[MPA-1, 0], x="obs", y="distance_bin")
    
    ax[MPA-1, 0].set_title(f"MPA {MPA}, with stokes, surface")
    
    xr.ufuncs.log10(
        xhist(
            data.loc[False, MPA, 0]["thinned_data_005_percent"].distance.persist(retries=40),
            bins=[np.linspace(0, 1200, 41), ],
            dim=["traj", ]).compute(retries=40)
    ).plot(ax=ax[MPA-1, 1], x="obs", y="distance_bin")
    
    ax[MPA-1, 1].set_title(f"MPA {MPA}, without stokes, surface")
    
    xr.ufuncs.log10(
        xhist(
            data.loc[False, MPA, -1]["thinned_data_005_percent"].distance.persist(retries=40),
            bins=[np.linspace(0, 1200, 41), ],
            dim=["traj", ]).compute(retries=40)
    ).plot(ax=ax[MPA-1, 2], x="obs", y="distance_bin")
    
    ax[MPA-1, 2].set_title(f"MPA {MPA}, without stokes, all depths")

In [None]:
fig, ax = plt.subplots(3, 3, sharex=False, sharey=False, figsize=(15, 15))
ax = ax.flatten()

for MPA in range(1, 10):
    (data.loc[False, MPA, 0]["thinned_data_005_percent"].distance.isnull().mean("traj") * 100).compute(retries=40).plot(ax=ax[MPA-1], label="without stokes");
    (data.loc[True, MPA, 0]["thinned_data_005_percent"].distance.isnull().mean("traj") * 100).compute(retries=40).plot(ax=ax[MPA-1], label="with stokes");
    
    ax[MPA-1].set_title(f"MPA {MPA}")
    ax[MPA-1].set_xlabel("time / hours")
    ax[MPA-1].set_ylabel("percent beached so far")
    ax[MPA-1].legend(loc=0, ncol=1);

fig.tight_layout();