In [7]:
%matplotlib inline
from dask import array as da
import numpy as np
import xarray as xr
from gcsfs.mapping import GCSMap
from xhistogram.xarray import histogram as xhist

In [2]:
# parameters
dataset_version = "v2019.09.11.2"
bucket_stokes = f"pangeo-parcels/med_sea_connectivity_{dataset_version}/traj_data_with_stokes.zarr"
bucket_nostokes = f"pangeo-parcels/med_sea_connectivity_{dataset_version}/traj_data_without_stokes.zarr"

In [3]:
from dask.distributed import Client, progress

from dask_kubernetes import KubeCluster
cluster = KubeCluster(n_workers=4)
cluster.adapt(minimum=4, maximum=40)
cluster

VBox(children=(HTML(value='<h2>KubeCluster</h2>'), HBox(children=(HTML(value='\n<div>\n  <style scoped>\n    .…

** ☝️ Don't forget to click the link above to view the scheduler dashboard! **

In [4]:
client = Client(cluster)
client

0,1
Client  Scheduler: tcp://10.32.60.28:35895  Dashboard: /user/0000-0003-1951-8494/proxy/8787/status,Cluster  Workers: 0  Cores: 0  Memory: 0 B


In [5]:
def open_dataset(bucket, restrict_to_MPA=None, restrict_to_z=None):
    # load data
    gcsmap = GCSMap(bucket)
    ds = xr.open_zarr(gcsmap, decode_cf=False)
    
    # get info on starting region and make it an easy-to-look-up coord
    initial_MPA = ds.MPA.isel(obs=0).squeeze()
    ds.coords["initial_MPA"] = initial_MPA
    
    # make (non-changing) depth level an easy to look up coord
    z = ds.z.isel(obs=0).squeeze()
    ds["z"] = z
    ds.coords["z"] = ds.z
    
    # mask after land contact
    before_land_contact = ((ds.land == 0).cumprod("obs") == 1)
    ds = ds.where(before_land_contact)
    
    # add relative time
    ds["time_since_start"] = (ds.time - ds.time.isel(obs=0)) / 1e9
    
    # maybe extract MPA
    if restrict_to_MPA is not None:
        ds = ds.where(ds.initial_MPA == restrict_to_MPA)
        
    # maybe extract z
    if restrict_to_z is not None:
        ds = ds.where(ds.z == restrict_to_z)
    
    return ds

In [6]:
ds_stokes = open_dataset(bucket_stokes)
ds_nostokes = open_dataset(bucket_nostokes)

In [8]:
def get_z_values(ds):
    z_values = da.unique(ds.z.data).compute(retries=40)
    z_values = z_values(~np.isnan(z_values))
    return z_values

In [None]:
get_z_values(ds_nostokes)