# Step 1:  Create supply (runoff)
This script summarizes the PCR-GLOBWB NetCDF outputs for runoff by the Hydrobasin 6 catchments. We loop through the GCM/scen, resample by 5x5, convert the flux to volume, and take a clip of the timeseries data by each polygon. 

## Runoff (renewable water available inside sub-basin. Internal consumption is not removed)
runoff = land_surface_runoff * area

## After this script:
The results will need to be added with discharge. The results have been resampled and will need to be adjusted.

# Setup

## Libraries

In [0]:
!pip install tqdm
!pip install rtree
!pip3 install numpy
!pip3 install pandas
!pip3 install scipy
!pip3 install geopandas
!pip3 install xarray
!pip3 install rasterio
!pip3 install rasterstats
!pip3 install rioxarray
!pip3 install netcdf4
!pip install psutil
!pip install dask
import psutil
import xarray
import rioxarray
import rasterio
import geopandas as gpd
import rasterstats as rstats
import netCDF4, os, subprocess, re, time, datetime, json
import numpy as np, pandas as pd
import netCDF4 as nc
from rasterio import Affine
from rasterio.enums import Resampling
import matplotlib.pyplot as plt
import math
from tqdm import tqdm
import dask
import gc
from joblib import Parallel, delayed


## Functions & Data Locations

In [0]:
def memory_usage():
    process = psutil.Process(os.getpid())
    print('- - - Current memory usage is:', process.memory_info()[0] / float(2 ** 20))
    

def find_runoff_paths(gcm, scen, m):
    '''
    PURPOSE: Find the paths to the supply data based on the defined parameters
    INPUTS:
        gcm: global climate model
        scen: future scenario
        m: region of the world
    OUTPUTS:
        dPATH: path to 5 arc min discharge
    '''
    # Find withdrawal data paths
    # For withdrawal data, Path requires knowing the beginning year
    beginText = 'begin_from_1960' if 'historical' in scen else 'begin_from_2015'
    scenText = 'historical-reference' if gcm == 'gswp3-w5e5' else scen
    # Find runoff
    rPATH = '/dbfs/mnt/pgb-data-lake/pcrglobwb_output1/pcrglobwb_aqueduct_2021/version_2021-09-16/{0}/{1}/{2}/{3}/netcdf/runoff_monthTot_output.nc'.format(gcm, scenText, beginText, m)
    return rPATH

def read_NETCDF(ncPATH):
    '''
    PURPOSE: Read in NetCDF, return an Xarray about with spatial dimension defined, and list of band names (so we know what data is in it)
    INPUTS:
        ncPATH: path to netCDF in Data Lake
    OUTPUTS:
        ds: Xarray
        nc_band: band in data
    '''
#     print(ncPATH)
    # Read in arrary
    ds = xarray.open_dataset(ncPATH)
    # Find coordinate names
    dimensions  = [x for x in ds.coords.keys()]
    lat_variable = [x for x in dimensions if "lat" in x][0]
    lon_variable = [x for x in dimensions if "lon" in x][0]
    # Standardize lat and lon names
    ds = ds.rename({lon_variable: 'lon', lat_variable: 'lat'})
    # Set spatial dimenstions and projection
    ds = ds.rio.set_spatial_dims('lon', 'lat')
    ds.rio.crs
    ds.rio.write_crs("epsg:4326", inplace=True)
    # Find name of bands
    nc_bands = list(set([x for x in ds.variables.keys()]) - set(dimensions))
    nc_bands.remove('spatial_ref')
#     print(nc_bands)
    return ds

def fillnas(da):
    """Replaces NA values with 0 in data array. Returns data array"""
        # Fill NA's with where statement. fillna functions aren't working great
    da_filled = xarray.where(da.isnull(), 0, da)
    del da
    # reset spatial dimensions
    da_filled = da_filled.rio.set_spatial_dims('lon', 'lat')
    da_filled.rio.crs
    da_filled.rio.write_crs("epsg:4326", inplace=True)
    return da_filled

def resample_xarray(ds, downscale_factor):
    '''
    PURPOSE: Resample NetCDF to smaller size so zonal statistics can be more accurate 
    INPUTS:
        ds: Xarray to downscale
        downscale_factor: 1-dimensional factor to increase size by. 
        Ex: 10 would turn each pixel into 100 smaller, identical pixels (10X10)
    OUTPUTS:
        xds_downscaled: downscaled Xarray
    '''
    # Dfein new dimensions
    new_width = ds.rio.width * downscale_factor
    new_height = ds.rio.height * downscale_factor
    # Run resampling function
    xds_downscaled = ds.rio.reproject(
        ds.rio.crs,
        shape=(new_height, new_width),
        resampling=Resampling.nearest,
    )
    # Rename coordinate dimensions
    xds_downscaled = xds_downscaled.rename({'x': 'lon', 'y': 'lat'})
    return xds_downscaled

