# Prepare protected areas data (WDPA)

In [None]:
# libraries
import os, time, shutil, rasterio, rioxarray
import numpy as np
import xarray as xr
import pandas as pd
import geopandas as gpd
from scipy import ndimage
import matplotlib.pyplot as plt

In [None]:
# Directories
dir_data =  '../data/'
dir01 = '../paper_deficit/output/01_prep/'

---

### Pre-processing

In [None]:
# WDPA categories
# Ia - strict nature reserve.
# Ib - wilderness area.
# II - national park.
# III - natural monument or feature.
# IV - habitat or species management area.
# V - protected landscape or seascape.
# VI - protected area with sustainable use of natural resources.

In [None]:
def prep_wdpa():
    """
    Prepare and export World Database on Protected Areas (WDPA) data for selected categories and statuses.    
    Reprojects WDPA shapefiles, filters by category and status, and exports as rasters.
    """

    
    # Path to reference raster and WDPA shapefiles
    rst_path = os.path.join(dir_data, 'worldclim/wc2.1_30s_elev/wc2.1_30s_elev.tif')
    wdpa_base = os.path.join(dir_data, 'wdpa/WDPA_Nov2022_Public_shp')

    # Read reference raster metadata
    with rasterio.open(rst_path) as rst:
        meta = rst.meta.copy()
        meta.update(compress='lzw')
    
    # Load and merge shapefiles
    shapefiles = [os.path.join(wdpa_base,f'WDPA_Nov2022_Public_shp_{i}/WDPA_Nov2022_Public_shp-polygons.shp')
                  for i in range(3)]
    shp = pd.concat([gpd.read_file(shp_file) for shp_file in shapefiles]) \
        .to_crs("EPSG:4326")
    
    # Define relevant categories and statuses
    categories = ['Ia', 'Ib', 'II', 'III', 'IV', 'V', 'VI']
    statuses = ['Adopted', 'Designated', 'Established', 'Inscribed']
    
    # Function to convert GeoDataFrame to raster
    def gpd_to_tif(shp, out_fn):
        with rasterio.open(out_fn, 'w+', **meta) as out:
            out_arr = out.read(1)
            shapes = ((geom, value) for geom, value in zip(shp.geometry, 
                                                           shp['value']))
            burned = rasterio.features.rasterize(shapes=shapes, 
                                                 fill=0,
                                                 out=out_arr, 
                                                 transform=out.transform)
            out.write_band(1, burned)
    
    # Process and export each category
    for cat in categories:
        shp_sel = shp[
        (shp['IUCN_CAT'] == cat) & (shp['STATUS'].isin(statuses))
        ][['geometry', 'STATUS_YR']]
        shp_sel = shp_sel.assign(value=1)
        
        # Export raster for each category
        output_file = os.path.join(dir01, f'wdpa_{cat.lower()}.tif')
        gpd_to_tif(shp_sel, output_file)
        
        # Post-process raster
        # Read data
        da_sel = rioxarray.open_rasterio(output_file, 
                                         chunks={'y': 5000, 'x': 5000})
        # Select grid cells that are 100% the selected category
        da_sel = xr.where(da_sel == 1, 1, 0).astype('int8')
        # Select lat between 80 and -60
        da_sel = da_sel.sel(y=slice(80, -60))
        # Adjust attributes
        da_sel.attrs['long_name'] = f'wdpa_{cat.lower()}'
        # export
        da_sel.rio.to_raster(output_file)

In [None]:
%time prep_wdpa()

---

### Post-processing

