# Figure 9. Future seasonal changes

In [None]:
import os 
import numpy as np
import pandas as pd
import xarray as xr

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

os.chdir('/home/rooda/OneDrive/Projects/DeepHydro/')

In [None]:
dict_basins = {"Y00004842":'Puelo', 
              "Y00004886":'Yelcho', 
              "Y00004938":'Palena', 
              "Y00004983":'Cisnes', 
              "Y00004956":'Aysen', 
              "Y00004957":'Baker', 
              "Y00004946":'Pascua', 
              "Y00004814":'Santa Cruz', 
              "Y00004894":'Grey'}

## Data 

In [None]:
metadata = pd.read_csv("data/Attributes_all_basins.csv", index_col = 0)
metadata = metadata.loc[dict_basins.keys()]

In [None]:
q_historical = xr.open_dataset("results/zenodo/Q_historical.nc", chunks = "auto").sel(basin_id = list(dict_basins.keys()))
q_historical = q_historical.groupby(q_historical.date.dt.dayofyear).mean()
q_historical = q_historical.rolling(dayofyear = 7, center = True, min_periods = 2).mean()

q_future = xr.open_dataset("results/zenodo/Q_future.nc", chunks = "auto").sel(basin_id = list(dict_basins.keys()))
q_future = q_future.sel(date = slice("2070-01-01", "2099-12-31"))
q_future = q_future.groupby(q_future.date.dt.dayofyear).mean()
q_future = q_future.rolling(dayofyear = 7, center = True, min_periods = 2).mean()
q_future = q_future.mean(dim = "gcm")
q_future = ((q_future.Q / q_historical.Q) - 1)*100
q_future = q_future.load()

## Plot

In [None]:
cl = px.colors.qualitative.D3
scen_dict = {"ssp126":"SSP 1-2.6", "ssp585": "SSP 5-8.5"}

basins_xy  = np.reshape(list(dict_basins.keys()), (3, 3))
basins_titles = [a + " (" + str(b) + "%)" for a, b in zip(list(dict_basins.values()), metadata.glacier_cover.to_list())]


for scenario in ["ssp126", "ssp585"]:

    fig    = make_subplots(rows=3, cols=3, horizontal_spacing = 0.03, vertical_spacing = 0.05, shared_xaxes= True, shared_yaxes= True, 
                           y_title = "ΔRunoff (%) [{}]".format(scen_dict[scenario]), subplot_titles = basins_titles)
    
    for x in range(0,3):
        for y in range(0,3):
        
            fig.add_trace(go.Scatter(x=q_future.dayofyear, y=q_future.sel(ssp = scenario).sel(model = "GR4J").sel(basin_id = basins_xy[x,y]), 
                                     mode='lines', name= "SSP 5-8.5", line=dict(color= cl[1], width = 1.2, dash='dot'), showlegend=False), row=x+1, col=y+1)
    
            fig.add_trace(go.Scatter(x=q_future.dayofyear, y=q_future.sel(ssp = scenario).sel(model = "TUWmodel").sel(basin_id = basins_xy[x,y]), 
                                     mode='lines', name= "SSP 5-8.5", line=dict(color= cl[4], width = 1.2, dash='dot'), showlegend=False), row=x+1, col=y+1)

            fig.add_trace(go.Scatter(x=q_future.dayofyear, y=q_future.sel(ssp = scenario).sel(model = "LSTM_OGGM_off").sel(basin_id = basins_xy[x,y]), 
                                     mode='lines', name= "SSP 5-8.5", line=dict(color= cl[2], width = 1.2, dash='dot'), showlegend=False), row=x+1, col=y+1)

            fig.add_trace(go.Scatter(x=q_future.dayofyear, y=q_future.sel(ssp = scenario).sel(model = "LSTM_OGGM_on").sel(basin_id = basins_xy[x,y]), 
                                     mode='lines', name= "SSP 5-8.5", line=dict(color= cl[0], width = 1.5), showlegend=False), row=x+1, col=y+1)
    
    fig.add_trace(go.Scatter(x=[None], y=[None],  mode='lines', name= "LSTM + OGGM", line=dict(color= cl[0], width = 1.5), showlegend=True), row=1, col=1)
    fig.add_trace(go.Scatter(x=[None], y=[None],  mode='lines', name= "Only LSTM", line=dict(color= cl[2], width = 1, dash = "dot"), showlegend=True), row=1, col=1)
    fig.add_trace(go.Scatter(x=[None], y=[None],  mode='lines', name= "TUWmodel + OGGM", line=dict(color= cl[4], width = 1, dash = "dot"), showlegend=True), row=1, col=1)
    fig.add_trace(go.Scatter(x=[None], y=[None],  mode='lines', name= "GR4J + OGGM", line=dict(color= cl[1], width = 1, dash = "dot"), showlegend=True), row=1, col=1)
    
    fig.update_yaxes(ticks="outside", range = [-70, 70], dtick = 25, griddash = "dot", tickfont = dict(size=12), ticksuffix = "%",  title_standoff = 100)
    fig.update_xaxes(ticks="outside", griddash = "dot",
                     tickvals=[15,46,74,105,135,166,196,227,258,288,319,349], tickfont = dict(size=15),
                     ticktext = ["J", "F", "M", "A", "M", "J", "J", "A", "S", "O", "N", "D"])
    
    fig.update_layout(legend=dict(y=1.08, x=0.46, orientation="h", bgcolor = 'rgba(0,0,0,0.0)', font_size = 15))
    fig.update_layout(height=800, width=1200, template = "seaborn", margin = dict(l=65, r=20, b=20, t=20), hovermode = False)
    fig.update_annotations(font_size=16)
    
    # save figure 
    fig.write_image("reports/figures/Figure9_seasonal_changes_{}.png".format(scenario), scale=4)
#fig.show()

## Text

In [None]:
q_historical = xr.open_dataset("results/zenodo/Q_historical.nc").sel(basin_id = "Y00004842").sel(model = "LSTM_OGGM_on").Q
q_historical = q_historical * metadata.total_area.loc["Y00004842"] * 1e6 / (1e3*86400)

q_future = xr.open_dataset("results/zenodo/Q_future.nc").sel(basin_id = "Y00004842").sel(model = "LSTM_OGGM_on").Q
q_future = q_future.sel(date = slice("2070-01-01", "2099-12-31"))
q_future = q_future * metadata.total_area.loc["Y00004842"] * 1e6 / (1e3*86400)


((q_future < 250).sum(dim = "date").mean(dim = "gcm")/len(q_future.date)) / ((q_historical < 250).sum(dim = "date")/len(q_historical) )