# Prepare data for Figure "DGVM" - Prepare other data for regridding

In [None]:
# Libraries
import os, shutil
import xarray as xr
import rioxarray

In [None]:
# Directories
dir_data = '../data/'
dir04 = '../paper_deficit/output/04_out/'
dir05x = '../paper_deficit/output/05_prep_other/fig_dgvm/'

---

In [None]:
# Create directories for other data if not exists
for i in ['luh2', 'pot', 'erb', 'walker', 'mo', 'sanderman030', 
          'sanderman100', 'sanderman200']:
    if not os.path.exists(os.path.join(dir05x, i)):
        os.mkdir(os.path.join(dir05x, i))

---

In [None]:
# Libraries
from dask_jobqueue import SLURMCluster
from dask.distributed import Client
import dask

# Initialize dask
cluster = SLURMCluster(
    queue='compute',                      # SLURM queue to use
    cores=12,                             # Number of CPU cores per job
    memory='256 GB',                      # Memory per job
    account='bm0891',                     # Account allocation
    interface="ib0",                      # Network interface for communication
    walltime='02:00:00',                  # Maximum runtime per job
    local_directory='../dask/',           # Directory for local storage
    job_extra_directives=[                # Additional SLURM directives for logging
        '-o ../dask/LOG_worker_%j.o',     # Output log
        '-e ../dask/LOG_worker_%j.e'      # Error log
    ]
)

# Scale dask cluster
cluster.scale(jobs=2)

# Configurate dashboard url
dask.config.config.get('distributed').get('dashboard').update(
    {'link': '{JUPYTERHUB_SERVICE_PREFIX}/proxy/{port}/status'}
)

# Create client
client = Client(cluster)

client

In [None]:
# Prepare luh2 data
# Get luh2 states file
ds_luh2 = xr.open_dataset(os.path.join(dir_data, 'luh2_v2h/states.nc'),
                          chunks=dict(time=100),
                          decode_times=False)

# Adjust time 
ds_luh2['time'] = ds_luh2.time.data.astype('int') + 850
# Select primary forest and non-forest in 1700 and export
(ds_luh2.primf + ds_luh2.primn) \
    .sel(time=1700) \
    .drop_vars('time') \
    .to_dataset(name='prim_1700') \
    .to_netcdf(os.path.join(dir05x, 'luh2/ds_luh2_prim_1700.nc'), mode='w')

# Prepare land-sea mask and export
xr.where(ds_luh2 == ds_luh2, 1, 0).primf.sel(time=1700) \
    .astype('float64') \
    .drop_vars('time') \
    .rename('land_sea_mask') \
    .to_netcdf(os.path.join(dir05x, 'luh2/ds_luh2_land.nc'), mode='w')

In [None]:
# Prepare predicted carbon density data
ds_agbc = xr.open_dataset(os.path.join(dir04, 'agbc.nc')) \
    .chunk(dict(lat=5000, lon=5000))
ds_bgbc = xr.open_dataset(os.path.join(dir04, 'bgbc.nc')) \
    .chunk(dict(lat=5000, lon=5000))
ds_soc = xr.open_dataset(os.path.join(dir04, 'soc.nc')) \
    .chunk(dict(lat=5000, lon=5000))

# Create dataset with relevant variables
ds_pot = xr.Dataset(
    dict(pot_s2_cveg = ds_agbc.agbc_max_prim + ds_bgbc.bgbc_max_prim,
         pot_s3_cveg = ds_agbc.agbc_max_act + ds_bgbc.bgbc_max_act,
         pot_s2_csoil = ds_soc.soc_mean_prim,
         pot_s3_csoil = ds_soc.soc_mean_act)
)

# Remove variable attributes (needed for regridding)
for i in list(ds_pot.data_vars):
    ds_pot[i].attrs = {}

# Export
ds_pot.to_zarr(os.path.join(dir05x + 'pot/ds_pot.zarr'), mode='w')

# Create and export land-sea mask
xr.where(ds_pot.pot_s2_cveg == ds_pot.pot_s2_cveg, 1, 0) \
    .rename('land_sea_mask') \
    .to_zarr(os.path.join(dir05x + 'pot/ds_pot_land.zarr'), mode='w');

In [None]:
# Prepare erb data
# Create dataset with erb maps as variabales
ds_erb = xr.Dataset()

