# Prepare potential training grid cells

In [None]:
dir01 = '../paper_deficit/output/01_prep/'
dir02 = '../paper_deficit/output/02_dbase/'

In [None]:
import xarray as xr
import matplotlib.pyplot as plt

---

In [None]:
# Read data
ds_hilda_stable_fso = xr.open_zarr(dir01 + 'ds_prep_hilda_stable_fso.zarr')
ds_esacci_stable_fso = xr.open_zarr(dir01 + 'ds_prep_esacci_stable_fso.zarr')
ds_cop_crops = xr.open_zarr(dir01 + 'ds_prep_cop_crops.zarr')
ds_cop_builtup = xr.open_zarr(dir01 + 'ds_prep_cop_builtup.zarr')
ds_riggio_vlhi0 = xr.open_zarr(dir01 + 'ds_prep_riggio_vlhi0.zarr')
ds_riggio_vlhi4 = xr.open_zarr(dir01 + 'ds_prep_riggio_vlhi4.zarr')
ds_wdpa_ia = xr.open_zarr(dir01 + 'ds_prep_wdpa_ia.zarr')
ds_wdpa_ib = xr.open_zarr(dir01 + 'ds_prep_wdpa_ib.zarr')
ds_wdpa_ii = xr.open_zarr(dir01 + 'ds_prep_wdpa_ii.zarr')
ds_wdpa_iii = xr.open_zarr(dir01 + 'ds_prep_wdpa_iii.zarr')
ds_wdpa_iv = xr.open_zarr(dir01 + 'ds_prep_wdpa_iv.zarr')
ds_wdpa_v = xr.open_zarr(dir01 + 'ds_prep_wdpa_v.zarr')
ds_wdpa_vi = xr.open_zarr(dir01 + 'ds_prep_wdpa_vi.zarr')

ds_land = xr.open_zarr(dir01 + 'ds_prep_copernicus_land_mask.zarr')

In [None]:
# Select potential training grid cells for primary land simulations
da_pot_prim = xr.where((ds_hilda_stable_fso.hilda_stable_fso > 0.9) &
                       (ds_esacci_stable_fso.esacci_stable_fso > 0.9) &
                       (ds_cop_crops.cop_crops == 0) &
                       (ds_cop_builtup.cop_builtup == 0) &
                       (ds_riggio_vlhi0.riggio_vlhi0 > 0.9), 1, 0) \
    .where(ds_land.copernicus_land_mask == True) \
    .rename('pot_prim')

# Export as zarr
da_pot_prim.to_zarr(dir02 + 'ds_prep_pot_prim.zarr', mode='w');

In [None]:
# Select grid cells where wdpa protection is at least 90%
da_wdpa = xr.where(
    xr.where(ds_wdpa_ia.wdpa_ia > 0.9, 1, 0) |
    xr.where(ds_wdpa_ib.wdpa_ib > 0.9, 1, 0) |
    xr.where(ds_wdpa_ii.wdpa_ii > 0.9, 1, 0) |
    xr.where(ds_wdpa_iii.wdpa_iii > 0.9, 1, 0) |
    xr.where(ds_wdpa_iv.wdpa_iv > 0.9, 1, 0) |
    xr.where(ds_wdpa_v.wdpa_v > 0.9, 1, 0) |
    xr.where(ds_wdpa_vi.wdpa_vi > 0.9, 1, 0),
    1, 0)

# Select potential training grid cells for primary land simulations
da_pot_secd = xr.where((ds_hilda_stable_fso.hilda_stable_fso > 0.9) &
                       (ds_esacci_stable_fso.esacci_stable_fso > 0.9) &
                       (ds_cop_crops.cop_crops == 0) &
                       (ds_cop_builtup.cop_builtup == 0) &
                       (ds_riggio_vlhi4.riggio_vlhi4 > 0.9) &
                       (da_wdpa == 1), 1, 0) \
    .where(ds_land.copernicus_land_mask == True) \
    .rename('pot_secd')

# Export as zarr
da_pot_secd.to_zarr(dir02 + 'ds_prep_pot_secd.zarr', mode='w');

---

### Check

In [None]:
da_pot_prim.sum().values.item()

In [None]:
da_pot_secd.sum().values.item()

In [None]:
fig, ax = plt.subplots(figsize=(20, 10), ncols=1, nrows=1)
da_pot_prim.plot.imshow(ax=ax)

In [None]:
fig, ax = plt.subplots(figsize=(20, 10), ncols=1, nrows=1)
da_pot_secd.plot.imshow(ax=ax)

---

In [None]:
dir_nearth = '../data/naturalearth/'

In [None]:
import geopandas as gpd
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
# Read coastline data
coastline110 = gpd.read_file(
    dir_nearth + 'ne_110m_coastline/ne_110m_coastline.shp')

In [None]:
def plot_hexbin_pot_train(df):

    fig, ax = plt.subplots(figsize=(20, 10), ncols=1, nrows=1)
    
    coastline110.plot(ax=ax, color='#000000', linewidth=0.5)
    
    im = ax.hexbin(x=df.lon, y=df.lat, C=df.iloc[:,2], 
                       gridsize=150, reduce_C_function=sum, linewidths=0.2,
                       cmap='inferno', bins='log')
        
    cbar_ticks = [1, 5, 10, 25, 50]
            
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="2%", pad=0.2)
    cbar = fig.colorbar(im, cax=cax, label='Grid cells for training')
    cbar.set_ticks(cbar_ticks)
    cbar.set_ticklabels(cbar_ticks)

In [None]:
df_prim = da_pot_prim \
    .to_dask_dataframe(dim_order=['lat', 'lon']) \
    [['lat', 'lon', 'pot_prim']]

df_prim = df_prim[df_prim.pot_prim == 1]

plot_hexbin_pot_train(df_prim)

In [None]:
df_secd = da_pot_secd \
    .to_dask_dataframe(dim_order=['lat', 'lon']) \
    [['lat', 'lon', 'pot_secd']]

df_secd = df_secd[df_secd.pot_secd == 1]

plot_hexbin_pot_train(df_secd)