# Soil Moisture Simulation LSTM

In [1]:
from pathlib import Path
import numpy as np 
import xarray as xr 
import matplotlib.pyplot as plt
import seaborn as sns
import warnings 
import pandas as pd
import torch 

import sys
sys.path.append("/home/tommy/neuralhydrology")
from scripts.read_nh_results import (
    get_test_filepath,
    get_all_station_ds,
    calculate_all_error_metrics,
    get_ensemble_path,
)

from scripts.read_model import (get_model, _load_weights)
from scripts.read_nh_results import (read_multi_experiment_results, calculate_member_errors)
from neuralhydrology.utils.config import Config

%load_ext autoreload 
%autoreload 2

# Data

In [2]:
run_dir = Path("/datadrive/data/runs/esa_cci_sm_lstm_1406_114743")
run_dir = Path("/datadrive/data/runs/esa_cci_sm_SMOOTH_lstm_2106_153936")
expt_dir = Path("/datadrive/data/runs/ERA5Land_SoilMoisture")

run_dir = sorted(list(expt_dir.iterdir()))[0]

# GET config
cfg = Config(run_dir / "config.yml")
cfg.run_dir = run_dir

# GET preds
res_fp = get_test_filepath(run_dir, epoch=30)

preds = read_multi_experiment_results(expt_dir, ensemble_members=False)
preds = preds.sortby("member")
preds["member"] = ["swvl1", "swvl2", "swvl3", "swvl4"]

# GET trained model
model = get_model(cfg).to(cfg.device)
_load_weights(model, cfg)

AssertionError: Has validation been run? ipython --pdb neuralhydrology/nh_run.py evaluate -- --run-dir /datadrive/data/runs/ERA5Land_SoilMoisture/ERA5Land_SoilMoistureVolumeLevel1_1108_202913

In [None]:
data_dir = Path("/datadrive/data")
preds.to_netcdf(data_dir / "SOIL_MOISTURE/results/lstm_direct_sm_preds.nc")

In [None]:
obs_var=[v for v in preds.data_vars if "obs" in v][0]
sim_var=[v for v in preds.data_vars if "sim" in v][0]

sm_errors = calculate_member_errors(
    preds,
    basin_coord="station_id",
    time_coord="time",
    obs_var=obs_var,
    sim_var=sim_var,
    metrics=["NSE", "Pearson-r"],
)

In [None]:
# unq_vars = np.unique(["_".join(v.split("_")[0:-1]) for v in preds.data_vars])
from scripts.cell_state.analysis import finite_flat, histogram_plot

In [None]:
variable = "Pearson-r"
f, axs = plt.subplots(2, 1, figsize=(12, 4*2))

colors = sns.color_palette("viridis", 4)
for jx, member in enumerate(preds.member.values):
    color = colors[jx]
    for i, variable in enumerate(["NSE", "Pearson-r"]):
        ax = axs[i]
        arr = finite_flat(sm_errors[variable].sel(member=member))
        med = np.median(arr)
        histogram_plot(np.clip(arr, 0, 1), hist_kwargs={"color": color, "label": f"{member}: {med:.2f}"}, ax=ax)
        ax.legend()
        ax.set_xlabel(variable)
        ax.set_xlim(0, 1)
        if i == 0:
            ax.set_title("The Single-Output LSTM SM simulations produce NSE scores comparable with discharge")

plt.tight_layout();

# How does it compare with the probe?

In [None]:
from scripts.cell_state.analysis import (save_probe_components, load_probe_components)

probe_run_dir = Path('/datadrive/data/runs/complexity_AZURE/hs_064_0306_205514')
all_models_preds = load_probe_components(probe_run_dir)

In [None]:
from typing import Any, Dict
import seaborn as sns
import matplotlib.pyplot as plt


def empirical_cdf(errors: np.ndarray, kwargs: Dict[str, Any] = {}):
    x = np.sort(errors)
    y = np.arange(len(x))/float(len(x))
    plt.plot(x, y, **kwargs)

    
f, ax = plt.subplots(figsize=(12, 4))
# nse = sm_errors["NSE"]
# empirical_cdf(nse, kwargs={"label": f"SWVL1 LSTM: {np.median(nse):.2f}", "color": "C0"})

# plot probe results
target_vars = list(all_models_preds.keys())
colors = sns.color_palette("viridis", n_colors=len(target_vars))
for ix, target_var in enumerate(target_vars):
    errors = all_models_preds[target_var]["errors"]
    nse = errors["NSE"]
    empirical_cdf(nse, kwargs={"label": f"Probe {target_var}: {np.median(nse):.2f}", "color": colors[ix], "ls": "--"})
    
    lstm_nse = sm_errors.sel(member=target_var)["NSE"]
    empirical_cdf(lstm_nse, kwargs={"label": f"{target_var} LSTM: {np.median(lstm_nse):.2f}", "color": colors[ix]})

ax.set_xlim(0, 1)
ax.set_ylim(0, 1.1)
plt.legend()
sns.despine()

# Timeseries

In [None]:
pixels = np.random.choice(preds.station_id.values, size=10)

times = slice("04-01-2000", "01-01-2010")

f, axs = plt.subplots(len(pixels), 1, figsize=(12, 4*len(pixels)))
member = "swvl1"
for i, px in enumerate(pixels):
    ax = axs[i]
    preds.sel(station_id=px, member=member)[obs_var].sel(time=times).plot(ax=ax, color="k", alpha=0.6, ls="--", label=f"{member} Obs")
    preds.sel(station_id=px, member=member)[sim_var].sel(time=times).plot(ax=ax, label=f"LSTM {member}", color="C2")
    ax.legend()
    sns.despine()
    
plt.tight_layout();

# Spatial Errors

In [None]:
import geopandas as gpd
from scripts.geospatial import initialise_gb_spatial_plot, load_latlon_points
from mpl_toolkits.axes_grid1 import make_axes_locatable

st_data_dir = Path("/home/tommy/spatio_temporal/data")
points = load_latlon_points(st_data_dir)
static = xr.open_dataset(st_data_dir / "camels_static.nc")

In [None]:
sm_errors["station_id"] = sm_errors["station_id"].astype(int)

for member in sm_errors.member.values:
    err_ = sm_errors.sel(member=member)
    gdf = gpd.GeoDataFrame(err_.to_dataframe().join(points))

    ax = initialise_gb_spatial_plot()
    gdf.plot("NSE", vmin=0.7, vmax=1, ax=ax, cmap="viridis_r", legend=True)