In [1]:
# %pip install seaborn xhistogram "matplotlib<3.1"

In [2]:
# Workaround to ensure lazy evaluation of xhistorams
# see https://github.com/xgcm/xhistogram/issues/12

import os
os.environ["NUMPY_EXPERIMENTAL_ARRAY_FUNCTION"] = "0"

# Heat maps

In [3]:
# parameters
dataset_version = "v2019.09.11.2"
bucket_stokes = f"pangeo-parcels/med_sea_connectivity_{dataset_version}/traj_data_with_stokes.zarr"
bucket_nostokes = f"pangeo-parcels/med_sea_connectivity_{dataset_version}/traj_data_without_stokes.zarr"

species = {
    "Striped red mullet": {
        "settling_start_days": 25,
        "settling_end_days": 35,
        "min_temp": 10.0,
        "max_temp": 28.0,
        "latest_spawning_month": 7
    },
    "White seabream": {
        "settling_start_days": 26,
        "settling_end_days": 30,
        "min_temp": 10.0,
        "max_temp": 28.0,
        "latest_spawning_month": 6
    },
    "Comber": {
        "settling_start_days": 21,
        "settling_end_days": 28,
        "min_temp": 10.0,
        "max_temp": 28.0,
        "latest_spawning_month": 7
    }
}

MPA_names = {
    0: 'Ocean', 
    1: 'Cote Bleue',
    2: 'Cerbere-Banyuls',
    3: 'Cape de Creus',
    4: 'Columbretes',
    5: 'Cala Ratjada',
    6: 'Menorca',
    7: 'Tabarca',
    8: 'Cabo de Palos',
    9: 'Cabo de Gata',
    10: 'Coast'
}

n_lon_bins = 200
n_lat_bins = 200

filter_warnings = "ignore"  # No warnings will bother you.  Change for debugging.

In [4]:
import warnings
warnings.filterwarnings(filter_warnings)

## Load all modules and spin up a Dask cluster

In [5]:
from dask import config as dconf

In [6]:
dconf.get("kubernetes.worker-template.spec.containers")

[{'args': ['dask-worker',
   '--nthreads',
   '2',
   '--no-bokeh',
   '--memory-limit',
   '7GB',
   '--death-timeout',
   '60'],
  'image': '${JUPYTER_IMAGE_SPEC}',
  'name': 'dask-${JUPYTERHUB_USER}',
  'resources': {'limits': {'cpu': 2, 'memory': '7G'},
   'requests': {'cpu': 1, 'memory': '7G'}}}]

In [7]:
dconf.set({"kubernetes.worker-template.spec.containers":
    [{
        'args': [
            'dask-worker',
            '--nthreads',
            '2',
            '--no-bokeh',
            '--memory-limit',
            '15.5GB',
            '--death-timeout',
            '60'
        ],
        'image': '${JUPYTER_IMAGE_SPEC}',
        'name': 'dask-${JUPYTERHUB_USER}',
        'resources': {'limits': {'cpu': 4, 'memory': '15.5G'},
                      'requests': {'cpu': 2, 'memory': '15.5G'}}}]
})

<dask.config.set at 0x7f55dc67b470>

In [8]:
%matplotlib inline
import seaborn as sns
from dask import array as da
import numpy as np
import xarray as xr
from gcsfs.mapping import GCSMap
from xhistogram.xarray import histogram as xhist
from matplotlib import pyplot as plt
import pandas as pd
from dask import delayed

In [9]:
sns.set_context("talk")
sns.set_style("darkgrid")

In [10]:
from dask.distributed import Client, progress

from dask_kubernetes import KubeCluster
cluster = KubeCluster(n_workers=20, silence_logs=50)
cluster.adapt(minimum=20, maximum=100, wait_count=15)

client = Client(cluster, silence_logs=50)
client

0,1
Client  Scheduler: tcp://10.32.3.55:44697  Dashboard: /user/willirath/proxy/8787/status,Cluster  Workers: 0  Cores: 0  Memory: 0 B


** ☝️ Don't forget to click the link above to view the scheduler dashboard! **

## Load data

All trajectories are stored in a Google Cloud Storage bucket. We want to be able to load and filter all trajectories easily.  To this end, we load all the datasets (lazily), filter them to different parameters (starting MPA, depth, stokes drift), and store a Pandas dataframe with virtual sub-datasets for each combination of the parameters.  This Pandas dataframe will be pickled for later re-use.

