# ESA CCI SM Probes

In [1]:
from pathlib import Path
import numpy as np 
import xarray as xr 
import matplotlib.pyplot as plt
import seaborn as sns
import warnings 
import pandas as pd

import sys
sys.path.append("/home/tommy/neuralhydrology")
from scripts.read_nh_results import (
    get_test_filepath,
    get_all_station_ds,
    calculate_all_error_metrics,
    get_ensemble_path,
)

from scripts.read_model import (get_model, _load_weights)
from scripts.read_nh_results import (read_multi_experiment_results, calculate_member_errors)
from neuralhydrology.utils.config import Config

%load_ext autoreload
%autoreload 2

In [2]:
PER_BASIN = False
data_dir = Path("/datadrive/data")
run_dir = data_dir / "runs/complexity_AZURE/hs_064_0306_205514"
out_dir = run_dir / "cell_states"
cfg = Config(run_dir / "config.yml")

In [3]:
import geopandas as gpd
from scripts.geospatial import initialise_gb_spatial_plot, load_latlon_points
from mpl_toolkits.axes_grid1 import make_axes_locatable

st_data_dir = data_dir
points = load_latlon_points(st_data_dir)
static = xr.open_dataset(st_data_dir / "camels_static.nc")

ds = xr.open_dataset(data_dir / "RUNOFF/ALL_dynamic_ds.nc")

# Input / Target Data

In [4]:
from scripts.cell_state.utils import (
    read_basin_list, 
    get_train_test_cell_states, 
    normalize_and_convert_dimension_to_variable_for_cell_state_data, 
    create_train_test_default_dict_for_all_target_vars, 
    train_and_evaluate_models
)

## Get training/test dataset

In [5]:
ds = xr.open_dataset(data_dir / "RUNOFF/ALL_dynamic_ds.nc")

train_sids = read_basin_list(cfg.train_basin_file)
test_sids = read_basin_list(cfg.test_basin_file)
train_ds = ds.sel(time=slice(cfg.train_start_date, cfg.train_end_date), station_id=np.isin(ds.station_id, train_sids.station_id))
test_ds = ds.sel(time=slice(cfg.test_start_date, cfg.test_end_date), station_id=np.isin(ds.station_id, test_sids.station_id))

out_of_sample = not all(np.isin(test_sids, train_sids))
print(f"Out of Sample: {not all(np.isin(test_sids, train_sids))}")

Out of Sample: False


## Probe the basins

In [6]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

train_cn, test_cn = get_train_test_cell_states(run_dir, cfg)

# Normalisation strategy - global or local?

In [7]:
cn = xr.concat([train_cn, test_cn], dim="date")
cn_station_mean_ = train_cn.mean(dim="date")
cn_station_std_ = train_cn.std(dim="date")
station_norm_cn = (cn - cn_station_mean_) / cn_station_std_

In [8]:
from scripts.cell_state.utils import normalize_and_convert_dimension_to_variable_for_cell_state_data


train_cs, (mean_, std_) = normalize_and_convert_dimension_to_variable_for_cell_state_data(
    cn=train_cn,
    out_dir=out_dir, 
    per_basin=PER_BASIN,
    train_test="train",
    time_dim="date",
    reload=True
)

test_cs, (mean_, std_) = normalize_and_convert_dimension_to_variable_for_cell_state_data(
    cn=test_cn,
    out_dir=out_dir, 
    per_basin=PER_BASIN,
    train_test="test",
    time_dim="date",
    reload=True,
    mean_=mean_,
    std_=std_
)

print("DONE")

Calculating Normalisation for `train` data: global
Calculating Normalisation for `test` data: global
DONE


In [9]:
train_cs["station_id"] = train_cs["station_id"].astype(int)
test_cs["station_id"] = test_cs["station_id"].astype(int)

train_cs = train_cs if (not "date" in train_cs.coords) else train_cs.rename({"date": "time"})
test_cs = test_cs if (not "date" in test_cs.coords) else test_cs.rename({"date": "time"})

