# Plotting MSE with Predicted MSE Over Time

In [None]:
import hvplot.pandas
import holoviews as hv
import scripts.latexify as latexify
try:
    from Deriv_dask import Deriv_dask
except:
    from scripts.Deriv_dask import Deriv_dask
try:
    from latexify import in_params_dic, physical_params
except:
    from scripts.latexify import in_params_dic, physical_params
import numpy as np
import xarray as xr
import pandas as pd

from glob import glob
import sys
from scripts.plot_outcomes import plot_my_thing as plot
from timeit import default_timer as timer

hv.extension("matplotlib")

def d_unnamed(df):
    df = df.loc[df["Perturbed Parameter"].isin(in_params)]
    return df.loc[:, ~df.columns.str.contains('^Unnamed')]

reduced_df = d_unnamed(pd.read_csv("stats_full/conv_adjusted_mse_errorMean_sensMean.csv"))
reduced_df = reduced_df.loc[reduced_df["Ratio Type"] == "adjusted"]
reduced_df = reduced_df.loc[reduced_df["Sensitivity"] != 0]

top20_sens_dic = {}
out_params = ["QV", "QC", "QR", "QI", "QG", "QH", "QS"]
for out_p in out_params:
    df = reduced_df.loc[reduced_df["Output Parameter"] == out_p]
    top20_sens_dic[out_p] = list(np.unique( df.nlargest(20, "Sensitivity")["Perturbed Parameter"] ))


In [None]:
def plot_sens_param(path, out_params, in_params, traj=0, min_x=None, max_x=None, plot_differences=False):
    
    def my_little_loader(path, f, out_ps, min_x, max_x):
        ds = xr.open_dataset(path + f, decode_times=False)[out_ps + ["time_after_ascent"]]
        return ds.sel(time=np.arange(min_x, max_x+20, 20)).compute()
    
    sensitivities = Deriv_dask(
        direc=path, # ,
        parquet=False,
        netcdf=True,
        columns=None,
        backend="matplotlib", 
        file_ending="traj" + str(traj) + "_notPerturbed.nc_wcb")
    
    sensitivities.cache_data(
        in_params=in_params,
        out_params=out_params + ["pressure"],
        x_axis="time_after_ascent",
        y_axis="pressure",
        compute=True,
        trajectories=None,
        min_x=min_x,
        max_x=max_x)
    sens_df = sensitivities._recalc_ratios(sensitivities.cache, "adjusted", None, in_params)
    min_x = np.min(sensitivities.cache["time"])
    max_x = np.max(sensitivities.cache["time"])
    if plot_differences:
        sens_df[in_params] = sens_df[in_params]**2
        twin_label = "Squared Predicted Error (10% Perturbance)"
    else:
        twin_label = "Predicted Error (10% Perturbance)"
    
    def Hlarry_Plotter(s_df, e_df, y, out_param):
        lower_y = np.min(s_df[y])
        upper_y = np.max(s_df[y])
        delta = (upper_y - lower_y)/6
        lower_y -= delta
        upper_y += delta
        def twinx_per_timestep(plot, element):
            ax = plot.handles["axis"]
            twinax = ax.twinx()
            twinax.set_ylim((lower_y, upper_y))
            twinax.set_ylabel(twin_label)
            plot.handles["axis"] = twinax

        def hook(plot, element):
            plot.handles['text_1_glyph'].text_font_size = '16pt'
            plot.handles['text_2_glyph'].text_font_size = '16pt'

        renderer = hv.Store.renderers["matplotlib"].instance(
                            fig='png', dpi=300)

        min_t = np.min(s_df["time_after_ascent"])
        max_t = np.max(s_df["time_after_ascent"])
        t = -1000
        perturb_lines = None
        while t < max_t:
            if t < min_t:
                t += 1800
                continue
            if perturb_lines is None:
                perturb_lines = hv.VLine(x=t).opts(color="black")
            else:
                perturb_lines *=  hv.VLine(x=t).opts(color="black")
            t += 1800
        twin = s_df.hvplot.scatter(
            x="time_after_ascent",
            y=y,
            color="red",
            label=latexify.parse_word(y).replace("{", "{{").replace("}", "}}"),
            alpha=0.3,
            legend=True).opts(initial_hooks=[twinx_per_timestep], apply_ranges=False)
        if plot_differences:
            title="MSE " + latexify.parse_word(out_param).replace("{", "{{").replace("}", "}}")
        else:
            title=latexify.parse_word(out_param).replace("{", "{{").replace("}", "}}")
        pl = (e_df.hvplot.scatter(x="time_after_ascent", y=out_param, aspect=10/10, alpha=0.3, label=title, ylabel=title) 
              * twin * perturb_lines
                    ).opts(fig_inches=10, title=title + " for one trajectory") 
        renderer.save(pl, "pics/" + out_param + "_" + y)
        return pl
    if plot_differences:
        mean_dic = {}
        for out_param in out_params:
            mean_dic[out_param] = np.reshape(np.asarray(
                sensitivities.cache.loc[sensitivities.cache["Output Parameter"] == out_param][out_param]),
                (len(sensitivities.cache.loc[sensitivities.cache["Output Parameter"] == out_param][out_param].index), 1, 1))

    all_pl = None
    for in_p in in_params:
        d = in_p[1::]
        ens_df = my_little_loader(
            path + "traj" + str(traj) + "/", 
            d + ".nc_wcb", 
            out_params,
            min_x,
            max_x)
        
        if plot_differences:
            plot_dic = {"time_after_ascent": np.unique(sensitivities.cache["time_after_ascent"])}
            for out_p in out_params:
                plot_dic[out_p] = np.mean( (ens_df[out_p]-mean_dic[out_p])**2, axis=(1,2) )
            plot_df = pd.DataFrame.from_dict(plot_dic)
        else:
            plot_df = ens_df
        for out_p in out_params:
            sens_tmp_df = sens_df.loc[sens_df["Output Parameter"] == out_p]
            pl = Hlarry_Plotter(sens_tmp_df, plot_df, in_p, out_p)
            if all_pl is not None:
                all_pl += pl
            else:
                all_pl = pl
    try:
        return all_pl.cols(2)
    except:
        return all_pl

