In [1]:
import os
import sys
notebook_dir = os.getcwd()
project_root = os.path.abspath(os.path.join(notebook_dir, '..'))
sys.path.append(project_root)

import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import time
import pandas as pd
import xarray as xr
import geopandas as gpd
import ipywidgets as widgets
from IPython.display import display
import warnings
import pickle

from extract_data import extract_imerg_daily, extract_era5_prec_daily, extract_chirps_daily, extract_chpclim
import SUPER_functions_pixels
import utils

# Ignore the specific RuntimeWarning
warnings.filterwarnings("ignore", category=RuntimeWarning)

# Settings

In [23]:
path_imerg_data = project_root+ '/data/imerg/imerg_2018_2022_A_28.45_43.55_-5.05_5.45.nc' #Path to IMERG data
path_era5_data = project_root+ '/data/era5/era5_2018_2022_A_28.45_43.55_-5.05_5.45.nc' #Path to ERA5 data
path_chirps_data = project_root+ '/data/chirps/chirps_2018_2022_A_28.45_43.55_-5.05_5.45.nc' #Path to CHIRPS data
path_sm2r_data = project_root+ '/data/sm2rain/ascat_sm2r_20180101_20221231_R01_rainclass_smoothed_All_stations_C.nc' #Path to SM2Rain data
path_chpclim_data = project_root+ '/data/chpclim/chpclim28.5_43.5_-5.0_5.4.nc' #Path to CHPClim data
path_to_tahmo = project_root+ '/data/tahmo' #Path to TAHMO data
path_to_metadata = project_root+ '/data/tahmo/metadata.csv' #Path to TAHMO metadata file
path_to_qualityflags = project_root+'/data/quality_flags.csv'

[MinLon,MaxLon,MinLat,MaxLat] =  [28.5,43.5,-5.0,5.4]  # AOI
subset = [MinLon,MaxLon,MinLat,MaxLat]

start_time = datetime(2018, 1, 1) #Start time
end_time = datetime(2022, 12, 31) #Start time
year = start_time.year


nclasses = 7
min_value = 0
max_value = 1800
sigma=5.0
tol = 0.2
ID = 'ER_SM'
num_stations_to_select = 'All' # 'All' or integer

data_flags_included = True #True if data quality flags are used for removing dubious TAHMO data


## Extract Data

In [15]:
imerg_precipitation = xr.open_dataset(path_imerg_data)
sm2r_precipitation = xr.open_dataset(path_sm2r_data)
era5_precipitation = xr.open_dataset(path_era5_data)
chirps_precipitation = xr.open_dataset(path_chirps_data)
chpclim_precipitation = xr.open_dataset(path_chpclim_data)

## Data preprocessing

#### Regrid datasets

In [16]:
ds_imerg_regridded = utils.regrid_dataset(imerg_precipitation, subset, 0.1).transpose("time", "lon", "lat")
ds_era5_regridded = utils.regrid_dataset(era5_precipitation, subset, 0.1).transpose("time", "lon", "lat")
ds_chirps_regridded = utils.regrid_dataset(chirps_precipitation, subset, 0.1).transpose("time", "lon", "lat")
ds_sm2r_regridded = utils.regrid_dataset(sm2r_precipitation['precip'], subset, 0.1).transpose("time", "lon", "lat")
chpclim_regridded = utils.regrid_dataset(chpclim_precipitation, subset, 0.1).transpose("time", "lon", "lat")['precip']

#### Rename  precipitation variable where needed

In [17]:
ds_imerg_regridded = xr.DataArray(ds_imerg_regridded['precipitationCal'], name='precip')
ds_era5_regridded = ds_era5_regridded.rename(tp='precip')

#### ERA5

In [18]:
ds_era5_regridded = ds_era5_regridded.where(ds_era5_regridded.precip>0,0) * 1000 #meters to mm

#### CHPclim