# Target Data

In [10]:
TARGET = "ESA"

In [11]:
from scripts.cell_state.normalize import normalize_2d_dataset


if TARGET == "ERA5":
    filepath = data_dir / "camels_basin_ERA5Land_sm.nc"
    era5_ds = xr.open_dataset(filepath)

    if not isinstance(era5_ds, xr.Dataset):
        era5_ds = era5_ds.to_dataset()

    if not PER_BASIN:
        target_mean = era5_ds.mean()
        target_std = era5_ds.std()
        era5_ds = (era5_ds - target_mean) / target_std
    else:
        for var in era5_ds.data_vars:
            era5_ds[var] = normalize_2d_dataset(era5_ds, variable_str=var, per_basin=PER_BASIN)

    era5_ds["station_id"] = era5_ds["station_id"].astype(int)

    # NOT for snow depth ..?
    era5_ds = era5_ds.drop("sd")
    target_ds = era5_ds

elif TARGET == "ESA":
    filepath = data_dir / "SOIL_MOISTURE/interp_full_timeseries_esa_cci_sm.nc"
    esa_ds = xr.open_dataset(filepath).drop("spatial_ref")
    if not isinstance(esa_ds, xr.Dataset):
        esa_ds = esa_ds.to_dataset()
    
    if not PER_BASIN:
        target_mean = esa_ds.mean()
        target_std = esa_ds.std()
        esa_ds = (esa_ds - target_mean) / target_std
    else:
        for var in esa_ds.data_vars:
            esa_ds[var] = normalize_2d_dataset(esa_ds, variable_str=var, per_basin=PER_BASIN)

    esa_ds["station_id"] = esa_ds["station_id"].astype(int)
    target_ds = esa_ds
else:
    assert False
    

In [12]:
if TARGET == "ESA":
    target_ds = target_ds[["7_day_smooth_sm"]]
    orig_target_ds = esa_ds[["sm"]]

In [13]:
target_ds

# Probes Train

In [14]:
train_target_ds = target_ds.sel(station_id=np.isin(target_ds.station_id, train_cs.station_id), time=np.isin(target_ds.time, train_cs.time))
test_target_ds = target_ds.sel(station_id=np.isin(target_ds.station_id, test_cs.station_id), time=np.isin(target_ds.time, test_cs.time))

input_variables = list(train_cs.data_vars)
seq_length = 1
basin_dim = "station_id"
time_dim = "time"

In [15]:
all_train_test = create_train_test_default_dict_for_all_target_vars(
    train_cs=train_cs,
    test_cs=test_cs,
    train_target_ds=train_target_ds,
    test_target_ds=test_target_ds,
    input_variables=input_variables,
)

** STARTING 7_day_smooth_sm **


Creating Train Samples: 100%|██████████| 668/668 [00:12<00:00, 52.96it/s]
Creating Test Samples: 100%|██████████| 668/668 [00:08<00:00, 82.50it/s]
Extracting Data: 100%|██████████| 16424/16424 [02:03<00:00, 133.28it/s]


Merging and reshaping arrays


Extracting Data: 100%|██████████| 10482/10482 [01:16<00:00, 137.10it/s]


Merging and reshaping arrays


In [16]:
all_models_preds = train_and_evaluate_models(all_train_test)

