# Figure S6. Runoff ratio

In [None]:
import os 
import numpy as np
import pandas as pd
import xarray as xr

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import shapely.geometry

from tqdm import tqdm
import geopandas as gpd

os.chdir('/home/rooda/OneDrive/Projects/DeepHydro/') 
path_climate = "/home/rooda/Pipeline/DeepHydro/CLIMATE/catchments"
path_geo = "/home/rooda/Datasets/GIS/"

In [None]:
def pp_preprocessing(df, initial_date):
    df = df[df.index >= initial_date]
    df = df.resample("YS").sum().mean()
    df.index.name = "basin_id"
    df = df.to_xarray()
    return(df)

## Data

In [None]:
# historical
pp_base = pd.read_parquet(path_climate + "/PP_ref_all_basins_full.parquet").set_index("date")
pp_base.index = pd.to_datetime(pp_base.index)
pp_base = pp_preprocessing(pp_base, "2000-01-01")

q_historical = xr.open_dataset("results/zenodo/Q_historical.nc")
q_historical = q_historical.resample(date = "YS").sum().mean(dim = "date").Q
q_historical = q_historical/pp_base

# to df
q_historical = q_historical.to_pandas().transpose()
q_historical = q_historical.reset_index().set_index("basin_id")
q_historical.columns = "hist_" + q_historical.columns


In [None]:
gcms  = ["GFDL-ESM4", "IPSL-CM6A-LR", "MIROC6", "MPI-ESM1-2-LR", "MRI-ESM2-0"]
ssps  = ["ssp126", "ssp585"]

# future
df_ssp = []
for ssp in tqdm(ssps):

    df_gcm = []
    for gcm in tqdm(gcms, leave = False):
        
        df_i = pd.read_parquet(path_climate + "/PP_{}_{}_all_basins_full.parquet".format(gcm, ssp)).set_index("date")
        df_i.index = pd.to_datetime(df_i.index)
        df_i = pp_preprocessing(df_i, "2070-01-01")
        df_gcm.append(df_i)

    df_gcm = xr.concat(df_gcm, dim='gcm')
    df_ssp.append(df_gcm)

df_ssp = xr.concat(df_ssp, dim='ssp')
pp_future = df_ssp.assign_coords(gcm=gcms, ssp=ssps)

q_future = xr.open_dataset("results/zenodo/Q_future.nc", chunks = "auto")
q_future = q_future.sel(date = slice("2070-01-01", "2099-12-31"))
q_future = q_future.resample(date = "YS").sum().mean(dim = "date").Q.load()
q_future = q_future / pp_future
q_future = q_future.mean(dim = "gcm")

# analysis will consider only SSP 585
q_future = q_future.sel(ssp = "ssp585").drop("ssp")
q_future = q_future.to_pandas().transpose()
q_future = q_future.reset_index().set_index("basin_id")
q_future.columns = "SSP_" + q_future.columns


In [None]:
pmet_shape = gpd.read_file("data/GIS/Basins_Patagonia_all_data.gpkg") # Problem in Puelo 
pmet_shape = pmet_shape.set_index("gauge_id")
pmet_shape["total_area"] = pmet_shape.total_area*1e6 # in m2 
pmet_shape = pd.concat([pmet_shape, q_historical, q_future], axis = 1)

## Plot

In [None]:
# basemap for background
geo_map = gpd.read_file(path_geo + "south_america.shp")
geo_map = geo_map[(geo_map.CC == "CI") | (geo_map.CC == "AR")]
geo_map = geo_map.dissolve(by='REGION')
geo_map["geometry"] = geo_map.simplify(0.01)

poly_gdf = shapely.geometry.Polygon([(-76, -54.99), (-76, -40.51), (-68.01, -40.51), (-68.01, -54.99), (-76, -54.99)])
poly_gdf = gpd.GeoDataFrame([1], geometry=[poly_gdf], crs=geo_map.crs)

geo_map = geo_map.clip(poly_gdf)

In [None]:
from  plotly.colors import unlabel_rgb, hex_to_rgb
def binned_colorscale(seq, nr_swatches=5):
    
    if seq[0][0] == '#':
        arr_colors=np.array([hex_to_rgb(s) for s in seq])/255
    elif seq[0][0:3] == 'rgb':
        arr_colors = np.array([unlabel_rgb(s) for s in seq])/255 
    else:
        raise ValueError("a plotly colorscale is given either with hex colors or as rgb colors")
    n = len(seq)
    svals = [k/(n-1) for k in range(n)] #the scale values corresponding to the colors in seq
    grid = [k/(nr_swatches-1) for k in range(nr_swatches)]# define the scale values corresponding nr_swatches
    r, g, b = [np.interp(grid, svals, arr_colors[:, k]) for k in range(3)]  #np.interp interpolates linearly
    cmap_arr = np.clip(np.vstack((r, g, b)).T, 0, 1)
    new_colors = np.array(cmap_arr*255, dtype=int)
    discrete_colorscale = []
    N = len(new_colors+1)
    for k in range(N):
        discrete_colorscale.extend([[k/N, f'rgb{tuple(new_colors[k])}'], 
                                    [(k+1)/N,  f'rgb{tuple(new_colors[k])}']]) 
    return discrete_colorscale 

