# Prepare HILDA+ LUC data (Winkler)

In [None]:
# Libraries
import os, time, shutil, rioxarray
import numpy as np
import xarray as xr
from scipy import ndimage
import matplotlib.pyplot as plt

In [None]:
# Directories
dir_data =  '../data/'
dir01 = '../paper_deficit/output/01_prep/'

---

### Pre-processing

In [None]:
# Libraries
from dask_jobqueue import SLURMCluster
from dask.distributed import Client
import dask

# Initialize dask
cluster = SLURMCluster(
    queue='compute',                      # SLURM queue to use
    cores=24,                             # Number of CPU cores per job
    memory='256 GB',                      # Memory per job
    account='bm0891',                     # Account allocation
    interface="ib0",                      # Network interface for communication
    walltime='00:30:00',                  # Maximum runtime per job
    local_directory='../dask/',           # Directory for local storage
    job_extra_directives=[                # Additional SLURM directives for logging
        '-o ../dask/LOG_worker_%j.o',     # Output log
        '-e ../dask/LOG_worker_%j.e'      # Error log
    ]
)

# Scale dask cluster
cluster.scale(jobs=10)

# Configurate dashboard url
dask.config.config.get('distributed').get('dashboard').update(
    {'link': '{JUPYTERHUB_SERVICE_PREFIX}/proxy/{port}/status'}
)

# Create client
client = Client(cluster)

client

In [None]:
# Function to trim memory of workers
def trim_memory() -> int:
    import ctypes
    libc = ctypes.CDLL("libc.so.6")
    return libc.malloc_trim(0)

In [None]:
def prep_hilda_stable():
    """
    Prepare HILDA data. This process includes reading the HILDA dataset, 
    extracting land cover classes, identifying stable grid cells, and exporting 
    the results.
    """

    # Load the HILDA dataset, selecting data from 1960 onwards
    hilda = xr.open_dataset(
        f'{dir_data}hilda/hildap_vGLOB-1.0-f_netcdf/hildaplus_GLOB-1-0-f_states.nc',
        decode_coords='all') \
        .chunk({'latitude': 5000, 'longitude': 5000, 'time': 1}) \
        .sel(time=slice(1960, 2020)) \
        .LULC_states

    # Create a new dataset with each land cover class as a separate variable
    hilda_processed = xr.Dataset({
        'lat': hilda.latitude.values,
        'lon': hilda.longitude.values,
        'time': hilda.time.astype('int')
    })

    land_cover_classes = [22, 33, 40, 41, 42, 43, 44, 45, 55, 66]
    for lc_class in land_cover_classes:
        hilda_processed[f'hilda_{lc_class}'] = (
            ('time', 'lat', 'lon'),
            xr.where(hilda == lc_class, True, False).data)

    # Export the processed dataset to Zarr format
    hilda_processed.to_zarr(f'{dir01}hilda_prep.zarr', mode='w')

    # Open the newly saved dataset and re-chunk for efficiency
    hilda_prep = xr.open_zarr(f'{dir01}hilda_prep.zarr') \
        .chunk({'lat': 5000, 'lon': 5000, 'time': -1})

    # Identify grid cells without land cover change across the entire time period (60 years)
    stable_cells = xr.where(hilda_prep.sum('time') == 60, True, False)

    # Store stable land cover information in a new dataset
    hilda_stable = xr.Dataset()
    hilda_stable['hilda_stable_crop'] = stable_cells.hilda_22
    hilda_stable['hilda_stable_pasture'] = stable_cells.hilda_33
    hilda_stable['hilda_stable_fso'] = (
        stable_cells.hilda_40 + stable_cells.hilda_41 +
        stable_cells.hilda_42 + stable_cells.hilda_43 +
        stable_cells.hilda_44 + stable_cells.hilda_45 +
        stable_cells.hilda_55 + stable_cells.hilda_66
    )

    # Export the stable land cover dataset to Zarr format
    hilda_stable \
        .chunk({'lat': 5000, 'lon': 5000}) \
        .to_zarr(f'{dir01}hilda_stable.zarr', mode='w')

    # Convert the Zarr dataset to GeoTIFF format for each variable
    for variable in hilda_stable.data_vars:
        xr.open_zarr(f'{dir01}hilda_stable.zarr')[variable] \
            .rename({'lat': 'y', 'lon': 'x'}) \
            .astype('float32') \
            .rio.to_raster(f'{dir01}{variable}.tif')

        # Trim memory after each export
        client.run(trim_memory)

    # Clean up intermediate files to free up disk space
    shutil.rmtree(f'{dir01}hilda_prep.zarr')
    shutil.rmtree(f'{dir01}hilda_stable.zarr')
    client.run(trim_memory)