In [11]:
def open_dataset(bucket, restrict_to_MPA=None, restrict_to_z=None):
    # load data
    gcsmap = GCSMap(bucket)
    ds = xr.open_zarr(gcsmap, decode_cf=True)
    
    # get info on starting region and make it an easy-to-look-up coord
    initial_MPA = ds.MPA.isel(obs=0).squeeze()
    ds.coords["initial_MPA"] = initial_MPA
     
    # add mask that is False after land contact
    ds["before_land_contact"] = ((ds.land == 0).cumprod("obs") == 1)
      
    return ds

In [12]:
ds_stokes = open_dataset(bucket_stokes)
ds_nostokes = open_dataset(bucket_nostokes)

## Simplify

We know a few things about our data that make it easier to deal with them:

- No vertical migration.  Hence, initial depth of a particle is valid for all times.

- All time steps are the same. Hence, we can easily build a relative time axis that is valid for all particles.

In [13]:
def apply_assumptions(ds):
    """Applies simplifications to the dataset that are valid for the 
    specific set of experiments we're dealing with here.
    
    Be careful when applying these to new experiments, because
    they might not apply.
    """
    # We assume no vertical migration and hence
    # make (non-changing) depth level an easy to look up coord
    z = ds.z.isel(obs=0).squeeze()
    ds["z"] = z
    ds.coords["z"] = ds.z
    
    # We assume that all time steps are equal
    # and that the time axis is measured in seconds
    # since some reference period
    time_axis = ds.reset_coords(["z", "initial_MPA"]).time.isel(traj=0).squeeze()
    time_axis -= time_axis.isel(obs=0).squeeze()
    time_axis.attrs["units"] = "seconds since start of particle"
    ds.coords["time_axis"] = time_axis
    
    return ds

In [14]:
ds_stokes = apply_assumptions(ds_stokes)
ds_nostokes = apply_assumptions(ds_nostokes)

## Load coordinates for quicker access

So far, we did only the bare minimum of information (data types, variable names, number of time steps, ...) but did not load any of the data.  We want to continue to do so for the bulk of the data, but get coordinates and the like now.

In [15]:
def persist_coords(ds, retries=40):
    """Will load coordinate data to the cluster."""
    ds["z"] = ds["z"].persist(retries=retries)
    ds["initial_MPA"] = ds["initial_MPA"].persist(retries=retries)
    ds["time_axis"] = ds["time_axis"].persist(retries=retries)
    return ds

In [16]:
def compute_coords(ds, retries=40):
    """Will load coordinate data to the front end."""
    ds["z"] = ds["z"].compute(retries=retries)
    ds["initial_MPA"] = ds["initial_MPA"].compute(retries=retries)
    ds["time_axis"] = ds["time_axis"].compute(retries=retries)
    return ds

In [17]:
ds_stokes = persist_coords(ds_stokes)
ds_nostokes = persist_coords(ds_nostokes)

In [18]:
ds_stokes = compute_coords(ds_stokes)
ds_nostokes = compute_coords(ds_nostokes)

In [19]:
ds_stokes

<xarray.Dataset>
Dimensions:              (obs: 962, traj: 2625480)
Coordinates:
    z                    (traj) float32 1.0182366 1.0182366 ... 1.0182366
    initial_MPA          (traj) float32 1.0 1.0 1.0 1.0 1.0 ... 9.0 9.0 9.0 9.0
    time_axis            (obs) timedelta64[ns] 00:00:00 ... 40 days 00:00:00
Dimensions without coordinates: obs, traj
Data variables:
    MPA                  (traj, obs) float32 dask.array<chunksize=(100000, 962), meta=np.ndarray>
    distance             (traj, obs) float32 dask.array<chunksize=(100000, 962), meta=np.ndarray>
    land                 (traj, obs) float32 dask.array<chunksize=(100000, 962), meta=np.ndarray>
    lat                  (traj, obs) float32 dask.array<chunksize=(100000, 962), meta=np.ndarray>
    lon                  (traj, obs) float32 dask.array<chunksize=(100000, 962), meta=np.ndarray>
    temp                 (traj, obs) float32 dask.array<chunksize=(100000, 962), meta=np.ndarray>
    time                 (traj, obs) datet

In [20]:
ds_nostokes

<xarray.Dataset>
Dimensions:              (obs: 962, traj: 13188600)
Coordinates:
    z                    (traj) float32 1.0182366 1.0182366 ... 10.536604
    initial_MPA          (traj) float32 1.0 1.0 1.0 1.0 1.0 ... 9.0 9.0 9.0 9.0
    time_axis            (obs) timedelta64[ns] 00:00:00 ... 40 days 00:00:00