** 7_day_smooth_sm linear model **
-- Epoch 1
Norm: 0.56, NNZs: 34, Bias: -0.171983, T: 3783967, Avg. loss: 0.051959
Total training time: 3.10 seconds.
-- Epoch 2
Norm: 0.56, NNZs: 36, Bias: -0.172275, T: 7567934, Avg. loss: 0.051943
Total training time: 6.49 seconds.
-- Epoch 3
Norm: 0.56, NNZs: 36, Bias: -0.175400, T: 11351901, Avg. loss: 0.051941
Total training time: 9.85 seconds.
-- Epoch 4
Norm: 0.56, NNZs: 36, Bias: -0.175678, T: 15135868, Avg. loss: 0.051940
Total training time: 13.20 seconds.
-- Epoch 5
Norm: 0.56, NNZs: 36, Bias: -0.175004, T: 18919835, Avg. loss: 0.051940
Total training time: 16.60 seconds.
-- Epoch 6
Norm: 0.56, NNZs: 35, Bias: -0.176591, T: 22703802, Avg. loss: 0.051938
Total training time: 19.94 seconds.
Convergence after 6 epochs took 20.24 seconds


Calculating Errors: 100%|██████████| 668/668 [00:15<00:00, 43.18it/s, 106001]






In [18]:
from scripts.cell_state.analysis import save_probe_components

(run_dir / "esa_cci").mkdir(exist_ok=True)
save_probe_components(run_dir / "esa_cci", all_models_preds)

Saving data: 100%|██████████| 1/1 [00:00<00:00, 27.12it/s]


In [None]:
train = all_train_test["7_day_smooth_sm"]["train"]
times = train["times"].astype("datetime64[ns]")
sids = train["station_ids"]
np.unique(times).shape
np.unique(times).shape

# Explore results

### plot timeseries

In [None]:
# time = slice("2000", "2007")
# time = slice("1998", "2008")

# N = 3
# # pixels = [27030, 38012, 39017]
# # pixels = [61001, 14001, 95001]
# # pixels = [34012, 52010, 85003]
# # pixels = [int(min_station), int(med_station), int(max_station)]

# with plt.rc_context({"figure.dpi": 400}):
#     for px in pixels:
#         f, ax = plt.subplots(1, 1, figsize=(12, 4), sharex=True)
#         for ix, target_var in enumerate(target_vars):
# #             ax = axs[np.unravel_index(ix, (2, 2))]
#             preds = nl_models_preds[target_var]["preds"]
#             data = preds.sel(station_id=px, time=time)

#         #     f, ax = plt.subplots(figsize=(12, 4))
#             ax.plot(data.time, data.obs, color="k", ls="--", alpha=0.3, label="Observed")
# #             ax = ax.twinx()
#             ax.plot(data.time, data.sim, color=f"C{ix}", ls="-", alpha=0.6, label="Simulated")
#             ax.set_title(f"{target_var}")
#             if ix == 0:
#                 ax.legend()
#             sns.despine()
#         f.suptitle(px)

In [None]:
target_vars = [v for v in list(all_models_preds.keys()) if v != ""]
target_vars = ["7_day_smooth_sm"]

# p = all_models_preds[target_vars[0]]["preds"]
# pixels = np.random.choice(p.station_id.values, N, replace=False)

time = slice("2000", "2008")
N = 3

pixels = [61001, 14001, 95001]
pixels = [54018, 15021]
# pixels = [54018, 15021, 48003]
# pixels = [33032, 39108, 85003]


with plt.rc_context({"figure.dpi": 400}):
    for px in pixels:
        f, axs = plt.subplots(1, 1, figsize=(12, 4), sharex=True)
        for ix, target_var in enumerate(target_vars):
            ax = axs[ix if len(axs.shape) == 1 else np.unravel_index(ix, (2, 2))] if isinstance(axs, list) else axs
            preds = all_models_preds[target_var]["preds"]            
            
            data = preds.sel(station_id=px, time=time)

            ax.plot(data.time, data.obs, color="k", ls="--", alpha=0.3, label=f"Observation")
            ax.plot(data.time, data.sim, color=f"C{ix}", ls="-", alpha=0.6, label=f"Simulation")
            ax.set_title(px)
            ax.legend()
            sns.despine()

### plot nse distributions

In [None]:
metric = "Pearson-r"

