# notebook source

https://gist.github.com/monocongo/978348233b4bde80e9bcc52fe8e4150c

# SPI calculation routines 

In [None]:
# standard library
import os
from tempfile import TemporaryDirectory
from typing import Dict

# third-party
from climate_indices.indices import spi, Distribution
from climate_indices.compute import Periodicity
import numpy as np
import pandas as pd
import requests
import xarray as xr

In [None]:
# create a wrapper function that can be applied to an entire Dataset
# (takes a DataArray as first argument, returns a DataArray)
def spi_wrapper(
    obj: xr.DataArray,
    precip_var: str,
    scale: int,
    distribution: Distribution,
    data_start_year: int,
    calibration_year_initial: int,
    calibration_year_final: int,
    periodicity: Periodicity,
    fitting_params: Dict = None,
) -> xr.DataArray:
    
    # compute SPI for this timeseries
    spi_data = spi(
        values=obj[precip_var].to_numpy(), #TODO find why we need to use the variable name rather than already using the variables's DataArray (i.e. why is obj a Dataset?)
        scale=scale,
        distribution=distribution,
        data_start_year=data_start_year,
        calibration_year_initial=calibration_year_initial,
        calibration_year_final=calibration_year_final,
        periodicity=periodicity,
        fitting_params=fitting_params,
    )

    #TODO for some reason this is necessary for the nClimGrid low-resolution example NetCDFs
    #TODO find out why
    #spi_data = spi_data.flatten()
    
    #TODO for some reason this is necessary for the NCO-modified nClimGrid normal-resolution example NetCDFs
    #TODO find out why
    spi_data = spi_data.reshape(spi_data.size, 1)
    
    # create the return DataArray (copy of input object's geospatial dims/coords plus SPI data)
    da_spi = xr.DataArray(
        dims   = obj[precip_var].dims,
        coords = obj[precip_var].coords,
        attrs  = {
            'description': 'SPI computed by the climate_indices Python package',
            'references': 'https://github.com/monocongo/climate_indices',
            'valid_min': -3.09, # this should mirror climate_indices.indices._FITTED_INDEX_VALID_MIN
            'valid_max':  3.09, # this should mirror climate_indices.indices._FITTED_INDEX_VALID_MAX
        },
        data = spi_data,
    )

    return da_spi


# get data 

In [None]:
import pandas as pd
from datetime import datetime
import numpy as np
import rioxarray
import xarray as xr

min_lon = 31
min_lat = 0.0
max_lon = 37
max_lat = 6


datereg=pd.date_range(start='19790201',end='20201130', freq='1M')
datereg1=[x.strftime('%Y%m') for x in datereg]


datasets = []
for datr in datereg1:
    db=xr.open_dataset(f'EA_MSWEP/{datr}.nc')
    #subset = db.rio.clip_box(minx=min_lon, miny=min_lat, maxx=max_lon, maxy=max_lat)
    cropped_ds = db.sel(lat=slice(max_lat,min_lat), lon=slice(min_lon,max_lon))
    #print(datr)
    datasets.append(cropped_ds)

combined = xr.concat(datasets, dim='time')

# calcualte SPI on xarray

In [None]:
sc = combined.stack(grid_cells=('lat', 'lon',))

spi_ds = sc.groupby('grid_cells').apply(
    spi_wrapper,
    precip_var='precipitation',
    scale=3,
    distribution=Distribution.gamma,
    data_start_year=1979,
    calibration_year_initial=1981,
    calibration_year_final=2010,
    periodicity=Periodicity.monthly,
).unstack('grid_cells')