def flux_to_volume(ds, ds_band):
    '''
    PURPOSE: Convert flux (m/month) to volume (m3/month)
    INPUTS:
        ds: Xarray in flux (m/month) to convert
        ds_band: band name of flux data
    OUTPUTS:
        ds_vol: 5 arc min dataset, band = 'volume'
    '''
    print("- - - Conversion to volume (million m3) started")
    memory_usage()
    # 1. Clip area to region
    ds_area_c = ds_area.sel(lon=slice(ds['lon'].min(),ds['lon'].max()), lat=slice(ds['lat'].max(), ds['lat'].min()))
    print("- - - - total area (m2):", ds_area_c[area_band].sum().values)
    # 2. Turn area into million m2. Do it now instead on PCR data later because it makes messy no data values
    ds_area_c = ds_area_c.assign(variable=ds_area_c[area_band]/1000000.0)
    print("- - - - total area (Million m2):", ds_area_c['variable'].sum().values)
    # 3. Reset Area's lat and lon to match pcr data (i.e. match truncated coords to untruncated)
    gridarea = ds_area_c.reindex_like(ds, method='nearest', tolerance=0.01)
    del ds_area_c
    # 4. Merge data into 1 dataset, Multiply area (million m2) times flux (m)
    ds_adj = xarray.merge([ds, gridarea['variable']])
    ds_adj = ds_adj.assign(volume=ds_adj[ds_band] * ds_adj['variable'])
    del ds
    print("- - - - total WW in Million m3:", ds_adj['volume'].sum().values)
    print("- - - Conversion complete.")
    memory_usage()
    return ds_adj

def segment_id_list(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]
        
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
# ! - -  - -  - -  - -  - -  - -  - -  - -  - -  - -  - -  - -  - -  - - UNIVERSAL DATA - -  - -  - -  - -  - -  - -  - -  - -  - -  - -  - -  - -  - -  - - !
# 1. M folders (ie, regional folders) and path to example regional data
mFolders = ['M' + str(x).zfill(7) for x in range(1, 54)]
# Global climate models
# gcmFolders = ['gswp3-w5e5']

gcmFolders = ['gfdl-esm4',
              'ipsl-cm6a-lr',
              'mpi-esm1-2-hr',
              'mri-esm2-0',
              'ukesm1-0-ll']
# Future scenarions
# scenFolders = ['historical',
#                'ssp126',
#                'ssp370',
#                'ssp585']
scenFolders = ['ssp585']
# 2. Hydrobasin 6 
shapePATH = '/dbfs/mnt/pgb-data-lake/aqueduct_dev/aux-boundaries/hydro_basin_lv6/aq3_pfaf_basins.shp'
hy6 = gpd.read_file(shapePATH , crs="epsg:4326")
hy6.columns= hy6.columns.str.lower()
project_crs = hy6.crs # WGS84 aka epsg 4326
hy6.set_index('pfaf_id', inplace = True)
geog_id = 'pfaf_id'

# 3. Area
areaPATH = '/dbfs/mnt/pgb-data-lake/aqueduct_dev/aux-boundaries/global_area_5arcmin.nc'
ds_area = read_NETCDF(areaPATH)
area_band = 'global_cellsize_m2_05min.tif'  
    

# # 4. Output ( 0 = resample; 1 = GCM; 2 = SCEN; 3 = M region; )
# outPATH = '/dbfs/mnt/pgb-data-lake/aqueduct_dev/pcrglobwb_aqueduct_2021/version_2021-09-16/run_202205/zonal_statistics/pfaf6/runoff_resample_{0}/{1}_{2}_{3}.csv'.format

# 4. Output {root: 0 = GCM; 1 = SCEN, name: 0 = m region}
newROOT = '/dbfs/mnt/pgb-data-lake/aqueduct_dev/pcrglobwb_aqueduct_2021/version_2021-09-16/run_202205/zonal_statistics/pfaf6/runoff_resample_{0}/{1}/{2}/'.format
newNAME = '{0}.csv'.format