In [None]:
def fill_nans(da_var, var, dir_out):
    """
    Fills NaN values in the specified variable's dataset using the nearest valid 
    data, applies a land mask, and exports the result to a new Zarr dataset.

    Args:
        var (str): Name of the variable to process (e.g., 'temperature', 'precipitation').
        dir_out (str): Directory where prepared data is stored and the filled dataset will be exported.

    Returns:
        None
    """

    def fill_nans_array(data, invalid):
        """
        Replace invalid (NaN) data cells by the value of the nearest valid data 
        cell.
        """
        ind = ndimage.distance_transform_edt(invalid,
                                             return_distances=False,
                                             return_indices=True)
        return data[tuple(ind)]

    # Paths for input and output
    land_mask_path = os.path.join(dir_out, 'ds_prep_copernicus_land_mask.zarr')
    output_path = os.path.join(dir_out, f'ds_prep_{var}.zarr')

    # Read land mask data
    da_land = xr.open_zarr(land_mask_path) \
                .chunk(dict(lat=5000, lon=5000)) \
                .copernicus_land_mask \
                .compute()

    # Fill nan using function fill_nans_array
    # If there are no NaNs, skip filling process
    if not da_var.isnull().any():
        da_fill = da_var.values  # No filling required
    else:
        da_fill = fill_nans_array(da_var.values, da_var.isnull().values)

    # Create a new Dataset with filled data
    ds_filled = xr.Dataset(dict(lat = da_var.lat, lon=da_var.lon))
    ds_filled[var] = (('lat', 'lon'), da_fill)

    # Apply land mask to the filled data
    ds_filled = ds_filled.where(da_land)

    # Export the filled dataset to Zarr format
    ds_filled.chunk(dict(lat=5000, lon=5000)) \
             .to_zarr(output_path, mode='w')

In [None]:
for i in ['wdpa_ia', 'wdpa_ib', 'wdpa_ii', 'wdpa_iii', 'wdpa_iv', 'wdpa_v',
          'wdpa_vi']:
    # Read data
    da_var = rioxarray.open_rasterio(dir01 + i + '.tif')[0] \
        .rename(dict(y='lat', x='lon'))
    # Fill nans with nearest neighbour if there are nans
    %time fill_nans(da_var, i, dir01)

---

### Check

In [None]:
# Plot to check
for i in ['wdpa_ia', 'wdpa_ib', 'wdpa_ii', 'wdpa_iii', 'wdpa_iv', 'wdpa_v',
          'wdpa_vi']:
    fig, ax = plt.subplots(figsize=(10, 5), ncols=1, nrows=1)
    xr.open_zarr(dir01 + 'ds_prep_' + i + '.zarr')[i] \
        .plot.imshow(ax=ax)
    ax.set_title(i)

In [None]:
def check_wdpa():
    shp0 = gpd.read_file(dir_data + 'wdpa/WDPA_Nov2022_Public_shp/' + 
                            'WDPA_Nov2022_Public_shp_0/' + 
                            'WDPA_Nov2022_Public_shp-polygons.shp')

    shp1 = gpd.read_file(dir_data + 'wdpa/WDPA_Nov2022_Public_shp/' + 
                            'WDPA_Nov2022_Public_shp_1/' + 
                            'WDPA_Nov2022_Public_shp-polygons.shp')

    shp2 = gpd.read_file(dir_data + 'wdpa/WDPA_Nov2022_Public_shp/' + 
                            'WDPA_Nov2022_Public_shp_2/' + 
                            'WDPA_Nov2022_Public_shp-polygons.shp')

    # merge shape files
    shp = pd.concat([shp0, shp1, shp2])

    # select highest protection categories
    shp_cat = shp[shp['IUCN_CAT']
                  .isin(['Ia', 'Ib', 'II', 'III', 'IV', 'V', 'VI'])]

    fig, ax = plt.subplots(figsize=(20, 10), ncols=1, nrows=1, dpi=300)

    xr.open_mfdataset(dir01 + 'ds_prep_wdpa_*.zarr', engine='zarr') \
        .to_array('x') \
        .sum('x') \
        .sel(lat=slice(50, 20), lon=slice(-120, -70)) \
        .plot.imshow(ax=ax, vmin=0, vmax=1)

    shp_cat.plot(ax=ax, edgecolor='red', color='none')

    ax.set_ylim(20, 50)
    ax.set_xlim(-120, -70);
    

check_wdpa()