In [19]:
chpclim = chpclim_regridded.rename({'time': 'month'})
chpclim["month"] = np.arange(1,13)

### Water body mask

In [20]:
water_body_mask = np.isnan(ds_sm2r_regridded) 

ds_era5_regridded = ds_era5_regridded.where(~water_body_mask[0])
ds_chirps_regridded = ds_chirps_regridded.where(~water_body_mask[0])
ds_imerg_regridded = ds_imerg_regridded.where(~water_body_mask[0])
chpclim = chpclim.where(~water_body_mask[0])

## Read TAHMO

#### Create dictionary with rainfall data per station

In [31]:
if data_flags_included:
    flags = pd.read_csv(path_to_qualityflags, index_col=0)  #Remove all flagged TAHMO data
    mask = flags.iloc[:,-5:] > 0
    
    stacked_df = mask.stack()
    flagged_data = stacked_df[stacked_df].index.tolist()
    
    flagged_data

In [32]:
list_stations = []
data_folder = path_to_tahmo

### Create list of all tahmo stations
for i in range(len(os.listdir(data_folder))):
    data_dir = os.path.join(data_folder, os.listdir(data_folder)[i])
    if os.path.basename(data_dir).startswith('TA'):
        station_name = os.path.basename(data_dir).split('.')[0]
        list_stations.append(station_name)

# Select the stations that are located within the subset
list_stations_subset = utils.SelectStationsSubset(list_stations, path_to_metadata, MinLat, MaxLat, MinLon, MaxLon) 

## SUPER algorithm

In [None]:
# Extract coordinates 
time = ds_sm2r_regridded.time
lon = ds_sm2r_regridded.lon
lat = ds_sm2r_regridded.lat

In [35]:
# Monthly correction
est1c = SUPER_functions_pixels.monthly_correction(chpclim,ds_imerg_regridded)
est2c = SUPER_functions_pixels.monthly_correction(chpclim,ds_chirps_regridded['precip'])
est3c = SUPER_functions_pixels.monthly_correction(chpclim,ds_era5_regridded['precip'])
est4c = SUPER_functions_pixels.monthly_correction(chpclim,ds_sm2r_regridded)

In [36]:
# Convert to numpy array (CTC function only works with numpy array)
p1 = est3c.values #Dependent
p2 = est4c.values #Dependent
p3 = est1c.values #Independent
p4 = est2c.values #Independent

In [37]:
# QC-merging
err_var = SUPER_functions_pixels.quadruple_weights(p1,p2,p3,p4) 
p_qc = SUPER_functions_pixels.merge_data_QC(err_var, [p1,p2,p3,p4])

In [38]:
# CTC-merging
m = SUPER_functions_pixels.CTC(p1,p3,p4, 0.5, 1.5)      # x=0.5, rain/no rain threshold # n=1.5, merging parameter (IMERG/ERA5/SM2Rain)

In [40]:
# False Alarm correction
norain_days = m < 0.5  # no-rain days
p_final = p_qc.copy()
p_final[norain_days] = 0 # set the merged data on these days to have no-rain

In [71]:
# Put data back in a xarray and save to NetCDF
super = xr.Dataset({
    'precip': xr.DataArray(
                data   = p_final, 
                dims   = ['time','lon','lat'],
                coords = {'time': time, 'lon':lon, 'lat':lat},
                attrs  = {
                    'description': 'SUPER precipitation estimates',
                    'units'     : 'mm/h'
                    }
                ),
            },
        attrs = {'description': 'Outputs of the SUPER group'}
    )

## Post processing

In [72]:
super = super.where(super >=0, 0) 

# Water Mask
super = super.where(~water_body_mask[0])


## Export

In [79]:
super.transpose("time", "lat", "lon").to_netcdf(f'{project_root}/data/super_{start_time.year}_{end_time.year}_{ID}_{MinLon}_{MaxLon}_{MinLat}_{MaxLat}.nc')

## Validation

