In [2]:
from pathlib import Path
import numpy as np 
import xarray as xr 
import matplotlib.pyplot as plt
import seaborn as sns
import warnings 
import pandas as pd

import sys
sys.path.append("/home/tommy/neuralhydrology")
from scripts.read_nh_results import (
    get_test_filepath,
    get_all_station_ds,
    calculate_all_error_metrics,
    get_ensemble_path,
)

from scripts.read_model import (get_model, _load_weights)
from scripts.read_nh_results import (read_multi_experiment_results, calculate_member_errors)
from neuralhydrology.utils.config import Config

%load_ext autoreload
%autoreload 2

In [3]:
data_dir = Path("/datadrive/data")
run_dir = data_dir / "runs/complexity_AZURE/hs_064_0306_205514"
out_dir = run_dir / "cell_states"

In [4]:
import geopandas as gpd
from scripts.geospatial import initialise_gb_spatial_plot, load_latlon_points
from mpl_toolkits.axes_grid1 import make_axes_locatable

st_data_dir = Path("/home/tommy/spatio_temporal/data")
points = load_latlon_points(st_data_dir)
static = xr.open_dataset(st_data_dir / "camels_static.nc")

# Input / Target Data

In [5]:
from scripts.cell_state.normalize import normalize_cstate
from scripts.cell_state.cell_state_dataset import dataset_dimensions_to_variable


if not (data_dir / "SOIL_MOISTURE/norm_cs_data_FINAL.nc").exists():
    cn = xr.open_dataset(out_dir / "cell_states.nc")
    norm_cs_data = normalize_cstate(cn, variable_str="c_n")
    norm_cs_data["station_id"] = [int(sid) for sid in norm_cs_data["station_id"]]
    if "date" in norm_cs_data.dims:
        norm_cs_data = norm_cs_data.rename({"date": "time"})

    if isinstance(norm_cs_data, xr.DataArray):
        norm_cs_data = norm_cs_data.to_dataset()

    if "c_n" in [v for v in norm_cs_data.data_vars]:
        norm_cs_data = norm_cs_data.rename({"c_n": "cell_state"})
        
    norm_cs_data.to_netcdf(data_dir / "SOIL_MOISTURE/norm_cs_data_FINAL.nc")
    
else:
    norm_cs_data = xr.open_dataset(data_dir / "SOIL_MOISTURE/norm_cs_data_FINAL.nc")
 

cs = dataset_dimensions_to_variable(
    ds=norm_cs_data, 
    variable="c_n",
    dimension_to_convert_to_variable_dim="dimension",
)

In [6]:
ds = xr.open_dataset(Path("/home/tommy/spatio_temporal/data/ALL_dynamic_ds.nc"))
ds.data_vars

Data variables:
    precipitation   (time, station_id) float64 ...
    pet             (time, station_id) float64 ...
    temperature     (time, station_id) float64 ...
    discharge_spec  (time, station_id) float64 ...
    discharge_vol   (time, station_id) float64 ...
    peti            (time, station_id) float64 ...
    humidity        (time, station_id) float64 ...
    shortwave_rad   (time, station_id) float64 ...
    longwave_rad    (time, station_id) float64 ...
    windspeed       (time, station_id) float64 ...

In [7]:
ds["precipitation"]

# Train the probes

In [11]:
from scripts.cell_state.sklearn_models import (
    init_linear_model,
    evaluate,
    create_analysis_dataset,
    fit_and_predict,
)
from scripts.cell_state.timeseries_dataset import TimeSeriesDataset, get_time_basin_aligned_dictionary
from collections import defaultdict

In [12]:
target_ds = era5_ds
input_ds = cs

train_start_date: pd.Timestamp = pd.to_datetime("1998-01-01")
train_end_date: pd.Timestamp = pd.to_datetime("2006-09-30")
test_start_date: pd.Timestamp = pd.to_datetime("2006-10-01")
test_end_date: pd.Timestamp =  pd.to_datetime("2009-10-01")
seq_length = 1
basin_dim = "station_id"
time_dim = "time"
input_variables = [f"dim{i}" for i in np.arange(64)]

# train test split
target_data = target_ds.sel(time=slice(train_start_date, train_end_date))
input_data = input_ds.sel(time=slice(train_start_date, train_end_date))

test_target_data = target_ds.sel(time=slice(test_start_date, test_end_date))
test_input_data = input_ds.sel(time=slice(test_start_date, test_end_date))


all_train_test = defaultdict(dict)
for target_var in [v for v in target_data.data_vars if "swvl" in v]:
    print(f"** STARTING {target_var} **")
    train_dataset = TimeSeriesDataset(
        input_data=input_data,
        target_data=target_data,
        target_variable=target_var,
        input_variables=input_variables,
        seq_length=seq_length,
        basin_dim=basin_dim,
        time_dim=time_dim,
        desc="Creating Train Samples",
    )

    train = get_time_basin_aligned_dictionary(train_dataset)

    test_dataset = TimeSeriesDataset(
        input_data=test_input_data,
        target_data=test_target_data,
        target_variable=target_var,
        input_variables=input_variables,
        seq_length=seq_length,
        basin_dim=basin_dim,
        time_dim=time_dim,
        desc="Creating Test Samples",
    )

    test = get_time_basin_aligned_dictionary(test_dataset)
    
    all_train_test[target_var]["train"] = train
    all_train_test[target_var]["test"] = test

** STARTING swvl1 **


Creating Train Samples: 100%|██████████| 668/668 [00:09<00:00, 69.89it/s]
Extracting Data: 100%|██████████| 8322/8322 [01:05<00:00, 127.19it/s]


Merging and reshaping arrays


Creating Test Samples: 100%|██████████| 668/668 [00:05<00:00, 114.69it/s]
Extracting Data: 100%|██████████| 2142/2142 [00:16<00:00, 133.00it/s]


Merging and reshaping arrays
** STARTING swvl2 **


Creating Train Samples: 100%|██████████| 668/668 [00:08<00:00, 75.02it/s]
Extracting Data: 100%|██████████| 8322/8322 [01:05<00:00, 127.59it/s]


Merging and reshaping arrays


Creating Test Samples: 100%|██████████| 668/668 [00:05<00:00, 114.54it/s]
Extracting Data: 100%|██████████| 2142/2142 [00:16<00:00, 130.30it/s]


Merging and reshaping arrays
** STARTING swvl3 **


Creating Train Samples: 100%|██████████| 668/668 [00:09<00:00, 72.06it/s]
Extracting Data: 100%|██████████| 8322/8322 [01:03<00:00, 130.27it/s]


Merging and reshaping arrays


Creating Test Samples: 100%|██████████| 668/668 [00:05<00:00, 115.04it/s]
Extracting Data: 100%|██████████| 2142/2142 [00:16<00:00, 128.26it/s]


Merging and reshaping arrays
** STARTING swvl4 **


Creating Train Samples: 100%|██████████| 668/668 [00:09<00:00, 73.15it/s]
Extracting Data: 100%|██████████| 8322/8322 [01:04<00:00, 128.45it/s]


Merging and reshaping arrays


Creating Test Samples: 100%|██████████| 668/668 [00:05<00:00, 114.91it/s]
Extracting Data: 100%|██████████| 2142/2142 [00:16<00:00, 128.61it/s]

Merging and reshaping arrays