In [0]:
def run_runoff_parallelized(gcm, scen, m, resample_size):
    '''
    PURPOSE: Find the mean runoff
    INPUTS:
        gcm: global climate model
        scen: scenario
        m: region

    OUTPUTS:
        None returned. CSV containing zonal statistics will save. 
    '''
    # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # 
    # In function function
    def clip_by_pfaf_id(p):
        # Step 3.1: Select 1 polygon per loop
        my_geom = hy6.loc[p:p, :]
        # Step 3.2: Clip NetCDF by polygon
        stime = time.time()
        clipped = rechunked2.rio.clip(my_geom.geometry, project_crs, drop=False)
        # Step 3.2 housekeeping
        print('- - - - Clipped NetCDF in {}'.format(time.time()-stime))
        memory_usage()
        # Step 3.3: Sum contents across lat and long
        geom_sum = clipped.sum(dim = ['lon', 'lat'])
        # Step 3.3 housekeeping
        del clipped
        # Convert array to Series
        df_t = geom_sum.to_dataframe(name = 'runoff')
        # Add geometry ID
        df_t[geog_id] = p
        # Step 3.4 housekeeping
        del geom_sum
        return df_t
    

    # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
    # Step 1 - Read in 5 arcmin dataset (steps are different for irrigation)
    runoff = find_runoff_paths(gcm, scen, m)
    ds_ww = read_NETCDF(runoff) 
    ww_band = 'land_surface_runoff'
    print(" - - - - Step 1: Read in runoff") 
    memory_usage()   
    
    # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #     
    # Step 2 - Convert flux to volume  
    stime = time.time()
    ds_vol = flux_to_volume(ds_ww, ww_band)
    del ds_ww
    # 5. Persist and Fill NA with hand-made function 
    ds_vol = ds_vol.persist()
    ds_vol = fillnas(ds_vol['volume'])
    print(' - - - - Step 2: Convertered flux to volume in {}'.format(time.time()-stime))
    memory_usage()
    
    # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
    # Step 3 - Resample data
    stime = time.time()
    # 6. Resample data    
    ds_rs = resample_xarray(ds_vol, resample_size)
    del ds_vol
    print('- - - - Step 3: Resampled data in {}'.format(time.time()-stime))
    memory_usage()
    
    # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
    # Step 4 - Chunk data
    stime = time.time()
    # Chunk data to improve speed
    ds_rs = ds_rs.persist()
    rechunked = ds_rs.chunk({"lon": 100, "lat": 100})
    # Fill NA's with where statement. fillna functions aren't working great
    rechunked2 = fillnas(rechunked)
    del ds_rs
    del rechunked
    print('- - - - Step 4: Chunked in {}'.format(time.time()-stime))
    
    # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
    # Step 5 - Find list of regional pfafs
    geogidlookupPATH = '/dbfs/mnt/pgb-data-lake/aqueduct_dev/aux-boundaries/m_region-pfaf6-lookups/{0}_pfaf6_lookup.csv'.format
    df_pfs = pd.read_csv(geogidlookupPATH(m))
    pfs = df_pfs['pfaf_id'].tolist()
    
    # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #
    # Step 6 - Clip by polygons
    print(" - - Start Step 5: Start clip loop") 
    memory_usage()
    stime = time.time()
    # Set number of workers
    n_workers = 40
    # Keep track of while loop
    worker_count = 0
    # While catchments remain in objectids, repeat this process. For every new round, lessen the number of workers to prevent memory overload
    while len(pfs) > 0:
        worker_count += 1
        print('- - - - - run number', 1, "\n- - - - - - Remaining catchments:", len(pfs))
        # Segment catchments by workers for parallel process
        objectids_list = segment_id_list(lst=pfs, n= int(n_workers / worker_count))
        # Step 3. Clip and sum by polygon
        df_fs = []
        for oids in objectids_list:
            memory_usage()
            df_ts = Parallel(n_jobs=n_workers)(delayed(clip_by_pfaf_id)(p) for p in oids)
            df_t_all = pd.concat(df_ts)
            # Pull list of finished PFs, remove from overall list
            finished_pfs = list(set(df_t_all[geog_id].tolist()))
            pfs = [x for x in pfs if (x not in finished_pfs)]
            del df_ts
            df_fs.append(df_t_all)
            del df_t_all
            print("- - - - Remaining catchments:", len(pfs))
            gc.collect()
    df_f = pd.concat(df_fs)
    outPATH = newROOT(resample_size, gcm, scen) + newNAME(m)
    df_f.to_csv(outPATH)

    del df_f
    del df_fs
    gc.collect()

# Run

In [0]:
rs_size = 5
# # Full run for 1 Region
for gcm in gcmFolders:
    print("- ", gcm)
    for scen in scenFolders:
        if gcm == 'gswp3-w5e5' and scen != 'historical':
            continue
        else:
            print("- - ", scen)
            for m in tqdm(mFolders):
                print("- - -", m)
                if os.path.exists(newROOT(rs_size, gcm, scen) + newNAME(m)):
                    continue
                else:
                    run_runoff_parallelized(gcm = gcm, scen = scen, m = m, resample_size = rs_size)

In [0]:
runoff = find_runoff_paths('gfdl-esm4', 'ssp126', 'M0000001')
ds_ww = read_NETCDF(runoff) 

In [0]:
ds_ww['land_surface_runoff']