for i in [i for i in os.listdir(os.path.join(dir_data, 'erb2018')) 
          if i.endswith('.tif')]:
    # Get data and rename coordintaes
    da = rioxarray.open_rasterio(
        os.path.join(dir_data, 'erb2018', 
                     f'ExtDat_Fig{i[10:11]}{i[11:12]}_gcm.tif'), 
        chunks=True) \
        .rename(y='lat', x='lon') \
        .squeeze('band') \
        .drop_vars(['band', 'spatial_ref'])
    # Add map as variable to dataset
    ds_erb['fig' + i[10:11] + '_' + i[11:12].lower()] =  \
        da.where(da != da.attrs['_FillValue']) * 0.01

# Mean of erb actual vegetation carbon
da_erb_act = ds_erb[[i for i in ds_erb.data_vars if i.startswith('fig3')]] \
    .to_array() \
    .mean('variable')

# Mean of erb potential vegetation carbon
da_erb_pot = ds_erb[[i for i in ds_erb.data_vars if i.startswith('fig4')]] \
    .to_array() \
    .mean('variable')

# Create dataset with mean values
ds_erb2 = xr.Dataset()
ds_erb2['erb_s2_cveg'] = da_erb_pot
ds_erb2['erb_s3_cveg'] = da_erb_act

# Export
ds_erb2.to_zarr(os.path.join(dir05x, 'erb', 'ds_erb.zarr'), mode='w')

# Create and export land-sea mask
xr.where(ds_erb2.erb_s2_cveg == ds_erb2.erb_s2_cveg, 1, 0) \
    .rename('land_sea_mask') \
    .to_zarr(os.path.join(dir05x, 'erb', 'ds_erb_land.zarr'), mode='w');

In [None]:
# Prepare sanderman data
# Look through data of different soild depths
for case in ['sanderman030', 'sanderman100', 'sanderman200']:
    # define depth based on case name
    if case == 'sanderman030':
        fstr = '0_30cm'
    if case == 'sanderman100':
        fstr = '0_100cm'
    if case == 'sanderman200':
        fstr = '0_200cm'

    # Get data
    da_sand_pot = rioxarray.open_rasterio(
        os.path.join(dir_data, 'sanderman2017', 
                     f'SOCS_{fstr}_year_NoLU_10km.tif'),
        chunks=dict(y=5000, x=5000))

    da_sand_cur = rioxarray.open_rasterio(
        os.path.join(dir_data, 'sanderman2017', 
                     f'SOCS_{fstr}_year_2010AD_10km.tif'),
                     chunks=dict(y=5000, x=5000))
    # Create dataset with the two arrays
    ds_sand = xr.Dataset()
    ds_sand[case + '_s2_csoil'] = da_sand_pot
    ds_sand[case + '_s3_csoil'] = da_sand_cur
    # rename coordinates and export
    ds_sand.rename(y='lat', x='lon') \
        .squeeze('band') \
        .drop_vars(['band', 'spatial_ref']) \
        .to_zarr(os.path.join(dir05x, case, f'ds_{case}.zarr'), mode='w')
    
# Create and export land-sea mask
for case in ['sanderman030', 'sanderman100', 'sanderman200']:
    
    ds_sand2 = xr.open_zarr(os.path.join(dir05x, case, f'ds_{case}.zarr'))

    xr.where(ds_sand2[case + '_s2_csoil'] == 
             ds_sand2[case + '_s2_csoil'], 1, 0) \
        .rename('land_sea_mask') \
        .to_zarr(os.path.join(dir05x, case, f'ds_{case}_land.zarr'), 
                 mode='w');

In [None]:
# Prepare walker data
# Get data
def read_walker(fstr):
    """Get walker data"""
    return rioxarray.open_rasterio(os.path.join(dir_data, 'walker2022', fstr), 
                                   chunks=dict(y=5000, x=5000))

da_walker_agb_pot = read_walker('Base_Pot_AGB_MgCha_500m.tif')
da_walker_agb_cur = read_walker('Base_Cur_AGB_MgCha_500m.tif')
da_walker_bgb_pot = read_walker('Base_Pot_BGB_MgCha_500m.tif')
da_walker_bgb_cur = read_walker('Base_Cur_BGB_MgCha_500m.tif')
da_walker_soc_pot = read_walker('Base_Pot_SOC_MgCha_500m.tif')
da_walker_soc_cur = read_walker('Base_Cur_SOC_MgCha_500m.tif')

# Create dataset with walker maps as variables
ds_walker = xr.Dataset(dict(band=da_walker_agb_pot.band, 
                            y=da_walker_agb_pot.y, 
                            x=da_walker_agb_pot.x))
