# Prepare worldclim data (Fick)

In [None]:
# Libraries
import os, time, shutil, rioxarray
import numpy as np
import xarray as xr
from scipy import ndimage
import matplotlib.pyplot as plt

In [None]:
# Directories
dir_data =  '../data/'
dir01 = '../paper_deficit/output/01_prep/'

---

### Pre-processing

In [None]:
# Libraries
from dask_jobqueue import SLURMCluster
from dask.distributed import Client
import dask

# Initialize dask
cluster = SLURMCluster(
    queue='compute',                      # SLURM queue to use
    cores=24,                             # Number of CPU cores per job
    memory='256 GB',                      # Memory per job
    account='bm0891',                     # Account allocation
    interface="ib0",                      # Network interface for communication
    walltime='00:30:00',                  # Maximum runtime per job
    local_directory='../dask/',           # Directory for local storage
    job_extra_directives=[                # Additional SLURM directives for logging
        '-o ../dask/LOG_worker_%j.o',     # Output log
        '-e ../dask/LOG_worker_%j.e'      # Error log
    ]
)

# Scale dask cluster
cluster.scale(jobs=10)

# Configurate dashboard url
dask.config.config.get('distributed').get('dashboard').update(
    {'link': '{JUPYTERHUB_SERVICE_PREFIX}/proxy/{port}/status'}
)

# Create client
client = Client(cluster)

client

In [None]:
# Function to trim memory of workers
def trim_memory() -> int:
    import ctypes
    libc = ctypes.CDLL("libc.so.6")
    return libc.malloc_trim(0)

In [None]:
def prep_worldclim():
    
    """Prepare worldclim data"""

    # File path of elevation file
    file_path_elev = dir_data + 'worldclim/wc2.1_30s_elev/wc2.1_30s_elev.tif'
    # Read elevation data
    da_elev = rioxarray.open_rasterio(file_path_elev, 
                                      chunks=dict(y=5000, x=5000))
    # Select land cells, limit to 90 to -60 latitude, export
    da_elev.where(da_elev != da_elev.attrs['_FillValue']) \
        .rename('worldclim_elev') \
        .sel(y=slice(80, -60)) \
        .drop_vars('spatial_ref') \
        .chunk(dict(y=5000, x=5000)) \
        .rio.to_raster(dir01 + 'worldclim_elev.tif')   
    
    # List of bioclimate files
    bio_list = sorted(os.listdir(dir_data + 'worldclim/wc2.1_30s_bio/'))
    
    for i in bio_list:
        # File Path of bioclimatic files
        file_path_biox = dir_data + 'worldclim/wc2.1_30s_bio/' + i
        # File name
        f_bio_x_str = ('worldclim_bio' + 
                       file_path_biox.split('_')[-1].split('.')[0])
        # Read data
        da = rioxarray.open_rasterio(file_path_biox, 
                                     chunks=dict(y=5000, x=5000))       
        # Select land cells, limit to 90 to -60 latitude, export
        da.where(da != da.attrs['_FillValue']) \
            .rename(f_bio_x_str) \
            .sel(y=slice(80, -60)) \
            .drop_vars('spatial_ref') \
            .chunk(dict(y=5000, x=5000)) \
            .rio.to_raster(dir01 + f_bio_x_str + '.tif') 
        # Release memory
        client.run(trim_memory);

In [None]:
%time prep_worldclim()

In [None]:
# Close dask cluster
cluster.close()

---

### Post-processing

In [None]:
def fill_nans(da_var, var, dir_out):
    """
    Fills NaN values in the specified variable's dataset using the nearest valid 
    data, applies a land mask, and exports the result to a new Zarr dataset.

    Args:
        var (str): Name of the variable to process (e.g., 'temperature', 'precipitation').
        dir_out (str): Directory where prepared data is stored and the filled dataset will be exported.

    Returns:
        None
    """

    def fill_nans_array(data, invalid):
        """
        Replace invalid (NaN) data cells by the value of the nearest valid data 
        cell.
        """
        ind = ndimage.distance_transform_edt(invalid,
                                             return_distances=False,
                                             return_indices=True)
        return data[tuple(ind)]

    # Paths for input and output
    land_mask_path = os.path.join(dir_out, 'ds_prep_copernicus_land_mask.zarr')
    output_path = os.path.join(dir_out, f'ds_prep_{var}.zarr')

    # Read land mask data
    da_land = xr.open_zarr(land_mask_path) \
                .chunk(dict(lat=5000, lon=5000)) \
                .copernicus_land_mask \
                .compute()

    # Fill nan using function fill_nans_array
    # If there are no NaNs, skip filling process
    if not da_var.isnull().any():
        da_fill = da_var.values  # No filling required
    else:
        da_fill = fill_nans_array(da_var.values, da_var.isnull().values)

    # Create a new Dataset with filled data
    ds_filled = xr.Dataset(dict(lat = da_var.lat, lon=da_var.lon))
    ds_filled[var] = (('lat', 'lon'), da_fill)

    # Apply land mask to the filled data
    ds_filled = ds_filled.where(da_land)

    # Export the filled dataset to Zarr format
    ds_filled.chunk(dict(lat=5000, lon=5000)) \
             .to_zarr(output_path, mode='w')

In [None]:
for i in ['worldclim_bio1', 'worldclim_bio2', 'worldclim_bio3', 
          'worldclim_bio4', 'worldclim_bio5', 'worldclim_bio6',
          'worldclim_bio7', 'worldclim_bio8', 'worldclim_bio9',
          'worldclim_bio10', 'worldclim_bio11', 'worldclim_bio12',
          'worldclim_bio13', 'worldclim_bio14', 'worldclim_bio15',
          'worldclim_bio16', 'worldclim_bio17', 'worldclim_bio18',
          'worldclim_bio19', 'worldclim_elev']:
    # Read data
    da_var = rioxarray.open_rasterio(dir01 + i + '.tif')[0] \
        .rename(dict(y='lat', x='lon'))
    # Fill nans with nearest neighbour if there are nans
    %time fill_nans(da_var, i, dir01)

---

### Check

In [None]:
# Plot to check
for i in ['worldclim_bio1', 'worldclim_bio2', 'worldclim_bio3', 
          'worldclim_bio4', 'worldclim_bio5', 'worldclim_bio6',
          'worldclim_bio7', 'worldclim_bio8', 'worldclim_bio9',
          'worldclim_bio10', 'worldclim_bio11', 'worldclim_bio12',
          'worldclim_bio13', 'worldclim_bio14', 'worldclim_bio15',
          'worldclim_bio16', 'worldclim_bio17', 'worldclim_bio18',
          'worldclim_bio19', 'worldclim_elev']:
    fig, ax = plt.subplots(figsize=(10, 5), ncols=1, nrows=1)
    xr.open_zarr(dir01 + 'ds_prep_' + i + '.zarr')[i] \
        .plot.imshow(ax=ax)
    ax.set_title(i)