### Create plots and save them

In [None]:
path = "/data/project/wcb/netcdf/perturbed_ensembles/conv_600_0_traj_t000000_p001/"
traj = 0
out_params = ["QV", "QC", "QR", "QI", "QG", "QH", "QS"]
for out_p in out_params:
    _ = plot_sens_param(path, [out_p], top20_sens_dic[out_p], traj, max_x=None, plot_differences=True)

### Create just plots for QC

In [None]:
path = "/data/project/wcb/netcdf/perturbed_ensembles/conv_600_0_traj_t000000_p001/"
traj = 0
plot_sens_param(path, ["QC"], top20_sens_dic["QC"], traj, max_x=5000, plot_differences=True)

### Create plots for every model parameter
This results in a lot of plots

In [None]:
in_params = []
for key in in_params_dic:
    in_params.extend(in_params_dic[key])
for e in physical_params:
    in_params.remove(e)
# We need to remove parameters for the one-moment scheme obviously
in_params.remove("da_1")
in_params.remove("da_2")
in_params.remove("de_1")
in_params.remove("de_2")
in_params.remove("dd")
in_params.remove("dN_c")
in_params.remove("dgamma")
in_params.remove("dbeta_c")
in_params.remove("dbeta_r")
in_params.remove("ddelta1")
in_params.remove("ddelta2")
in_params.remove("dzeta")

path = "/data/project/wcb/netcdf/perturbed_ensembles/conv_600_0_traj_t000000_p001/"
traj = 0
_ = plot_sens_param(path, ["QC"], top20_sens_dic["QC"], traj, max_x=5000, plot_differences=True)