%time prep_hilda_stable()

In [None]:
def prep_hilda2015():
    """Reclassify hilda data of year 2015 to 6 LUC groups"""

    # File path
    file_path = f'{dir_data}hilda/hildap_vGLOB-1.0-f_netcdf/hildaplus_GLOB-1-0-f_states.nc'

    # Get data
    ds_hilda = xr.open_dataset(file_path, decode_coords='all') \
        .chunk(dict(latitude=5000, longitude=5000, time=1))
    
    # Select year
    da_hilda = ds_hilda.LULC_states.sel(time=2015).drop_vars('time') 

    # Create dataset with one variable for each luc class
    ds_hilda_out = xr.Dataset()
    ds_hilda_out['urban'] = xr.where(da_hilda == 11, 1, 0)
    ds_hilda_out['crop'] = xr.where(da_hilda == 22, 1, 0)
    ds_hilda_out['pasture'] = xr.where(da_hilda == 33, 1, 0)
    ds_hilda_out['forest'] = xr.where(
        da_hilda.isin([40, 41, 42, 43, 44, 45]), 1, 0)
    ds_hilda_out['shrub'] = xr.where(da_hilda == 55, 1, 0)
    ds_hilda_out['other'] = xr.where(da_hilda == 66, 1, 0)
    
    # Set non-land value to zero, rename lat and lon
    ds_hilda_out = ds_hilda_out \
        .where(da_hilda != 0) \
        .rename(dict(latitude='y', longitude='x'))
    
    # export as tif
    for i in ds_hilda_out.data_vars:
        ds_hilda_out[i].rio.to_raster(dir01 + 'hilda2015_' + i + '.tif')


%time prep_hilda2015()

In [None]:
# Close dask cluster
cluster.close()

In [None]:
# Wait for 60s for dask client to completely disconnect
time.sleep(60)

---

### Regridding

In [None]:
# Import regridding function
from regrid_high_res_v1_01 import regrid_high_res, prep_tif

In [None]:
def regrid_da(f_source, dir_target, dir_source, dir_out, 
              size_tiles, fill_value=None, olap=1):  
    """Regrid large xarray dataarrays.

    Args:
        f_source (str): The filename (without extension) of the source .tif file to be regridded.
        dir_target (str): Directory containing target grid .tif file.
        dir_source (str): Directory containing the the source  .tif file.
        dir_out (str): Directory to store the output and intermediate files.
        size_tiles (int): Size of the regridding tiles in degrees.
        fill_value (float, optional): Fill value to use in the regridding process. Defaults to None.
        olap (int, optional): Overlap size in degrees for regridding tiles. Defaults to 1.
        
    Returns:
        xarray.Dataset: The combined dataset after regridding.
    """
    # Prepare the target and source data arrays from TIFF files
    da_target = prep_tif(dir_target + 'target_grid.tif', 'target_grid')
    da_source = prep_tif(dir_source + f_source + '.tif', f_source)
    # Regridd source array to target grid
    regrid_high_res(da_target, da_source, dir_out,
                    account='bm0891', partition='shared',
                    size_tiles=size_tiles, olap=olap, fill_value = fill_value,
                    type_export='zarr', del_interm=False)