In [None]:
fig = make_subplots(rows=2, cols=2, vertical_spacing = 0.03, horizontal_spacing = 0.01, 
                    subplot_titles = ["a) LSTM + OGGM", "b) Only LSTM", "c) TUWmodel + OGGM", "d) GR4J + OGGM"],
                    specs=[[{"type": "scattergeo"}, {"type": "scattergeo"}],
                           [{"type": "scattergeo"}, {"type": "scattergeo"}]])

# colors
cl = px.colors.qualitative.D3
cs = px.colors.colorbrewer.GnBu
colorsc = binned_colorscale(["#fe7e0d", "#ffe9ba", "#1d78b4"], nr_swatches=15) 


dtick = 2
x = list(range(-78, 0 + dtick, dtick))
y = list(range(-56, 0 + dtick, dtick))
xpos = -75.9
ypos = -56

## Basemap -----------------------------------------------------------------------------------------------
for x_plot in range(0,2):
        for y_plot in range(0,2):

          fig.add_trace(go.Choropleth(geojson = eval(geo_map['geometry'].to_json()),  locations = geo_map.index, z = geo_map['iso_num'], 
                                      colorscale = ["#EAEAF2", "#EAEAF2"],  showscale= False, marker_line_color ='white', marker_line_width=0.1), row=y_plot+1, col=x_plot+1)
      
          fig.add_trace(go.Scattergeo(lon = [lon + 0.25 for lon in x[1:-1]] + [xpos + 0.25] * (len(y) - 2),
                                  lat = [ypos + 0.1] * (len(x) - 2) + [lat + 0.1 for lat in y[1:-1]],
                                  showlegend = False,
                                  text = x[1:-1] + y[1:-1], textfont=dict(size=11, color = "rgba(0,0,0,0.25)"),
                                  mode = "text"), row=y_plot+1, col=x_plot+1)

fig.add_trace(go.Choropleth(geojson = eval(pmet_shape['geometry'].to_json()),  locations = pmet_shape.index, z = pmet_shape['SSP_LSTM_OGGM_on'] - pmet_shape['hist_LSTM_OGGM_on'], 
                            colorscale = colorsc, marker_line_color ='white', marker_line_width=0.2, 
                            zmin = -0.5, zmax = 0.5, colorbar=dict(len=0.9, x=1, y= 0.50, title='Δ Runoff ratio<br>(Q/PP)', thickness=25, tickwidth=1)), row=1, col=1)

fig.add_trace(go.Choropleth(geojson = eval(pmet_shape['geometry'].to_json()),  locations = pmet_shape.index, z = pmet_shape['SSP_LSTM_OGGM_off'] - pmet_shape['hist_LSTM_OGGM_off'], 
                            colorscale = colorsc, marker_line_color ='white', marker_line_width=0.2, 
                            zmin = -0.5, zmax = 0.5, showscale = False), row=1, col=2)

fig.add_trace(go.Choropleth(geojson = eval(pmet_shape['geometry'].to_json()),  locations = pmet_shape.index, z = pmet_shape['SSP_TUWmodel'] - pmet_shape['hist_TUWmodel'], 
                            colorscale = colorsc, marker_line_color ='white', marker_line_width=0.2, 
                            zmin = -0.5, zmax = 0.5, showscale = False), row=2, col=1)

fig.add_trace(go.Choropleth(geojson = eval(pmet_shape['geometry'].to_json()),  locations = pmet_shape.index, z = pmet_shape['SSP_GR4J'] - pmet_shape['hist_GR4J'], 
                            colorscale = colorsc, marker_line_color ='white', marker_line_width=0.2, 
                            zmin = -0.5, zmax = 0.5, showscale = False), row=2, col=2)

## Layout -------------------------------------------------------------------------------------------------------------------
fig.update_xaxes(showline = True, linecolor = 'rgba(0,0,0,0.5)', linewidth = 1, ticks="outside", griddash = "dot", mirror=True)
fig.update_yaxes(showline = True, linecolor = 'rgba(0,0,0,0.5)', linewidth = 1, ticks="outside", griddash = "dot", mirror=True)

fig.update_geos(showframe = True, framewidth = 1, framecolor = "rgba(0,0,0,0.5)", 
                lonaxis_range=[-76, -68], lataxis_range=[-56, -40.5], 
                bgcolor = "rgb(255,255,255)", 
                showland = False, showcoastlines = False, showlakes = False, 
                lataxis_showgrid=True, lonaxis_showgrid=True, 
                lonaxis_dtick = dtick, lataxis_dtick = dtick, 
                lonaxis_gridcolor = "rgba(0,0,0,0.1)", lataxis_gridcolor = "rgba(0,0,0,0.1)", 
                lonaxis_griddash = "dot", lataxis_griddash = "dot")

fig.update_layout(autosize = False, width = 800, height = 1200, template = "seaborn", font_size = 17, margin = dict(l=5, r=5, b=5, t=30))
fig.write_image("reports/figures/FigureS6_Runoff_ratio.png", scale = 4)
#fig.show()