In [81]:
ts_super = utils.timeSeriesAllTahmoFromNetCDF(f'{project_root}/data/super_{start_time.year}_{end_time.year}_{ID}_{MinLon}_{MaxLon}_{MinLat}_{MaxLat}.nc', path_to_metadata, list_stations_subset, 'precip') ## calibrated with 2021/2022
rmse, bias, spearman, kge, pod, far, hss = utils.GetEvaluationStatisticsPerStation(path_to_tahmo,path_to_qualityflags, ts_super, list_stations_subset, 0.25,correlation='spearman', only_raindays=True, getMean=True)
print(f'rmse = {rmse}, bias = {bias}, spearman correlation = {spearman}, KGE = {kge}, pod = {pod}, far = {far}, hss = {hss}')

  station_data = pd.read_csv(data_dir, index_col=0, parse_dates=True)


rmse = 13.32264231613737, bias = -0.43313258952876116, spearman correlation = 0.4834609026320591, KGE = 0.25516642969477993, pod = 0.41346891321287976, far = 0.11213633878398659, hss = 0.2619645564549368


### Visualisation

In [90]:
kenya = gpd.read_file(f'{project_root}/data/shapes/kenya/gadm41_KEN_0.shp')
rwanda = gpd.read_file(f'{project_root}/data/shapes/rwanda/gadm41_RWA_0.shp')
uganda = gpd.read_file(f'{project_root}/data/shapes/uganda/gadm41_UGA_0.shp')
tanzania = gpd.read_file(f'{project_root}/data/shapes/tanzania/gadm41_TZA_0.shp')
som = gpd.read_file(f'{project_root}/data/shapes/somalia/gadm41_SOM_0.shp')
ssd = gpd.read_file(f'{project_root}/data/shapes/southsudan/gadm41_SSD_0.shp')
eth = gpd.read_file(f'{project_root}/data/shapes/ethiopia/gadm41_ETH_0.shp')
brn = gpd.read_file(f'{project_root}/data/shapes/burundi/gadm41_BDI_0.shp')
drc = gpd.read_file(f'{project_root}/data/shapes/drcongo/gadm41_COD_0.shp')

def plot_super_slider(timestep):
    # plt.figure(figsize=(10, 10))
    fig, ax = plt.subplots(figsize=(4, 2))
    super['precip'].isel(time=timestep).plot(ax=ax, cbar_kwargs={"label": "precip [mm/d]"})
    
    # plt.scatter(36.82, -1.29, color='r', label='Nairobi')
    kenya.plot(ax=ax, edgecolor="black", facecolor="none")
    rwanda.plot(ax=ax, edgecolor="black",facecolor="none")
    uganda.plot(ax=ax, edgecolor="black", facecolor="none")
    tanzania.plot(ax=ax, edgecolor="black", facecolor="none")
    som.plot(ax=ax, edgecolor="black", facecolor="none")
    ssd.plot(ax=ax, edgecolor="black", facecolor="none")
    eth.plot(ax=ax, edgecolor="black", facecolor="none")
    brn.plot(ax=ax, edgecolor="black", facecolor="none")
    drc.plot(ax=ax, edgecolor="black", facecolor="none")

    plt.title(f'Precipitation (SUPER) {super["time"].isel(time=timestep).values.astype("M8[D]").tolist().strftime("%Y-%m-%d")}')
    plt.xlabel('lon', size=10)
    plt.ylabel('lat', size=10)
    # plt.legend()
    plt.show()

In [91]:
super = xr.open_dataset(f'{project_root}/data/super_{start_time.year}_{end_time.year}_{ID}_{MinLon}_{MaxLon}_{MinLat}_{MaxLat}.nc')


time_slider = widgets.IntSlider(value=0, min=0, max=len(super['time']) - 1, step=1, description='Time Step')
widgets.interactive(plot_super_slider, timestep=time_slider)

interactive(children=(IntSlider(value=0, description='Time Step', max=1825), Output()), _dom_classes=('widget-…