ds_walker['walker_s2_cveg'] = (('band', 'y', 'x'), 
                        da_walker_agb_pot.where(da_walker_agb_pot != -32768).data + 
                        da_walker_bgb_pot.where(da_walker_bgb_pot != -32768).data)
ds_walker['walker_s2_cveg'] = ds_walker['walker_s2_cveg'].fillna(-32768).astype('int16')

ds_walker['walker_s3_cveg'] = (('band', 'y', 'x'), 
                        da_walker_agb_cur.where(da_walker_agb_cur != -32768).data + 
                        da_walker_bgb_cur.where(da_walker_bgb_cur != -32768).data)
ds_walker['walker_s3_cveg'] = ds_walker['walker_s3_cveg'].fillna(-32768).astype('int16')

ds_walker['walker_s2_csoil'] = (('band', 'y', 'x'), da_walker_soc_pot.data)
ds_walker['walker_s3_csoil'] = (('band', 'y', 'x'), da_walker_soc_cur.data)

# Export
ds_walker.to_zarr(os.path.join(dir05x, 'walker', 'walker_temp.zarr'), mode='w')

    
# Get prepared walker dataset
ds_walker = xr.open_zarr(os.path.join(dir05x, 'walker', 'walker_temp.zarr'))

# Reproject variable of prepared walker dataset
def reproject_walker(fstr):
    """Reproject variable of prepared walker dataset"""
    ds_walker[fstr] \
    .rio.reproject('epsg:4326') \
    .rename(y='lat', x='lon') \
    .squeeze('band') \
    .drop_vars(['band', 'spatial_ref']) \
    .to_dataset(name=fstr) \
    .to_zarr(os.path.join(dir05x, 'walker', f'walker_temp_{fstr}.zarr'),
             mode='w')

for i in ds_walker.data_vars:
    reproject_walker(i)

# Combine reprojected variales in dataset and export
ds = xr.open_mfdataset(os.path.join(dir05x, 'walker', 'walker_temp_*.zarr'), 
                       engine='zarr')

for var in ds:
    del ds[var].encoding['chunks']
    
ds.chunk(dict(lat=5000, lon=5000)) \
    .to_zarr(os.path.join(dir05x, 'walker', 'ds_walker.zarr'),  mode="w")

# Create and export land-sea mask
ds_walker2 = xr.open_zarr(os.path.join(dir05x, 'walker', 'ds_walker.zarr'))

xr.where(ds_walker2.walker_s2_cveg == ds_walker2.walker_s2_cveg, 1, 0) \
    .rename('land_sea_mask') \
    .to_zarr(os.path.join(dir05x, 'walker', 'ds_walker_land.zarr'), mode='w')

# Delete temporary files
for f in [i for i in os.listdir(os.path.join(dir05x, 'walker')) 
          if i[:6] == 'walker']:
    shutil.rmtree(os.path.join(dir05x, 'walker', f));

In [None]:
# Mo
def prep_mo(act_pot, fstr):

    """Get Mo data"""
    
    if act_pot == 'pot':
        path_fstr = os.path.join(
        dir_data, f'mo2023/v1_1/{fstr}_Full_TGB_carbon_density_Map_Merged.tif')
    if act_pot == 'act':
        path_fstr = os.path.join(
        dir_data, f'mo2023/v1_1/{fstr}_Present_TGB_Density_Map_Merged.tif')

    
    da = rioxarray.open_rasterio(path_fstr, chunks = dict(y=5000, x=5000))[0]
    da = da.where(da != da.attrs['_FillValue'])
    return da.rename(y='lat', x='lon').drop_vars(['band', 'spatial_ref'])

# Create Dataset with Mo data
ds_mo = xr.Dataset(dict(lat=prep_mo('act', 'SD').lat, 
                        lon=prep_mo('act', 'SD').lon))

ds_mo['act_gsl'] = (('lat', 'lon'), prep_mo('act', 'GS_Max').data)
ds_mo['act_gsu'] = (('lat', 'lon'), prep_mo('act', 'GS_Mean').data)
ds_mo['act_sdh'] = (('lat', 'lon'), prep_mo('act', 'HM').data)
ds_mo['act_sde'] = (('lat', 'lon'), prep_mo('act', 'SD').data)
ds_mo['act_sdw'] = (('lat', 'lon'), prep_mo('act', 'WK').data)