Dimensions without coordinates: obs, traj
Data variables:
    MPA                  (traj, obs) float32 dask.array<chunksize=(100000, 962), meta=np.ndarray>
    distance             (traj, obs) float32 dask.array<chunksize=(100000, 962), meta=np.ndarray>
    land                 (traj, obs) float32 dask.array<chunksize=(100000, 962), meta=np.ndarray>
    lat                  (traj, obs) float32 dask.array<chunksize=(100000, 962), meta=np.ndarray>
    lon                  (traj, obs) float32 dask.array<chunksize=(100000, 962), meta=np.ndarray>
    temp                 (traj, obs) float32 dask.array<chunksize=(100000, 962), meta=np.ndarray>
    time                 (traj, obs) date

In [21]:
def get_z_values(ds):
    """Load unique z-values to the front end.
    
    This triggers a computation across all of the z-level data.
    """
    z_values = da.unique(ds.z.data).compute(retries=40)
    z_values = z_values[~np.isnan(z_values)]
    return z_values

In [22]:
z_values = get_z_values(ds_nostokes)

In [23]:
print(z_values)

[ 1.0182366  3.1657474  5.4649634  7.9203773 10.536604 ]


## Kill based on temperature

In [24]:
def apply_temp_range(ds, species_name):
    # If we're close to land, temp=0 gets linearly
    # interpolated to the particle position
    actual_temp = ds.temp / (1 - ds.land)
    
    does_not_die = (
        (actual_temp > species[species_name]["min_temp"])
        & (actual_temp < species[species_name]["max_temp"])
    )
    ds["does_not_die"] = does_not_die
    
    is_alive = (does_not_die.cumprod("obs") == 1)
    ds["is_alive"] = is_alive
    
    return ds

In [25]:
# Pick first species. We don't have varying temperature ranges (yet?).
ds_nostokes = apply_temp_range(ds_nostokes, "Striped red mullet")
ds_nostokes

<xarray.Dataset>
Dimensions:              (obs: 962, traj: 13188600)
Coordinates:
    z                    (traj) float32 1.0182366 1.0182366 ... 10.536604
    initial_MPA          (traj) float32 1.0 1.0 1.0 1.0 1.0 ... 9.0 9.0 9.0 9.0
    time_axis            (obs) timedelta64[ns] 00:00:00 ... 40 days 00:00:00
Dimensions without coordinates: obs, traj
Data variables:
    MPA                  (traj, obs) float32 dask.array<chunksize=(100000, 962), meta=np.ndarray>
    distance             (traj, obs) float32 dask.array<chunksize=(100000, 962), meta=np.ndarray>
    land                 (traj, obs) float32 dask.array<chunksize=(100000, 962), meta=np.ndarray>
    lat                  (traj, obs) float32 dask.array<chunksize=(100000, 962), meta=np.ndarray>
    lon                  (traj, obs) float32 dask.array<chunksize=(100000, 962), meta=np.ndarray>
    temp                 (traj, obs) float32 dask.array<chunksize=(100000, 962), meta=np.ndarray>
    time                 (traj, obs) date

In [26]:
ds_stokes = apply_temp_range(ds_stokes, "Striped red mullet")
ds_stokes

<xarray.Dataset>
Dimensions:              (obs: 962, traj: 2625480)
Coordinates:
    z                    (traj) float32 1.0182366 1.0182366 ... 1.0182366
    initial_MPA          (traj) float32 1.0 1.0 1.0 1.0 1.0 ... 9.0 9.0 9.0 9.0
    time_axis            (obs) timedelta64[ns] 00:00:00 ... 40 days 00:00:00
Dimensions without coordinates: obs, traj
Data variables:
    MPA                  (traj, obs) float32 dask.array<chunksize=(100000, 962), meta=np.ndarray>
    distance             (traj, obs) float32 dask.array<chunksize=(100000, 962), meta=np.ndarray>
    land                 (traj, obs) float32 dask.array<chunksize=(100000, 962), meta=np.ndarray>
    lat                  (traj, obs) float32 dask.array<chunksize=(100000, 962), meta=np.ndarray>
    lon                  (traj, obs) float32 dask.array<chunksize=(100000, 962), meta=np.ndarray>
    temp                 (traj, obs) float32 dask.array<chunksize=(100000, 962), meta=np.ndarray>
    time                 (traj, obs) datet

