In [1]:
%matplotlib inline
from dask import array as da
import numpy as np
import xarray as xr
from gcsfs.mapping import GCSMap
from xhistogram.xarray import histogram as xhist
from matplotlib import pyplot as plt
import pandas as pd
from dask import delayed

In [2]:
# parameters
dataset_version = "v2019.09.11.2"
bucket_stokes = f"pangeo-parcels/med_sea_connectivity_{dataset_version}/traj_data_with_stokes.zarr"
bucket_nostokes = f"pangeo-parcels/med_sea_connectivity_{dataset_version}/traj_data_without_stokes.zarr"
thinning_data_factor = 0.1  # randomly thinning trajectories to 10%

In [3]:
from dask.distributed import Client, progress

from dask_kubernetes import KubeCluster
cluster = KubeCluster(n_workers=8)
cluster.adapt(minimum=8, maximum=40, wait_count=15)
cluster

VBox(children=(HTML(value='<h2>KubeCluster</h2>'), HBox(children=(HTML(value='\n<div>\n  <style scoped>\n    .…

** ☝️ Don't forget to click the link above to view the scheduler dashboard! **

In [4]:
client = Client(cluster)
client

0,1
Client  Scheduler: tcp://10.32.60.28:44999  Dashboard: /user/0000-0003-1951-8494/proxy/8787/status,Cluster  Workers: 0  Cores: 0  Memory: 0 B


In [None]:
def open_dataset(bucket, restrict_to_MPA=None, restrict_to_z=None):
    # load data
    gcsmap = GCSMap(bucket)
    ds = xr.open_zarr(gcsmap, decode_cf=False)
    
    # get info on starting region and make it an easy-to-look-up coord
    initial_MPA = ds.MPA.isel(obs=0).squeeze()
    ds.coords["initial_MPA"] = initial_MPA
    
    # make (non-changing) depth level an easy to look up coord
    z = ds.z.isel(obs=0).squeeze()
    ds["z"] = z
    ds.coords["z"] = ds.z
    
    # mask after land contact
    before_land_contact = ((ds.land == 0).cumprod("obs") == 1)
    ds = ds.where(before_land_contact)
    
    # add relative time
    ds["time_since_start"] = (ds.time - ds.time.isel(obs=0)) / 1e9
    
    return ds

In [None]:
ds_stokes = open_dataset(bucket_stokes)
ds_nostokes = open_dataset(bucket_nostokes)

In [None]:
import cloudpickle

In [None]:
with open("ds", "wb") as f:
        cloudpickle.dump(ds_stokes,  f) 

In [None]:
def persist_z_MPA(ds):
    ds["z"] = ds["z"].persist(retries=40)
    ds["initial_MPA"] = ds["initial_MPA"].persist(retries=40)
    return ds

In [None]:
ds_stokes = persist_z_MPA(ds_stokes)
ds_nostokes = persist_z_MPA(ds_nostokes)

In [None]:
def get_z_values(ds):
    z_values = da.unique(ds.z.data).compute(retries=40)
    z_values = z_values[~np.isnan(z_values)]
    return z_values

In [None]:
z_values = get_z_values(ds_nostokes)

In [None]:
@delayed
def restrict_to(ds, MPA=None, z=None):
    traj_indices = xr.full_like(ds.initial_MPA, True, dtype="bool")
    
    if MPA is not None:
        traj_indices = traj_indices & (ds.initial_MPA == MPA)
    
    if z is not None:
        traj_indices = traj_indices & (ds.z == z)
        
    ds = ds.isel(traj=traj_indices)
    
    return ds

In [None]:
from collections import OrderedDict

In [None]:
data = pd.DataFrame(
    (
        OrderedDict(
            {
                "stokes": True, "MPA": MPA, "k": k,
                "data": restrict_to(ds_stokes, MPA, z=z_values[k])
            }
        )
        for MPA in range(1, 10)
        for k in range(1)
    )
)
data = data.append(
    pd.DataFrame(
        (
            OrderedDict(
                {
                    "stokes": False, "MPA": MPA, "k": k,
                    "data": restrict_to(ds_nostokes, MPA, z=z_values[k])
                }
            )
            for MPA in range(1, 10)
            for k in range(len(z_values))
        )
    ),
    ignore_index=True
)

In [None]:
def get_thinned_data(ds, factor=0.5, seed=None):
    if seed is not None:
        np.random.seed(seed)
    traj_indices = (np.random.uniform(0, 1, size=ds.z.shape) < factor)
    ds = ds.isel(traj=traj_indices)
    return ds

In [None]:
def get_var(ds, varname):
    return ds[varname]

In [None]:
data

In [None]:
_dist = get_var(get_thinned_data(data["data"][53], factor=0.001, seed=1), "distance")

In [None]:
import cloudpickle

In [None]:
dist = _dist.result().compute(retries=40)