In [None]:
# Regridding
for i in ['hilda_stable_crop', 'hilda_stable_pasture', 'hilda_stable_fso']:
    %time regrid_da(i, dir01, dir01, dir01, 30, np.nan, 0.1)
# Takes about 4min for one file

In [None]:
for i in ['hilda2015_urban', 'hilda2015_crop', 'hilda2015_pasture', 
         'hilda2015_forest', 'hilda2015_shrub', 'hilda2015_other']:
    %time regrid_da(i, dir01, dir01, dir01, 30, np.nan, 0.1)
# Takes about 4min for one file

---

### Fill nans

In [None]:
def fill_nans(var, dir_out):
    """
    Fills NaN values in the specified variable's dataset using the nearest valid 
    data, applies a land mask, and exports the result to a new Zarr dataset.

    Args:
        var (str): Name of the variable to process (e.g., 'temperature', 'precipitation').
        dir_out (str): Directory where prepared data is stored and the filled dataset will be exported.

    Returns:
        None
    """

    def fill_nans_array(data, invalid):
        """
        Replace invalid (NaN) data cells by the value of the nearest valid data 
        cell.
        """
        ind = ndimage.distance_transform_edt(invalid,
                                             return_distances=False,
                                             return_indices=True)
        return data[tuple(ind)]

    # Paths for input and output
    land_mask_path = os.path.join(dir_out, 'ds_prep_copernicus_land_mask.zarr')
    var_data_path = os.path.join(dir_out, f'ds_regridded_{var}.zarr')
    output_path = os.path.join(dir_out, f'ds_prep_{var}.zarr')

    # Read land mask data
    da_land = xr.open_zarr(land_mask_path) \
                .chunk(dict(lat=5000, lon=5000)) \
                .copernicus_land_mask \
                .compute()

    # Read variable data
    da_var = xr.open_zarr(var_data_path)['regridded_' + var]

    # Fill nan using function fill_nans_array
    # If there are no NaNs, skip filling process
    if not da_var.isnull().any():
        da_fill = da_var.values  # No filling required
    else:
        da_fill = fill_nans_array(da_var.values, da_var.isnull().values)

    # Create a new Dataset with filled data
    ds_filled = xr.Dataset(dict(lat = da_var.lat, lon=da_var.lon))
    ds_filled[var] = (('lat', 'lon'), da_fill)

    # Apply land mask to the filled data
    ds_filled = ds_filled.where(da_land)

    # Export the filled dataset to Zarr format
    ds_filled.chunk(dict(lat=5000, lon=5000)) \
             .to_zarr(output_path, mode='w')

In [None]:
%%time
for i in ['hilda_stable_crop', 'hilda_stable_pasture', 'hilda_stable_fso']:
    %time fill_nans(i, dir01)

In [None]:
for i in ['hilda2015_urban', 'hilda2015_crop', 'hilda2015_pasture', 
         'hilda2015_forest', 'hilda2015_shrub', 'hilda2015_other']:
    %time fill_nans(i, dir01)

---

### Check

In [None]:
# Plot to check
for i in ['hilda_stable_crop', 'hilda_stable_pasture', 'hilda_stable_fso']:
    fig, ax = plt.subplots(figsize=(10, 5), ncols=1, nrows=1)
    xr.open_zarr(dir01 + 'ds_prep_' + i + '.zarr')[i] \
        .plot.imshow(ax=ax)
    ax.set_title(i)

In [None]:
# Plot to check
for i in ['hilda2015_urban', 'hilda2015_crop', 'hilda2015_pasture', 
          'hilda2015_forest', 'hilda2015_shrub', 'hilda2015_other']:
    fig, ax = plt.subplots(figsize=(10, 5), ncols=1, nrows=1)
    xr.open_zarr(dir01 + 'ds_prep_' + i + '.zarr')[i] \
        .plot.imshow(ax=ax)
    ax.set_title(i)