# Prepare data

All trajectories are stored in a Google Cloud Storage bucket. We want to be able to load and filter all trajectories easily.  To this end, we load all the datasets (lazily), filter them to different parameters (starting MPA, depth, stokes drift), and store a Pandas dataframe with virtual sub-datasets for each combination of the parameters.  This Pandas dataframe will be pickled for later re-use.

In [None]:
# parameters
dataset_version = "v2019.09.11.2"
bucket_stokes = f"pangeo-parcels/med_sea_connectivity_{dataset_version}/traj_data_with_stokes.zarr"
bucket_nostokes = f"pangeo-parcels/med_sea_connectivity_{dataset_version}/traj_data_without_stokes.zarr"

filter_warnings = "ignore"  # No warnings will bother you.  Change for debugging.

## Load all modules and spin up a Dask cluster

In [None]:
%matplotlib inline
from dask import array as da
import numpy as np
import xarray as xr
from gcsfs.mapping import GCSMap
from xhistogram.xarray import histogram as xhist
from matplotlib import pyplot as plt
import pandas as pd
from dask import delayed

In [None]:
from dask.distributed import Client, progress

from dask_kubernetes import KubeCluster
cluster = KubeCluster(n_workers=8)
cluster.adapt(minimum=8, maximum=60, wait_count=15)

client = Client(cluster)
client

** ☝️ Don't forget to click the link above to view the scheduler dashboard! **

## Open datasets

In [None]:
def open_dataset(bucket, restrict_to_MPA=None, restrict_to_z=None):
    # load data
    gcsmap = GCSMap(bucket)
    ds = xr.open_zarr(gcsmap, decode_cf=False)
    
    # get info on starting region and make it an easy-to-look-up coord
    initial_MPA = ds.MPA.isel(obs=0).squeeze()
    ds.coords["initial_MPA"] = initial_MPA
     
    # add mask that is False after land contact
    ds["before_land_contact"] = ((ds.land == 0).cumprod("obs") == 1)
      
    return ds

In [None]:
ds_stokes = open_dataset(bucket_stokes)
ds_nostokes = open_dataset(bucket_nostokes)

## Simplify

We know a few things about our data that make it easier to deal with them:

- No vertical migration.  Hence, initial depth of a particle is valid for all times.

- All time steps are the same. Hence, we can easily build a relative time axis that is valid for all particles.

In [None]:
def apply_assumptions(ds):
    """Applies simplifications to the dataset that are valid for the 
    specific set of experiments we're dealing with here.
    
    Be careful when applying these to new experiments, because
    they might not apply.
    """
    # We assume no vertical migration and hence
    # make (non-changing) depth level an easy to look up coord
    z = ds.z.isel(obs=0).squeeze()
    ds["z"] = z
    ds.coords["z"] = ds.z
    
    # We assume that all time steps are equal
    # and that the time axis is measured in seconds
    # since some reference period
    time_axis = ds.reset_coords(["z", "initial_MPA"]).time.isel(traj=0).squeeze()
    time_axis -= time_axis.isel(obs=0).squeeze()
    time_axis.attrs["units"] = "seconds since start of particle"
    ds.coords["time_axis"] = time_axis
    
    return ds

In [None]:
ds_stokes = apply_assumptions(ds_stokes)
ds_nostokes = apply_assumptions(ds_nostokes)

## Load coordinates for quicker access

So far, we did only the bare minimum of information (data types, variable names, number of time steps, ...) but did not load any of the data.  We want to continue to do so for the bulk of the data, but get coordinates and the like now.

In [None]:
def persist_coords(ds, retries=40):
    """Will load coordinate data to the cluster."""
    ds["z"] = ds["z"].persist(retries=retries)
    ds["initial_MPA"] = ds["initial_MPA"].persist(retries=retries)
    ds["time_axis"] = ds["time_axis"].persist(retries=retries)
    return ds

In [None]:
def compute_coords(ds, retries=40):
    """Will load coordinate data to the front end."""
    ds["z"] = ds["z"].compute(retries=retries)
    ds["initial_MPA"] = ds["initial_MPA"].compute(retries=retries)
    ds["time_axis"] = ds["time_axis"].compute(retries=retries)
    return ds

In [None]:
ds_stokes = persist_coords(ds_stokes)
ds_nostokes = persist_coords(ds_nostokes)

In [None]:
ds_stokes = compute_coords(ds_stokes)
ds_nostokes = compute_coords(ds_nostokes)

