!

In [None]:
%%capture
!pip install --upgrade xarray --quiet
!pip install --upgrade rioxarray --quiet

In [None]:
import xarray as xr
import rioxarray as rio
from rasterio.warp import reproject, Resampling
import pandas as pd
import numpy as np
import datetime
import os

In [None]:
from google.colab import drive
drive.mount('/content/gdrive/', force_remount=True)

%cd /content/gdrive/MyDrive

In [None]:
# functions to calculate dekadal, monthly and seasonal sums from the ET_look output netcdf file
def resotor_encoding(ds, encoding, var_attrs, temp_res):
    # Restore encoding information
    for var in ds.data_vars:
        if var in encoding:  # Ensure the variable exists in the original encoding
            ds[var].encoding = encoding[var]

    for var in ds.data_vars:
        if var in var_attrs:  # Ensure the variable exists in the original encoding
            attrs = ds[var].attrs
            attrs_to_delete = [j for j in attrs if 'NETCDF_' in j or 'scale_factor' in j]
            attrs = {key: attrs[key] for key in attrs if key not in attrs_to_delete}
            lname = attrs['long_name']
            lname = lname.replace("Daily", temp_res)
            attrs.update({'long_name': lname, 
                          'source_data': 'Aggregated from ET_Look model output',
                          'units' : f"mm/{temp_res[:-2]}",
                          'temporal_resolution' : temp_res,})
            ds[var].attrs = attrs
    return ds