In [27]:
# ds_nostokes.is_alive.astype("float").mean("traj").compute(retries=40).plot(label="no stokes, all layers")
# ds_stokes.is_alive.astype("float").mean("traj").compute(retries=40).plot(label="stokes, surface");
# plt.gca().legend();

In [28]:
# actually apply the temp range
ds_stokes = ds_stokes.where(ds_stokes.is_alive)
ds_nostokes = ds_nostokes.where(ds_nostokes.is_alive)

## Be able to filter for spawning month

Not all species spawn for the same time span.  "White seabream" only spawns till June.  As the physical trajectories are otherwise equivalent, we want to be able to remove those trajectories that start too late in the year:

## Be able to filter for settling period

In [29]:
def filter_settling_period(ds, species_name):
    return ds.isel(
        obs=slice(
            int(species[species_name]["settling_start_days"] * 24),
            int(species[species_name]["settling_end_days"] * 24) + 1
        )
    )

## Be able to find trajectories that _never_ touch land

In [30]:
def get_never_touches_land(ds):
    ds =  ds.where(ds.land.max("obs") == 0)
    return ds

## Heat Maps

In [31]:
def get_bnds(ds, varname):
    var_min = ds[varname].min().persist(retries=40)
    var_max = ds[varname].max().persist(retries=40)
    var_bnds = [var_min.compute(), var_max.compute()]
    return var_bnds

In [32]:
# lat_bnds_stokes = get_bnds(ds_stokes, "lat")
# lat_bnds_nostokes = get_bnds(ds_nostokes, "lat")
# lat_bnds = [
#     min(lat_bnds_stokes[0], lat_bnds_stokes[1]),
#     max(lat_bnds_stokes[0], lat_bnds_stokes[1]),
# ]

# lon_bnds_stokes = get_bnds(ds_stokes, "lon")
# lon_bnds_nostokes = get_bnds(ds_nostokes, "lon")
# lon_bnds = [
#     min(lon_bnds_stokes[0], lon_bnds_stokes[1]),
#     max(lon_bnds_stokes[0], lon_bnds_stokes[1]),
# ]

In [33]:
# print("lat_bnds", lat_bnds)
# print("lon_bnds", lon_bnds)

In [34]:
lat_bnds = [35.0, 44.0]
lon_bnds = [-6.0, 10.0]

In [35]:
lat_bins = np.linspace(lat_bnds[0], lat_bnds[1], n_lat_bins + 1)
lon_bins = np.linspace(lon_bnds[0], lon_bnds[1], n_lon_bins + 1)

In [36]:
def heat_map(ds):
    heat_map = xhist(
        ds.lat,
        ds.lon,
        bins=[lat_bins, lon_bins],
        dim=["traj", "obs"]
    )
    
    heat_map = heat_map.where(heat_map > 0)
        
    return heat_map

## Heat map stokes vs. no stokes, surface

In [37]:
heat_map_stokes_z0 = heat_map(ds_stokes)
heat_map_stokes_z0

<xarray.DataArray 'histogram_lat_lon' (lat_bin: 200, lon_bin: 200)>
dask.array<where, shape=(200, 200), dtype=float64, chunksize=(200, 200), chunktype=numpy.ndarray>
Coordinates:
  * lat_bin  (lat_bin) float64 35.02 35.07 35.11 35.16 ... 43.89 43.93 43.98
  * lon_bin  (lon_bin) float64 -5.96 -5.88 -5.8 -5.72 ... 9.72 9.8 9.88 9.96

In [38]:
heat_map_nostokes_z0 = heat_map(ds_nostokes.where(ds_nostokes.z == z_values[0]))
heat_map_nostokes_z0

<xarray.DataArray 'histogram_lat_lon' (lat_bin: 200, lon_bin: 200)>
dask.array<where, shape=(200, 200), dtype=float64, chunksize=(200, 200), chunktype=numpy.ndarray>
Coordinates:
  * lat_bin  (lat_bin) float64 35.02 35.07 35.11 35.16 ... 43.89 43.93 43.98
  * lon_bin  (lon_bin) float64 -5.96 -5.88 -5.8 -5.72 ... 9.72 9.8 9.88 9.96

In [39]:
heat_map_stokes_z0 = heat_map_stokes_z0.persist(retries=40)
heat_map_nostokes_z0 = heat_map_nostokes_z0.persist(retries=40)