with plt.rc_context({"figure.dpi": 400}):
    f, ax = plt.subplots(figsize=(12, 4))

    colors = sns.color_palette("viridis", n_colors=len(target_vars))
    for ix, target_var in enumerate(target_vars):
        errors = all_models_preds[target_var]["errors"]
        nse = errors[metric]

        ax.hist(nse.where(nse > -1, -1), bins=100, density=True, label=f"{target_var}: {nse.median().values:.3f}", alpha=0.6, color=colors[ix]);
        ax.axvline(nse.median(), color=colors[ix], ls="--", alpha=0.5)


    ax.set_xlabel(metric if metric != "Pearson-r" else "Correlation")
    ax.set_xlim(-1, 1)
    ax.legend()
    sns.despine()

### plot the weights

In [None]:
from scripts.cell_state.analysis import get_model_weights, plot_weights

n_plots = len(target_vars)
f, axs = plt.subplots(n_plots, 1, figsize=(12, 2*n_plots))

for ix, target_var in enumerate(target_vars):
    model = all_models_preds[target_var]["model"]
    ax = axs[ix] if isinstance(axs, list) else axs
    w, b = get_model_weights(model)
    plot_weights(np.abs(w), kwargs={"vmin": 0.0, "vmax": 0.15}, ax=ax, cbar=False)
    ax.set_title(f"Target: {target_var}")
plt.tight_layout()

In [None]:
from typing import Any, Dict
import seaborn as sns
import matplotlib.pyplot as plt


def empirical_cdf(errors: np.ndarray, kwargs: Dict[str, Any] = {}):
    x = np.sort(errors)
    y = np.arange(len(x))/float(len(x))
    plt.plot(x, y, **kwargs)

    
f, ax = plt.subplots(figsize=(12, 4))
colors = sns.color_palette("viridis", n_colors=len(target_vars))
for ix, target_var in enumerate(target_vars):
    errors = all_models_preds[target_var]["errors"]
    nse = errors["NSE"]
    empirical_cdf(nse, kwargs={"label": target_var, "color": colors[ix]})

ax.set_xlim(0, 1)
ax.set_ylim(0, 1.1)
plt.legend()
sns.despine()

# What about the spatial performances?

In [None]:
import geopandas as gpd
from scripts.geospatial import initialise_gb_spatial_plot, load_latlon_points
from mpl_toolkits.axes_grid1 import make_axes_locatable

In [None]:
points = load_latlon_points(data_dir)

In [None]:
metric = "Pearson-r"
for target_var in target_vars:
    # initialise plotting data
    errors = all_models_preds[target_var]["errors"]
    gdf = gpd.GeoDataFrame(errors[metric].to_dataframe().join(points))
    
    # create plot
    ax = initialise_gb_spatial_plot()

    cbar = False
    vmax = 0.8
    vmin = 0.3

    if cbar:
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        gdf.plot(metric, ax=ax, vmin=vmin, vmax=vmax, cmap="viridis_r", cax=cax, legend=True);
    else:
        gdf.plot(metric, ax=ax, vmin=vmin, vmax=vmax, cmap="viridis_r")
    
    ax.set_title(f"{target_var}")

# Location of basins

In [None]:
# static["soil_depth_pelettier"]

In [None]:
static["log_area"] = np.log(static["area"])
static["log_p_mean"] = np.log(static["p_mean"])


from scripts.geospatial import plot_spatial_location
from scripts.plots import plot_context

pixels = [54018, 15021]

with plt.rc_context({"figure.dpi": 400}):
    for px in pixels:
        # spatial context
        plot_context(static, variables=["log_area", "log_p_mean", "gauge_elev", "pet_mean"], sids=[px])
        name = static.sel(station_id=px)["gauge_name"].values
        plt.gcf().suptitle(f"{px}: {name}")

        # spatial location
        plot_spatial_location(px, points)
        

# Does the model performance change with size?

In [None]:
errors = all_models_preds["7_day_smooth_sm"]["errors"]

