# Prepare ecozone data (FAO global ecological zones)

In [None]:
# libraries
import os, time, shutil, rasterio, rioxarray
import numpy as np
import xarray as xr
import geopandas as gpd
from scipy import ndimage
from dask.distributed import Lock
import matplotlib.pyplot as plt

In [None]:
# Directories
dir_data =  '../data/'
dir01 = '../paper_deficit/output/01_prep/'

---

### Pre-processing

In [None]:
# Ecozone shapefile to tif
# File pathes
path_rst = f'{dir_data}/worldclim/wc2.1_30s_elev/wc2.1_30s_elev.tif'
path_shp = f'{dir_data}/fao2010gez/gez_2010_wgs84.shp'
path_out = f'{dir01}/fao2010_gez.tif'

# Read target tif
rst = rasterio.open(path_rst)

# Read shapefile
shp = gpd.read_file(path_shp)

# Reproject
shp = shp.to_crs("EPSG:4326")

# Ret meta data from target raster
meta = rst.meta.copy()
meta.update(compress='lzw')
# File
# Burn shapefile on target grid and export
with rasterio.open(path_out, 'w+', **meta) as out:
    out_arr = out.read(1)

    shapes = ((geom,value) for geom, value in 
              zip(shp.geometry, shp['gez_code']))

    burned = rasterio.features.rasterize(shapes=shapes, 
                                fill=0, 
                                out=out_arr, 
                                transform=out.transform)

    out.write_band(1, burned)

# Read tif
da = rioxarray.open_rasterio(path_out, chunks=True)
# Set nan values
da = da.where(da != da.attrs['_FillValue'])
# Set attributes
da.attrs['long_name'] = 'fao2010_gez'
# Export as tif
da.rio.to_raster(path_out);

In [None]:
# Reclassify
# Read data
da_gez = rioxarray.open_rasterio(dir01 + 'fao2010_gez.tif', 
                                 chunks=dict(y=5000, x=5000))
# Drop band dimension
da_gez = da_gez[0].drop_vars('band')

# Group regions
list_trop1 = [11]
list_trop2 = [12, 13, 14, 15, 16]
list_subt = [21, 22, 23, 24, 25]
list_temp = [31, 32, 33, 34, 35]
list_bore = [41, 42, 43]
list_pola = [50]

# Create dataset with each class as variable
ds = xr.Dataset()
ds['fao2010_trop1'] = xr.where(da_gez.isin(list_trop1), 1, 0).astype('int8')
ds['fao2010_trop2'] = xr.where(da_gez.isin(list_trop2), 2, 0).astype('int8')
ds['fao2010_subt'] = xr.where(da_gez.isin(list_subt), 3, 0).astype('int8')
ds['fao2010_temp'] = xr.where(da_gez.isin(list_temp), 4, 0).astype('int8')
ds['fao2010_bore'] = xr.where(da_gez.isin(list_bore), 5, 0).astype('int8')
ds['fao2010_pola'] = xr.where(da_gez.isin(list_pola), 6, 0).astype('int8')

# Transform dataset to dataarray
da = ds.to_array().sum("variable")
# Set non-land to 0
da = da.where(da > 0)
# Select lat between 80 and -60
da = da.sel(y=slice(80, -60))
# Rename x and y
da = da.rename(dict(y='lat', x='lon'))

---

### Fill nan

In [None]:
def fill_nans(da_var, var, dir_out):
    """
    Fills NaN values in the specified variable's dataset using the nearest valid 
    data, applies a land mask, and exports the result to a new Zarr dataset.

    Args:
        var (str): Name of the variable to process (e.g., 'temperature', 'precipitation').
        dir_out (str): Directory where prepared data is stored and the filled dataset will be exported.

    Returns:
        None
    """

    def fill_nans_array(data, invalid):
        """
        Replace invalid (NaN) data cells by the value of the nearest valid data 
        cell.
        """
        ind = ndimage.distance_transform_edt(invalid,
                                             return_distances=False,
                                             return_indices=True)
        return data[tuple(ind)]

    # Paths for input and output
    land_mask_path = os.path.join(dir_out, 'ds_prep_copernicus_land_mask.zarr')
    output_path = os.path.join(dir_out, f'ds_prep_{var}.zarr')

    # Read land mask data
    da_land = xr.open_zarr(land_mask_path) \
                .chunk(dict(lat=5000, lon=5000)) \
                .copernicus_land_mask \
                .compute()

    # Fill nan using function fill_nans_array
    # If there are no NaNs, skip filling process
    if not da_var.isnull().any():
        da_fill = da_var.values  # No filling required
    else:
        da_fill = fill_nans_array(da_var.values, da_var.isnull().values)

    # Create a new Dataset with filled data
    ds_filled = xr.Dataset(dict(lat = da_var.lat, lon=da_var.lon))
    ds_filled[var] = (('lat', 'lon'), da_fill)

    # Apply land mask to the filled data
    ds_filled = ds_filled.where(da_land.values)

    # Export the filled dataset to Zarr format
    ds_filled.chunk(dict(lat=5000, lon=5000)) \
             .to_zarr(output_path, mode='w')

In [None]:
fill_nans(da, 'fao2010', dir01)

---

### Check

In [None]:
xr.open_zarr(dir01 + 'ds_prep_' + 'fao2010' + '.zarr')

In [None]:
# Plot to check
for i in ['fao2010']:
    fig, ax = plt.subplots(figsize=(10, 5), ncols=1, nrows=1)
    xr.open_zarr(dir01 + 'ds_prep_' + i + '.zarr')[i] \
        .plot.imshow(ax=ax)
    ax.set_title(i)