In [None]:
## Notebook steps

1. Read CHRIPS monthly data subset for the region of interest
2. Regrid the CHRIPS monthly data from 5km into 25km
3. Plot the before and after regridding
4. Use the SPI calculation method on the regridded data
5. Collect MAM SPI product 

In [None]:
## Discussion points

1. Finding the SPI calcualtion methods, where to search, github gist, github author profile
2. application of pandas dataframe in date formating 

# notebook source

The function to calculate SPI is from https://gist.github.com/monocongo/978348233b4bde80e9bcc52fe8e4150c

# SPI calculation routines 

In [None]:
# standard library
import os
from tempfile import TemporaryDirectory
from typing import Dict

# third-party
from climate_indices.indices import spi, Distribution
from climate_indices.compute import Periodicity
import numpy as np
import pandas as pd
import requests
import xarray as xr

%env ESMFMKFILE=/srv/conda/envs/notebook/lib/esmf.mk
import xesmf as xe

## regrid the chrips data from 5km to 25 km

### before that, let's subset the data to small region

In [None]:

db=xr.open_dataset('/srv/repo/IBF_workshop_data/kmj_aa/chirps-v2.0.monthly.nc')
min_lon = 29.0
min_lat = -2.0-0.5
max_lon = 35.0+0.5
max_lat = 4.2+0.5
#array([ 4.2,  3.2,  2.2,  1.2,  0.2, -0.8, -1.8]),
# array([29., 30., 31., 32., 33., 34., 35.]))
kmj_db = db.sel(latitude=slice(min_lat,max_lat), longitude=slice(min_lon,max_lon))


local_path_nc=f'/srv/repo/IBF_workshop_data/kmj_aa/kmj_chirps-v2.0.monthly.nc'
kmj_db.to_netcdf(local_path_nc)

In [None]:
ds1=kmj_db.rename({'longitude':'lon','latitude':'lat'})
dr = ds1["precip"] 
ds_out = xr.Dataset(
    {
        "lat": (["lat"], np.arange(-2, 5.25, 0.25), {"units": "degrees_north"}),
        "lon": (["lon"], np.arange(29, 38.25, 0.25), {"units": "degrees_east"}),
    }
)
regridder = xe.Regridder(ds1, ds_out, "bilinear")
regridder  # print basic regridder information.
dr_out = regridder(dr, keep_attrs=True)
ds2=dr_out.to_dataset()
ds2.to_netcdf('/srv/repo/IBF_workshop_data/kmj_aa/kmj_km25_chirps-v2.0.monthly.nc')

## checking on before and after the regrdding 

In [None]:
kmj_db['precip'].sel(time='1981-01-01T00:00:00.000000000').plot()

In [None]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs

fig=plt.figure(figsize=(10,10))

ax1=fig.add_subplot(211,projection=ccrs.PlateCarree())
#kmj_db['precip'].sel(time='2012').plot(ax=ax1,add_colorbar=False)
kmj_db['precip'].sel(time='1981-01-01T00:00:00.000000000').plot(ax=ax1)
#plt.title(geodf['admin2Name'].values[0])
ax1.legend().set_visible(False)
plotname='Original CHRIPS data'
plt.title(f'{plotname} with pixels 5X5 km')


ax1=fig.add_subplot(231,projection=ccrs.PlateCarree())
#kmj_db['precip'].sel(time='2012').plot(ax=ax1,add_colorbar=False)
ds2['precip'].sel(time='1981-01-01T00:00:00.000000000').plot(ax=ax1,add_colorbar=False)
#plt.title(geodf['admin2Name'].values[0])
ax1.legend().set_visible(False)
plotname='Regridded CHRIPS data'
plt.title(f'{plotname} with pixels 25X25 km')

In [None]:
# create a wrapper function that can be applied to an entire Dataset
# (takes a DataArray as first argument, returns a DataArray)
def spi_wrapper(
    obj: xr.DataArray,
    precip_var: str,
    scale: int,
    distribution: Distribution,
    data_start_year: int,
    calibration_year_initial: int,
    calibration_year_final: int,
    periodicity: Periodicity,
    fitting_params: Dict = None,
) -> xr.DataArray:
    
    # compute SPI for this timeseries
    spi_data = spi(
        values=obj[precip_var].to_numpy(), #TODO find why we need to use the variable name rather than already using the variables's DataArray (i.e. why is obj a Dataset?)
        scale=scale,
        distribution=distribution,
        data_start_year=data_start_year,
        calibration_year_initial=calibration_year_initial,
        calibration_year_final=calibration_year_final,
        periodicity=periodicity,
        fitting_params=fitting_params,
    )
    #TODO for some reason this is necessary for the nClimGrid low-resolution example NetCDFs
    #TODO find out why
    spi_data = spi_data.flatten()
    #TODO for some reason this is necessary for the NCO-modified nClimGrid normal-resolution example NetCDFs
    #TODO find out why
    #spi_data = spi_data.reshape(spi_data.size, 1)
    # create the return DataArray (copy of input object's geospatial dims/coords plus SPI data)
    da_spi = xr.DataArray(
        dims   = obj[precip_var].dims,
        coords = obj[precip_var].coords,
        attrs  = {
            'description': 'SPI computed by the climate_indices Python package',
            'references': 'https://github.com/monocongo/climate_indices',
            'valid_min': -3.09, # this should mirror climate_indices.indices._FITTED_INDEX_VALID_MIN
            'valid_max':  3.09, # this should mirror climate_indices.indices._FITTED_INDEX_VALID_MAX
        },
        data = spi_data,
    )

    return da_spi

# calcualte SPI on xarray

In [None]:
cdb

In [None]:
cdb=xr.open_dataset('/srv/repo/IBF_workshop_data/kmj_aa/kmj_km25_chirps-v2.0.monthly.nc')
cdb_sc = cdb.stack(grid_cells=('lat', 'lon',))

spi_cdb = cdb_sc.groupby('grid_cells').apply(
    spi_wrapper,
    precip_var='precip',
    scale=3,
    distribution=Distribution.gamma,
    data_start_year=1981,
    calibration_year_initial=1981,
    calibration_year_final=2018,
    periodicity=Periodicity.monthly,
).unstack('grid_cells')

In [None]:
spi_cdb1=spi_cdb.to_dataset(name='spi')
spi_cdb1.to_netcdf('/srv/repo/IBF_workshop_data/kmj_aa/months3_spi_kmj_km25_chirps-v2.0.monthly.nc')

In [None]:
ds=xr.open_dataset('/srv/repo/IBF_workshop_data/kmj_aa/months3_spi_kmj_km25_chirps-v2.0.monthly.nc')

ds