test_df = static[["gauge_name", "area"]].to_dataframe().join(errors["Pearson-r"].to_dataframe())
plt.scatter(test_df["area"], test_df["Pearson-r"])

# Visualising individual dimensions

In [None]:
def get_ws_bs_for_target_var(all_models_preds, target_var: str):
    model = all_models_preds[target_var]["model"]
    w, b = get_model_weights(model)
    return w, b


target_var = "7_day_smooth_sm"
w, b = get_ws_bs_for_target_var(all_models_preds, target_var)


max_idx = np.argmax(np.abs(w))
n = 1
largest_n = np.abs(w).argsort()[-n:][::-1]
feature = f"dim{max_idx}"

features = [f"dim{idx}" for idx in largest_n]
assert all(np.isin(features, test_cs.data_vars)), f"Expect {features} to be in cs.data_vars"

In [None]:
time = "2005"

px = 53013
pxs = [px]

Npxs = 2
pxs = np.random.choice(test_cs.station_id.values, Npxs)

pxs = [61001, 14001, 95001]
f, axs = plt.subplots(len(pxs), 1, figsize=(12, 4*len(pxs)), tight_layout=True, sharex=True, sharey=True)
for ix, px in enumerate(pxs):
    ax = axs[ix]
    data = test_cs.sel(station_id=px, time=time).to_dataframe()
    target = esa_ds.sel(station_id=px, time=time).to_dataframe()

    ax.plot(target.index, target[target_var], label=target_var, color="k", ls="--", alpha=0.6)
    ax.legend(loc="upper right")

    ax = ax.twinx()
    for ix, feature in enumerate(features):  # features  ["dim20"]  ["dim58"]
        ax.plot(
            data.index, data[feature], label=f"{feature}: {float(np.median(data[feature])):.2f}", color=f"C{ix}", alpha=0.6
        )
        ax.legend(loc="upper left")
    
    ax.set_title(px)
    sns.despine()

In [None]:
DIM = 20

with plt.rc_context({"figure.dpi": 400}):
    f, ax = plt.subplots(figsize=(12, 4))

    pxs = np.random.choice(test_cs.station_id.values, 10)
    for sid in pxs:
        test_cs[f"dim{DIM}"].sel(station_id=sid).plot(ax=ax, label=sid)
    ax.legend()
    ax.set_title(f"Comparison of Dimension {DIM} across basins\nWhy is it variable in some basins and basically flat in others?")
    sns.despine()

In [None]:
from scripts.geospatial import plot_spatial_location
with plt.rc_context({"figure.dpi": 100}):
    plot_spatial_location(pxs, points=points, plot_kwargs={"color": [f"C{ix}" for ix in np.arange(len(pxs))]})

In [None]:
var_31 = test_cs[f"dim{DIM}"].std(dim="time")
mn_31 = test_cs[f"dim{DIM}"].mean(dim="time")

df = (
    static[cfg.static_attributes].to_dataframe()
    .join(var_31.rename(f"std_{DIM}").to_dataframe())
    .join(mn_31.rename(f"mn_{DIM}").to_dataframe())
)

In [None]:
# correlations ?
corrs = pd.DataFrame({v: df[f"mn_{DIM}"].corr(df[v]) for v in df.columns if v not in [f'std_{DIM}', f'mn_{DIM}']}, index=[f"mn_{DIM}_corr"]).T.reset_index()
# var = f"std_{DIM}"
corrs = (
    pd.DataFrame(
        {v: df[f"std_{DIM}"].corr(df[v]) for v in df.columns if v not in [f'std_{DIM}', f'mn_{DIM}']}, 
        index=[f"std_{DIM}_corr"]
    ).T.reset_index()
    .join(corrs.drop("index", axis=1))
)