In [40]:
import cartopy.crs as ccrs

In [41]:
def plot_two_histograms(hist1, hist2, title1, title2):

    fig = plt.figure(figsize=(24, 6))

    ax1 = plt.subplot(1, 2, 1, projection=ccrs.PlateCarree())
    ax1.coastlines('10m')
    # ax1.set_extent([-45, 55, 20, 80], ccrs.PlateCarree())

    gl1 = ax1.gridlines(
        crs=ccrs.PlateCarree(), draw_labels=True,
        linewidth=1, color='gray', alpha=0.5, linestyle='--')
    gl1.xlabels_top = False
    gl1.ylabels_right = False

    ax2 = plt.subplot(1, 2, 2, projection=ccrs.PlateCarree())
    ax2.coastlines('10m')
    # ax2.set_extent([-45, 55, 20, 80], ccrs.PlateCarree())

    gl2 = ax2.gridlines(
        crs=ccrs.PlateCarree(), draw_labels=True,
        linewidth=1, color='gray', alpha=0.5, linestyle='--')
    gl2.xlabels_top = False
    gl2.ylabels_right = False

    xr.ufuncs.log10(hist1).plot(ax=ax1, vmin=1, vmax=6, transform=ccrs.PlateCarree(), rasterized=True)
    xr.ufuncs.log10(hist2).plot(ax=ax2, vmin=1, vmax=6, transform=ccrs.PlateCarree(), rasterized=True)

    ax1.set_title(title1)
    ax2.set_title(title1)

    # fig.tight_layout();
    
    return fig

In [None]:
fig = plot_two_histograms(
    heat_map_stokes_z0.compute(retries=40),
    heat_map_nostokes_z0.compute(retries=40),
    "stokes, surface, all positions, first 40 days, logscale",
    "no stokes, surface, all positions, first 40 days, logscale"
)

In [None]:
fig.savefig("plots/heat_maps_stokes_vs_nostokes_z0.png")
fig.savefig("plots/heat_maps_stokes_vs_nostokes_z0.pdf")

## Heat map no stokes, z min vs. z max, MPA5

In [None]:
heat_map_nostokes_z0_MPA5 = heat_map(
    ds_nostokes.where(
        (ds_nostokes.initial_MPA == 5)
        & (ds_nostokes.z == z_values[0])
    )
)
heat_map_nostokes_z0_MPA5

In [None]:
heat_map_nostokes_z4_MPA5 = heat_map(
    ds_nostokes.where(
        (ds_nostokes.initial_MPA == 5)
        & (ds_nostokes.z == z_values[4])
    )
)
heat_map_nostokes_z4_MPA5

In [None]:
heat_map_nostokes_z0_MPA5 = heat_map_nostokes_z0_MPA5.persist(retries=40)
heat_map_nostokes_z4_MPA5 = heat_map_nostokes_z4_MPA5.persist(retries=40)

In [None]:
fig = plot_two_histograms(
    heat_map_nostokes_z0_MPA5.compute(retries=40),
    heat_map_nostokes_z4_MPA5.compute(retries=40),
    f"no stokes, {z_values[0]}m, all positions, first 40 days, logscale",
    f"no stokes, {z_values[4]}m, all positions, first 40 days, logscale"
)

In [None]:
fig.savefig("plots/heat_maps_nostokes_MPA5_z0_vs_z4.png")
fig.savefig("plots/heat_maps_nostokes_MPA5_z0_vs_z4.pdf")

## Heat map stokes noslip vs. no land contact, surface

In [None]:
heat_map_stokes_z0_no_land_contact = heat_map(get_never_touches_land(ds_stokes))
heat_map_stokes_z0_no_land_contact

In [None]:
heat_map_stokes_z0_no_land_contact = heat_map_stokes_z0_no_land_contact.persist(retries=40)

In [None]:
fig = plot_two_histograms(
    heat_map_stokes_z0.compute(retries=40),
    heat_map_stokes_z0_no_land_contact.compute(retries=40),
    "stokes, surface, all positions, first 40 days, logscale",
    "stokes, surface, all positions, no land contact, first 40 days, logscale"
)

In [None]:
fig.savefig("plots/heat_maps_stokes_z0_noslip_vs_nolandcontact.png")
fig.savefig("plots/heat_maps_stokes_z0_noslip_vs_nolandcontact.pdf")

# Technical documentation

Lists the whole working environment.

In [None]:
%pip list

In [None]:
%conda list --explicit