In [None]:
ds_stokes

In [None]:
ds_nostokes

In [None]:
def get_z_values(ds):
    """Load unique z-values to the front end.
    
    This triggers a computation across all of the z-level data.
    """
    z_values = da.unique(ds.z.data).compute(retries=40)
    z_values = z_values[~np.isnan(z_values)]
    return z_values

In [None]:
z_values = get_z_values(ds_nostokes)

In [None]:
print(z_values)

## Filter data

We want to quickly select:
- stokes drift on or off
- MPA a trajectory started from
- z-level

In [None]:
def restrict_to(ds, MPA=None, z=None):
    traj_indices = xr.full_like(ds.initial_MPA, True, dtype="bool")
    
    if MPA is not None:
        traj_indices = traj_indices & (ds.initial_MPA == MPA)
    
    if z is not None:
        traj_indices = traj_indices & (ds.z == z)
        
    ds = ds.isel(traj=traj_indices)
    
    return ds

In [None]:
from collections import OrderedDict

In [None]:
def wrap_in_dataframe(ds, stokes=True, num_levels=1, num_mpas=9):
    if num_levels is not None:
        data = pd.DataFrame(
            (
                OrderedDict(
                    {
                        "stokes": stokes, "MPA": MPA, "k": k,
                        "data": restrict_to(ds, MPA=MPA, z=z_values[k])
                    }
                )
                for MPA in range(1, 1 + num_mpas)
                for k in range(num_levels)
            )
        )
    else:
        data = pd.DataFrame(
            (
                OrderedDict(
                    {
                        "stokes": stokes, "MPA": MPA, "k": -1,
                        "data": restrict_to(ds, MPA=MPA, z=None)
                    }
                )
                for MPA in range(1, 1 + num_mpas)
            )
        )
    return data

The following will trigger computations.

In [None]:
# quick-access dataframe for stokes drift data at surface
data = wrap_in_dataframe(ds_stokes, stokes=True, num_levels=1, num_mpas=9)

# add non-stokes data per level
data = data.append(
    wrap_in_dataframe(ds_nostokes, stokes=False, num_levels=len(z_values), num_mpas=9),
    ignore_index=True
)

# add non-stokes data without distinguishing levels
data = data.append(
    wrap_in_dataframe(ds_nostokes, stokes=False, num_levels=None, num_mpas=9),
    ignore_index=True
)

In [None]:
data

## Create thinned out data

We're not sure if we need all the statistics.  Create sub-sampled datasets that only have 1%, 5%, and 10% of the data.  Sub-sampling is done randomly.

In [None]:
def get_thinned_data(ds, percent=50, seed=None):
    """Return dataset thinned to a percentage by randomly picking trajectories."""
    if seed is not None:
        np.random.seed(seed)
    traj_indices = (np.random.uniform(0, 1, size=ds.z.shape) < (percent / 100.0))
    ds = ds.isel(traj=traj_indices)
    return ds

In [None]:
for perc in [1, 5, 10]:
    data[f"thinned_data_{perc:03d}_percent"] = data["data"].apply(lambda ds: get_thinned_data(ds, percent=perc))

In [None]:
data

## Make it easy to index

In [None]:
data = data.set_index(keys=["stokes", "MPA", "k"])

## Check data volumes

In [None]:
def get_total_size(series):
    return series.apply(lambda dobj: dobj.nbytes).sum()

In [None]:
print("All data (doubly counting non-stokes data):", get_total_size(data["data"]) / 1e9, "GB")
print("All thinned data (10%):", get_total_size(data["thinned_data_010_percent"]) / 1e9, "GB")
print("All thinned data (5%):", get_total_size(data["thinned_data_005_percent"]) / 1e9, "GB")
print("All thinned data (1%):", get_total_size(data["thinned_data_001_percent"]) / 1e9, "GB")

In [None]:
import warnings
warnings.filterwarnings(filter_warnings)

## Store dataframe for later re-use.  Then re-load to check.

In [None]:
import cloudpickle

In [None]:
!mkdir -p intermediate_data

In [None]:
with open("intermediate_data/all_traj_dataframe.pickle", mode="wb") as f:
    cloudpickle.dump(data, f)

In [None]:
with open("intermediate_data/all_traj_dataframe.pickle", mode="rb") as f:
    data = cloudpickle.load(f)

# Technical documentation

Lists the whole working environment.

In [None]:
%pip list

In [None]:
%conda list --explicit