with plt.rc_context({"figure.dpi": 400}):
    f, axs = plt.subplots(2,1,figsize=(12, 4*2), tight_layout=True)
    
    ax = axs[0]
    sns.barplot(x="index", y=f"mn_{DIM}_corr", data=corrs.sort_values(f"mn_{DIM}_corr"), palette="RdBu_r", ax=ax)
    ax.set_title(f"The mean value of Dimension{DIM} is positively correlated with wetness, elevation, steep slopes")
    ax.set_xticklabels(ax.get_xticklabels(), rotation=40)
    
    ax = axs[1]
    sns.barplot(x="index", y=f"std_{DIM}_corr", data=corrs.sort_values(f"std_{DIM}_corr"), palette="RdBu_r", ax=ax)
    ax.set_title(f"The std value of Dimension{DIM} is positively correlated with agriculture, dryness, low elevation")
    
    plt.xticks(rotation=40)
        
    sns.despine()

In [None]:
gdf_31 = gpd.GeoDataFrame(test_cs[f"dim{DIM}"].median(dim="time").to_dataframe().join(points))
gdf_31 = gdf_31.join(test_cs[f"dim{DIM}"].std(dim="time").rename(f"std_dim{DIM}").to_dataframe())
gdf_31[f"abs_dim{DIM}"] = abs(gdf_31[f"dim{DIM}"])
# gdf_31[f"abs_std_dim{DIM}"] = gdf_31[f"std_dim{DIM}"]

f, axs = plt.subplots(1, 2, figsize=(5*2, 8))
ax = axs[0]
mn_vmin, mn_vmax = gdf_31[f"abs_dim{DIM}"].quantile(q=0.1), gdf_31[f"abs_dim{DIM}"].quantile(q=0.7)
initialise_gb_spatial_plot(ax=ax)
gdf_31.plot(f"abs_dim{DIM}", legend=True, ax=ax, vmin=mn_vmin, vmax=mn_vmax)
ax.set_title(f"MEAN Dimension{DIM}")

ax = axs[1]
std_vmin, std_vmax = gdf_31[f"std_dim{DIM}"].quantile(q=0.1), gdf_31[f"std_dim{DIM}"].quantile(q=0.7)
initialise_gb_spatial_plot(ax=ax)
gdf_31.plot(f"std_dim{DIM}", legend=True, ax=ax, vmin=std_vmin, vmax=std_vmax)
ax.set_title(f"STD Dimension{DIM}")

# Non linear probe

In [None]:
if not "nl_models_preds" in globals().keys():
    nl_models_preds = train_and_evaluate_models(all_train_test, hidden_sizes=[20, 10])

In [None]:
with plt.rc_context({'figure.dpi': 400}):
    f, ax = plt.subplots(figsize=(12, 4))

    metric = "Pearson-r"
    colors = sns.color_palette("viridis", n_colors=len(target_vars))
    for ix, target_var in enumerate(target_vars):
        errors = nl_models_preds[target_var]["errors"]
        nse = errors[metric]

        ax.hist(nse.where(nse > -1, -1), bins=100, density=True, label=f"{target_var}: {nse.median().values:.2f}", alpha=0.6, color=colors[ix]);
        ax.axvline(nse.median(), color=colors[ix], ls="--", alpha=0.5)
    ax.set_xlabel(metric if metric != "Pearson-r" else "Pearson Correlation")
    ax.set_xlim(0.25, 1)
    ax.legend()
    sns.despine()

In [None]:
with plt.rc_context({'figure.dpi': 400}):
    f, ax = plt.subplots(figsize=(12, 4))

    metric = "RMSE"
    colors = sns.color_palette("viridis", n_colors=len(target_vars))
    for ix, target_var in enumerate(target_vars):
        errors = nl_models_preds[target_var]["errors"]
        nse = errors[metric]

        ax.hist(nse.where(nse > -1, -1), bins=100, density=True, label=f"{target_var}: {nse.median().values:.2f}", alpha=0.6, color=colors[ix]);
        ax.axvline(nse.median(), color=colors[ix], ls="--", alpha=0.5)
    ax.set_xlabel(metric if metric != "Pearson-r" else "Pearson Correlation")
    ax.set_xlim(0, 1)
    ax.legend()
    sns.despine()