ds_mo['pot_gs1l'] = (('lat', 'lon'), prep_mo('pot', 'GS_Mean1').data)
ds_mo['pot_gs2l'] = (('lat', 'lon'), prep_mo('pot', 'GS_Mean2').data)
ds_mo['pot_gs1u'] = (('lat', 'lon'), prep_mo('pot', 'GS_Max1').data)
ds_mo['pot_gs2u'] = (('lat', 'lon'), prep_mo('pot', 'GS_Max2').data)
ds_mo['pot_sd1h'] = (('lat', 'lon'), prep_mo('pot', 'HM1').data)
ds_mo['pot_sd2h'] = (('lat', 'lon'), prep_mo('pot', 'HM2').data)
ds_mo['pot_sd1e'] = (('lat', 'lon'), prep_mo('pot', 'SD1').data)
ds_mo['pot_sd2e'] = (('lat', 'lon'), prep_mo('pot', 'SD2').data)
ds_mo['pot_sd1w'] = (('lat', 'lon'), prep_mo('pot', 'WK1').data)
ds_mo['pot_sd2w'] = (('lat', 'lon'), prep_mo('pot', 'WK2').data)

# Dataset with mean actual and potential mo values
ds_mo2 = xr.Dataset()
ds_mo2['mo_s2_cveg'] = ds_mo[[i for i in ds_mo.data_vars if 'pot' in i]] \
    .to_array(dim='s').mean('s')
ds_mo2['mo_s3_cveg'] = ds_mo[[i for i in ds_mo.data_vars if 'act' in i]] \
    .to_array(dim='s').mean('s')

# Export
ds_mo2.to_zarr(os.path.join(dir05x, 'mo', 'ds_mo.zarr'), mode='w');

In [None]:
# Mo land sea mask 
# Land sea differences between actual, potential, open-water file
def prep_mo_land(f_in, f_out):
    
    """Prepare Mo land sea masks"""
    
    da_land_x = rioxarray.open_rasterio(
        os.path.join(dir_data, f'mo2023/v1_1/{f_in}.tif'),
        chunks = dict(y=5000, x=5000))[0]

    if f_out in ['ds_mo_land_act', 'ds_mo_land_pot']:
        v_na = da_land_x.attrs['_FillValue']
        da_land = xr.where(da_land_x != v_na, 1., 0.)
    if f_out == 'ds_mo_land':
        da_land = da_land_x
    ds_land = xr.Dataset(dict(lat=ds_mo.lat, lon=ds_mo.lon))
    ds_land['land_sea_mask'] = (('lat', 'lon'), da_land.data.astype('float32'))
    ds_land.to_zarr(
        os.path.join(dir05x, 'mo', f'{f_out}.zarr'), mode='w')


prep_mo_land('SD_Present_TGB_Density_Map_Merged', 'ds_mo_land_act')
prep_mo_land('SD1_Full_TGB_carbon_density_Map_Merged', 'ds_mo_land_pot')
prep_mo_land('Open_Water_mask_Map', 'ds_mo_land');

---

### Check Mo land masks

In [None]:
import matplotlib.pyplot as plt

In [None]:
fig, axs = plt.subplots(figsize=(20, 5), ncols=3, nrows=1)
xr.open_zarr(os.path.join(dir05x, 'mo/ds_mo_land.zarr')) \
    .land_sea_mask.plot.imshow(ax=axs[0])
xr.open_zarr(os.path.join(dir05x, 'mo/ds_mo_land_act.zarr')) \
    .land_sea_mask.plot.imshow(ax=axs[1])
xr.open_zarr(os.path.join(dir05x, 'mo/ds_mo_land_pot.zarr')) \
    .land_sea_mask.plot.imshow(ax=axs[2])

In [None]:
fig, axs = plt.subplots(figsize=(20, 5), ncols=3, nrows=1)

xr.open_zarr(os.path.join(dir05x, 'mo/ds_mo_land.zarr')) \
    .sel(lat=slice(60, 40), lon=slice(0, 20)) \
    .land_sea_mask.plot.imshow(ax=axs[0], add_colorbar=False)
xr.open_zarr(os.path.join(dir05x, 'mo/ds_mo_land_act.zarr')) \
    .sel(lat=slice(60, 40), lon=slice(0, 20)) \
    .land_sea_mask.plot.imshow(ax=axs[1], add_colorbar=False)
xr.open_zarr(os.path.join(dir05x, 'mo/ds_mo_land_pot.zarr')) \
    .sel(lat=slice(60, 40), lon=slice(0, 20)) \
    .land_sea_mask.plot.imshow(ax=axs[2], add_colorbar=False)

axs[0].set_title('open_water_mask')
axs[1].set_title('from actual carbon file')
axs[2].set_title('from potential carbon file')

plt.tight_layout()

In [None]:
cluster.close()