# decadal sum
def dekadal_sum(ds):
    # Store encoding information
    encoding = {var: ds[var].encoding for var in ds.data_vars}
    attrs = {var: ds[var].attrs for var in ds.data_vars}
    # aggegate to dekadal values (10,10, and 10/11/9/8 based on the eyar and the month)
    d = ds.time.dt.day - np.clip((ds.time.dt.day-1) // 10, 0, 2)*10 - 1
    date = ds.time.values - np.array(d, dtype="timedelta64[D]")
    ds['time'] = date
    ds_dk = ds.groupby(ds.time).sum(dim='time', skipna=False, keep_attrs=True)
    # Restore encoding information
    ds_dk = resotor_encoding(ds_dk, encoding, attrs, 'Dekadal') #resotor_encoding(ds_dk, encoding, attrs, 'dekadal')
    return ds_dk

# Monthly sum
def monthly_sum(ds):
    # Store encoding information
    encoding = {var: ds[var].encoding for var in ds.data_vars}
    attrs = {var: ds[var].attrs for var in ds.data_vars}
    # aggregate to monthly
    ds_mn = ds.resample(time="1ME").sum(skipna=False)
    # Restore encoding information
    ds_mn = resotor_encoding(ds_mn, encoding, attrs, 'Monthly')
    return ds_mn
# Check if the start and end time of the selected dataarray corresponds to sos and eos
def select_season_da(da_var, season_start_date, season_end_date):

    sos = datetime.datetime.fromisoformat(season_start_date) #start of season date, we use datetime.datetime to convert the year, month, day to a datetime object
    eos = datetime.datetime.fromisoformat(season_end_date) #end of season date

    da_st = datetime.datetime.fromisoformat(pd.to_datetime(da_var.time.data).strftime('%Y-%m-%d')[0])
    da_et = datetime.datetime.fromisoformat(pd.to_datetime(da_var.time.data).strftime('%Y-%m-%d')[-1])
    try:
        if (sos >= da_st) or (eos <= da_et):
            da = da_var.sel(time=slice(sos, eos))
            return da
        else:
            print("The sos and/or eos out of the time range of the dataset.")
            da = da_var.sel(time=slice(sos, eos))
            return da
    except ValueError:
        print("Erro in selecting data for the season.")

# Seasonal Resample
def seasonal_sum(ds, sos, eos):
    # Store encoding information
    encoding = {var: ds[var].encoding for var in ds.data_vars}
    attrs = {var: ds[var].attrs for var in ds.data_vars}

    ds_sn = select_season_da(ds, sos, eos).sum(dim = 'time', skipna=False)
    # Restore encoding information
    ds_sn = resotor_encoding(ds_sn, encoding, attrs, 'Seasonal')
    for var in ds_sn.data_vars:
        if var in encoding:  # Ensure the variable exists in the original encoding
            encoding[var]['sos']= sos
            encoding[var]['eos']= eos
            ds_sn[var].encoding = encoding[var]
    return ds_sn
  
def reproject(ds, to_crs):
#   encoding = ds.econding #{var: ds[var].encoding for var in ds.data_vars}
  try:
      if 'EPSG'.lower() in str(to_crs).lower():
        #   print(to_crs)
          dst = ds.rio.reproject(to_crs)
          return dst
      else: # assume it is a path to template file
        if os.path.exists(to_crs):
            # print("Use a template raster to repoject the dataset")
            temp_rst_file = to_crs
            da_rst = rio.open_rasterio(temp_rst_file)
            if da_rst.rio.crs != None:
                dst= ds.rio.reproject_match(da_rst)
                return dst
            else:
                print(f"the template raster {temp_rst_file} does not have CRS information.")
  except ValueError:
        print("Your input is not either a valid EPSG code or a teplate raster path.")

switcher = {
        'et': 'AETI',
        'e': 'E',
        'int': 'I',
        'npp': 'NPP',
        't': 'T',
        'se': 'RSM',
        'dekadal': 'D',
        'monthly': 'M',
        'seasonal': 'S'
    }
def get_code(code_name):
    func = switcher.get(code_name, "nothing")
    # Execute the function
    return func

# write to file
def write_file(da, to_crs, fname, encoding, date, attrs, temporal_res) :
    if(to_crs!=None):
        # reproject the data. provide a crs in the form of f"EPSG:{epsg code}" or a path to template raster
        da = reproject(da, to_crs)
    # Modify the attributes
    attrs.update({'date': date})
    da.attrs  = attrs
    da = da.round(2)
    da.encoding = encoding   #['scale_factor'] = 1.0
    da.rio.to_raster(f"{fname}.tif", driver="GTiff", compress="LZW")
    da.close()

# netCDF to geotiff
def write2gtiff(ds, temporal_res, dir_out, to_crs = None):
  
  if 'time' in ds.dims:
      date_str = pd.to_datetime(ds.time.data).strftime('%Y-%m-%d')

  for var in ds.data_vars:
    var_name = get_code(var.split('_')[0])
    time_code = get_code(temporal_res)
    var_name = f"{var_name}_{time_code}"
    fd = os.path.join(dir_out, temporal_res, f"pywapor_{var_name}")
    encoding  = ds[var].encoding
    encoding['dtype'] = 'float32'
    encoding['scale_factor'] = 1.0
    encoding['_FillValue'] = encoding['_FillValue'].astype('float32')
    # print(encoding)
    attrs = ds[var].attrs
    if(temporal_res.lower() == 'seasonal'):
        sos = encoding['sos']
        eos = encoding['eos']
    # Create folder per variable.
    if not os.path.isdir(fd):
        os.makedirs(fd)

    if(temporal_res.lower() != 'seasonal'):
        for i in range(len(ds.time)):
            date = date_str[i]
            fname = os.path.join(fd, f"pywapor_{var_name}_{date}")
            da = ds[var][i]
            da = da.drop_vars('time')  # get the data for one time step
            write_file(da, to_crs, fname, encoding, date, attrs, temporal_res)    
    else:
        date = f"{sos}_{eos}"
        fname = os.path.join(fd, f"pywapor_{var_name}_{date}")
        da = ds[var]
        write_file(da, to_crs, fname, encoding, date, attrs, temporal_res)

#### Step 1: Read pywapor output

In [None]:
# path to the et_look_out/nc file
path_et_look_out = r'/content/gdrive/MyDrive/pywapor/et_look_out.nc'
xr.set_options(keep_attrs=True)
ds = xr.open_dataset(path_et_look_out, decode_coords="all")
ds = ds.rename({'time_bins': 'time'})
# ds

#### Step 2: Aggregate to the required timestep (dekadal, monthly or seasonal) and write the result to individual geotiff files per time step
The ET_look output is in EPSG:4326, if you would like to reproject the dataset to other projections such UTM zone, provide the required epsg code or a path to raster template file. The defualt is an estimated utm crs from the dataset. if you want to change provide the crs in the following style: to_crs = f"EPSG:{epsg code (number)}" 

In [None]:
dir_out = r'/content/gdrive/MyDrive/pywapor' # a folder in your gdrive to save the geotif files
# dir_out = r'pywapor_out' # a folder in colab working directory to save the geotif files

# estinated utm crs from the dataset
to_crs = ds.rio.estimate_utm_crs()


In [None]:
# aggregate to dekadal timestep
ds_dk = dekadal_sum(ds) # dekadal
temporal_res = 'dekadal'
write2gtiff(ds_dk, temporal_res, dir_out, to_crs)

In [None]:
# aggregate to monthly timestep
ds_mn = monthly_sum(ds) # monthly
temporal_res = 'monthly'
write2gtiff(ds_mn, temporal_res, dir_out, to_crs)

In [None]:
# aggregate to a season
season_start_date = '2022-10-01' # start od the season in iso format
season_end_date = '2023-04-30' # end of the season in iso format
ds_sn = seasonal_sum(ds, season_start_date, season_end_date)
temporal_res = 'seasonal'
write2gtiff(ds_sn, temporal_res, dir_out, to_crs )


### Zip and downalod the data folder to your local drive

In [None]:

!zip -r /content/pywapor.zip /content/gdrive/MyDrive/pywapor_out
from google.colab import files
files.download('/content/pywapor.zip')