In [None]:
time = slice("2000", "2007")
time = slice("1998", "2008")

N = 3
# pixels = [27030, 38012, 39017]
# pixels = [61001, 14001, 95001]
# pixels = [34012, 52010, 85003]
# pixels = [54018, 15021, 48003]
pixels = [33032, 39108, 85003]
pixels = [54018, 15021]
# pixels = [int(min_station), int(med_station), int(max_station)]

with plt.rc_context({"figure.dpi": 400}):
    for px in pixels:
        f, ax = plt.subplots(1, 1, figsize=(12, 4), sharex=True)
        for ix, target_var in enumerate(target_vars):
#             ax = axs[np.unravel_index(ix, (2, 2))]
            preds = nl_models_preds[target_var]["preds"]
            data = preds.sel(station_id=px, time=time)

        #     f, ax = plt.subplots(figsize=(12, 4))
            ax.plot(data.time, data.obs, color="k", ls="--", alpha=0.3, label="Observed")
#             ax = ax.twinx()
            ax.plot(data.time, data.sim, color=f"C{ix}", ls="-", alpha=0.6, label="Simulated")
            #ax.set_title(f"{target_var}")
            if ix == 0:
                ax.legend()
            sns.despine()
        f.suptitle(px)

# Generating INTERACTIVE maps of Soil Moisture

In [None]:
from scripts.cell_state.timeseries_model import _round_time_to_hour

gdf = gpd.read_file(data_dir / "CAMELS_GB_DATASET/Catchment_Boundaries/CAMELS_GB_catchment_boundaries.shp")

In [None]:
LEVEL = target_vars[-1]
preds = all_models_preds[LEVEL]["preds"]
preds["time"] = _round_time_to_hour(preds["time"].values)

In [None]:
d = preds["sim"]

# pivot table simulations
pixel_dim = "station_id"
time_dim = "time"
new_ds = xr.Dataset(
    {
        f"{i}": ((pixel_dim), d.values[:, i])
        for i in range(len(d[time_dim].values))
    },
    coords={pixel_dim: d[pixel_dim]},
)
df = new_ds.to_dataframe()

In [None]:
time_gdf = gpd.GeoDataFrame(df.join(gdf.set_index("ID")))
mean_std = time_gdf[[f"{ts}" for ts in np.arange(822)]].std().mean()
vmax = time_gdf[[f"{ts}" for ts in np.arange(822)]].mean().max() + mean_std
vmin = time_gdf[[f"{ts}" for ts in np.arange(822)]].mean().min() - mean_std

(vmin, vmax)

In [None]:
lookup_times = dict(enumerate(d[time_dim].values))

In [None]:
# https://www.earthdatascience.org/courses/scientists-guide-to-plotting-data-in-python/plot-spatial-data/customize-vector-plots/python-customize-map-legends-geopandas/
# ax = initialise_gb_spatial_plot()
# 

if False:
    (data_dir / f"animate_{LEVEL}").mkdir(exist_ok=True)
    from tqdm import tqdm

    with plt.style.context("dark_background"):
        for i in tqdm(np.arange(822)):
            f, ax = plt.subplots(figsize=(5, 8))
            time_gdf.plot(f"{i}", ax=ax, vmin=vmin, vmax=vmax, legend=True)
            ax.axis("off")
            ax.set_title(np.datetime_as_string(lookup_times[i], unit="D"))
            f.savefig(data_dir / f"animate/{i:03}.png")
            del f, ax
            plt.close("all")

# Pytorch Probe

In [None]:
assert False, "Do not run automagically"

In [None]:


train_dataset = TimeSeriesDataset(
    input_data=input_data,
    target_data=target_data,
    target_variable="swvl1",
    input_variables=input_variables,
    seq_length=seq_length,
    basin_dim=basin_dim,
    time_dim=time_dim,
    desc="Creating Train Samples",
)

