# Figure 7. Regionalization differences (selection)

In [None]:
import os 
import numpy as np
import pandas as pd
import xarray as xr

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

os.chdir('/home/rooda/OneDrive/Projects/DeepHydro/')

In [None]:
dict_basins = {"Y00004842":'Puelo', 
               "Y00004886":'Yelcho', 
               "Y00004938":'Palena', 
               "Y00004983":'Cisnes', 
               "Y00004956":'Aysen', 
               "Y00004957":'Baker', 
               "Y00004946":'Pascua', 
               "Y00004814":'Santa Cruz', 
               "Y00004894":'Grey'}

## Data 

In [None]:
metadata = pd.read_csv("data/Attributes_all_basins.csv", index_col = 0)
metadata = metadata.loc[dict_basins.keys()]

In [None]:
q_data = xr.open_dataset("results/zenodo/Q_historical.nc").sel(basin_id = list(dict_basins.keys()))
q_data = q_data.groupby(q_data.date.dt.dayofyear).mean()
q_data = q_data.rolling(dayofyear = 7, center = True, min_periods = 1).mean()

## Plot

In [None]:
basins_xy  = np.reshape(list(dict_basins.keys()), (3, 3))
basins_titles = [a + " (" + str(b) + "%)" for a, b in zip(list(dict_basins.values()), metadata.glacier_cover.to_list())]

fig    = make_subplots(rows=3, cols=3, horizontal_spacing = 0.03, vertical_spacing = 0.05, 
                       shared_xaxes= True, shared_yaxes= False, y_title = "Runoff (mm d <sup>-1</sup>)", 
                       subplot_titles = basins_titles)

cl = px.colors.qualitative.D3

for x in range(0,3):
    for y in range(0,3):

        fig.add_trace(go.Scatter(x=q_data.dayofyear, y=q_data.sel(basin_id = basins_xy[x,y]).sel(model = "TUWmodel").Q, 
                                 mode='lines', name= "TUWmodel", line=dict(color= cl[4], width = 1.2, dash = "dot"), showlegend=False), row=x+1, col=y+1)

        fig.add_trace(go.Scatter(x=q_data.dayofyear, y=q_data.sel(basin_id = basins_xy[x,y]).sel(model = "GR4J").Q, 
                                 mode='lines', name= "GR4J", line=dict(color= cl[1], width = 1.2, dash = "dot"), showlegend=False), row=x+1, col=y+1)

        fig.add_trace(go.Scatter(x=q_data.dayofyear, y=q_data.sel(basin_id = basins_xy[x,y]).sel(model = "LSTM_OGGM_off").Q, 
                                 mode='lines', name= "LSTM OGGM off", line=dict(color= cl[2], dash = "dot", width = 1.2), showlegend=False), row=x+1, col=y+1)

        fig.add_trace(go.Scatter(x=q_data.dayofyear, y=q_data.sel(basin_id = basins_xy[x,y]).sel(model = "LSTM_OGGM_on").Q, 
                                 mode='lines', name= "LSTM OGGM on", line=dict(color= cl[0], width = 1.5), showlegend=False), row=x+1, col=y+1)


fig.add_trace(go.Scatter(x=[None], y=[None],  mode='lines', name= "LSTM + OGGM", line=dict(color= cl[0], width = 1.5), showlegend=True), row=1, col=1)
fig.add_trace(go.Scatter(x=[None], y=[None],  mode='lines', name= "Only LSTM", line=dict(color= cl[2], width = 1), showlegend=True), row=1, col=1)
fig.add_trace(go.Scatter(x=[None], y=[None],  mode='lines', name= "TUWmodel + OGGM", line=dict(color= cl[4], width = 1, dash = "dot"), showlegend=True), row=1, col=1)
fig.add_trace(go.Scatter(x=[None], y=[None],  mode='lines', name= "GR4J + OGGM", line=dict(color= cl[1], width = 1, dash = "dot"), showlegend=True), row=1, col=1)


fig.update_yaxes(title = "", row = 1, col = 1)
fig.update_yaxes(ticks="outside", griddash = "dot", tickfont = dict(size=15))
fig.update_xaxes(ticks="outside", griddash = "dot", tickvals=[15,46,74,105,135,166,196,227,258,288,319,349],  tickfont = dict(size=15),
                                                     ticktext = ["J", "F", "M", "A", "M", "J", "J", "A", "S", "O", "N", "D"])

fig.update_layout(legend=dict(y=1.08, x=0.46, orientation="h", bgcolor = 'rgba(0,0,0,0.0)', font_size = 15))
fig.update_layout(height=800, width=1200, template = "seaborn", margin = dict(l=10, r=10, b=10, t=10), hovermode = False)
fig.update_annotations(font_size=16)

# save figure 
fig.write_image("reports/figures/Figure7_Regionalization_diff_selection.png", scale=4)
fig.show()