test_dataset = TimeSeriesDataset(
    input_data=test_input_data,
    target_data=test_target_data,
    target_variable="swvl1",
    input_variables=input_variables,
    seq_length=seq_length,
    basin_dim=basin_dim,
    time_dim=time_dim,
    desc="Creating Test Samples",
)

In [None]:
from torch.utils.data import DataLoader
import torch
from torch import nn
from scripts.cell_state.cell_state_model import LinearModel

model = LinearModel(D_in=64)

In [None]:
batch_size = int(1e4)
num_workers = 0
learning_rate = 1e-4
l2_penalty = 0

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
)
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
)
print(len(train_loader))
print(len(test_loader))

In [None]:
# GET optimizer
optimizer = torch.optim.Adam(
    model.parameters(), lr=learning_rate, weight_decay=l2_penalty
)

# GET loss
loss_fn = nn.MSELoss()

n_epochs = 10

### Train the linear model

In [None]:
from tqdm import tqdm


l1_ratio = 0.15
epoch_losses = []
mean_epoch_loss = 9999
for epoch in range(n_epochs):
    losses = []
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    for data in pbar:
        pbar.set_postfix_str(f"{mean_epoch_loss:.2f}")
        y_pred = model(data)
        y = data["y"].squeeze(1)

        # calculate loss
        loss = loss_fn(y_pred, y)

        # add regularisation terms
        # l1 loss-penalty (1st order magnitude, vector of weights)
        loss = loss + (regularization_lambda * torch.norm(torch.cat([param.view(-1) for param in model.parameters()]), p=1))
        # l2 loss-penalty (2nd order magnitude, vector of weights)
        loss = loss + ((1 - regularization_lambda) * torch.square(torch.norm(torch.cat([param.view(-1) for param in model.parameters()]), p=2)))
    
        losses.append(loss.detach().numpy())
        
        # train/update the weight
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    mean_epoch_loss = float(np.mean(losses))
    epoch_losses.append(mean_epoch_loss)

In [None]:
plt.plot(epoch_losses)

In [None]:
ws = [p for p in model.parameters()][0].detach().numpy()
b = [p for p in model.parameters()][1].detach().numpy()

ext = np.concatenate([ws, b])
f, ax = plt.subplots(1, 1, figsize=(12, 2*1))
plot_weights(np.abs(ext), kwargs={"vmin": 0.0, "vmax": 0.3}, ax=ax, cbar=False)

In [None]:
from scripts.cell_state.timeseries_model import _round_time_to_hour
from scripts.cell_state.cell_state_model import to_xarray

# def evaluate(test_loader: Dataloader) -> xr.Dataset:
predictions = defaultdict(list)   # : DefaultDict[str, List]

with torch.no_grad():
    for data in tqdm(test_loader, "Evaluation Forward Pass"):
        y_hat = model(data).squeeze()
        y = data["y"].squeeze()
        basin, time = data["meta"]["spatial_unit"].numpy(), data["meta"]["time"].numpy()
        
        #  Coords / Dimensions
        predictions["time"].extend(_round_time_to_hour(pd.to_datetime([t[0] for t in time.astype("datetime64[ns]")])))
        predictions["station_id"].extend(basin)

        # Variables
        predictions["y_hat"].extend(y_hat.detach().cpu().numpy().flatten())
        predictions["y"].extend(y.detach().cpu().numpy().flatten())


In [None]:
model_p = to_xarray(predictions)

In [None]:
f, ax = plt.subplots(figsize=(12, 4))
data = model_p.isel(station_id=100).drop("station_id")
ax.plot(data.time, data["y_hat"], color="C0", label="sim")
ax.plot(data.time, data["y"], color="k", label="obs", alpha=0.5, ls="--")
# data.to_dataframe().drop("station_id", axis=1).plot(ax=ax)
sns.despine()