In [None]:
import numpy as np
import torch
from numpy.testing import assert_almost_equal, assert_array_almost_equal

from pism_emulator.models.debm import DEBMModel, TorchDEBMModel

In [None]:
debm = DEBMModel()

In [None]:
    eccentricity = np.array([0.0167, 0.03])
    obliquity = np.deg2rad(np.array([23.14, 22.10]))
    perihelion_longitude = np.deg2rad(np.array([102.94719, -44.3]))
    debm = DEBMModel()
    d = debm.distance_factor_paleo(eccentricity, perihelion_longitude, obliquity)


In [None]:
    solar_constant = np.array([1361.0])
    distance_factor = np.array([1.1])
    hour_angle = np.array([0.8])
    latitude = np.array([np.pi / 4])
    declination = np.array([np.pi / 8])

    debm = DEBMModel()
    insolation = debm.insolation(solar_constant, distance_factor, hour_angle, latitude, declination)


In [None]:
time = np.array([2022.25])
debm = DEBMModel()
debm.orbital_parameters(time)

In [None]:
debm = DEBMModel(paleo_enabled=True)
debm.orbital_parameters(time)

In [None]:
debm = DEBMModel()
elevation = np.array([0.0, 1000.0, 2000.0])
transmissivity = debm.atmosphere_transmissivity(elevation)


In [None]:
transmissivity

In [None]:
melt_rate = np.array([1.])
debm = DEBMModel()
albedo = debm.albedo(melt_rate)


In [None]:
albedo

In [None]:
#    pism_config:surface.debm_simple.albedo_slope_units = "m2 s kg-1";


In [None]:

import numpy as np
import torch
from numpy.testing import assert_array_almost_equal
import xarray as xr

from pism_emulator.models.pdd import ReferencePDDModel, TorchPDDModel


def make_fake_climate() -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Make fake climate to test surface models
    """

    temp = np.array(
        [
            [-3.12],
            [-2.41],
            [-0.62],
            [1.93],
            [4.41],
            [6.20],
            [6.91],
            [6.21],
            [4.40],
            [1.92],
            [-0.61],
            [-2.41],
        ],
    )
    precip = np.array(
        [
            [1.58],
            [1.47],
            [1.18],
            [0.79],
            [0.39],
            [0.11],
            [-0.01],
            [0.10],
            [0.39],
            [0.79],
            [1.18],
            [1.47],
        ],
        dtype=np.float64,
    )
    sd = np.array(
        [
            [0.0],
            [0.18],
            [0.70],
            [1.40],
            [2.11],
            [2.61],
            [2.81],
            [2.61],
            [2.10],
            [1.40],
            [0.72],
            [0.18],
        ],
    )
    return temp, precip, sd


def make_fake_climate_2d(filename=None):
    """Create an artificial temperature and precipitation file.

    This function is used if pypdd.py is called as a script without an input
    file. The file produced contains an idealized, three-dimensional (t, x, y)
    distribution of near-surface air temperature, precipitation rate and
    standard deviation of near-surface air temperature to be read by
    `PDDModel.nco`.

    filename: str, optional
        Name of output file.
    """

    ATTRIBUTES = {

        # coordinate variables
        'x': {
            'axis': 'X',
            'long_name': 'x-coordinate in Cartesian system',
            'standard_name': 'projection_x_coordinate',
            'units': 'm'},
        'y': {
            'axis': 'Y',
            'long_name': 'y-coordinate in Cartesian system',
            'standard_name': 'projection_y_coordinate',
            'units': 'm'},
        'time': {
            'axis': 'T',
            'long_name': 'time',
            'standard_name': 'time',
            'bounds': 'time_bounds',
            'units': 'yr'},
        'time_bounds': {},

        # climatic variables
        'temp': {
            'long_name': 'near-surface air temperature',
            'units':     'degC'},
        'prec': {
            'long_name': 'ice-equivalent precipitation rate',
            'units':     'm yr-1'},
        'stdv': {
            'long_name': 'standard deviation of near-surface air temperature',
            'units':     'K'},

        # cumulative quantities
        'smb': {
            'standard_name': 'land_ice_surface_specific_mass_balance',
            'long_name': 'cumulative ice-equivalent surface mass balance',
            'units':     'm yr-1'},
        'pdd': {
            'long_name': 'cumulative number of positive degree days',
            'units':     'degC day'},
        'accu': {
            'long_name': 'cumulative ice-equivalent surface accumulation',
            'units':     'm'},
        'snow_melt': {
            'long_name': 'cumulative ice-equivalent surface melt of snow',
            'units':     'm'},
        'ice_melt': {
            'long_name': 'cumulative ice-equivalent surface melt of ice',
            'units':     'm'},
        'melt': {
            'long_name': 'cumulative ice-equivalent surface melt',
            'units':     'm'},
        'runoff': {
            'long_name': 'cumulative ice-equivalent surface meltwater runoff',
            'units':     'm yr-1'},

        # instantaneous quantities
        'inst_pdd': {
            'long_name': 'instantaneous positive degree days',
            'units':     'degC day'},
        'accu_rate': {
            'long_name': 'instantaneous ice-equivalent surface accumulation rate',
            'units':     'm yr-1'},
        'snow_melt_rate': {
            'long_name': 'instantaneous ice-equivalent surface melt rate of snow',
            'units':     'm yr-1'},
        'ice_melt_rate': {
            'long_name': 'instantaneous ice-equivalent surface melt rate of ice',
            'units':     'm yr-1'},
        'melt_rate': {
            'long_name': 'instantaneous ice-equivalent surface melt rate',
            'units':     'm yr-1'},
        'runoff_rate': {
            'long_name': 'instantaneous ice-equivalent surface runoff rate',
            'units':     'm yr-1'},
        'inst_smb': {
            'long_name': 'instantaneous ice-equivalent surface mass balance',
            'units':     'm yr-1'},
        'snow_depth': {
            'long_name': 'depth of snow cover',
            'units':     'm'}}

    # FIXME code could be simplified a lot more but we need a better test not
    # relying on exact reproducibility of this toy climate data.

    # assign coordinate values
    lx = ly = 750000
    x = xr.DataArray(np.linspace(-lx, lx, 201, dtype='f4'), dims='x')
    y = xr.DataArray(np.linspace(-ly, ly, 201, dtype='f4'), dims='y')
    time = xr.DataArray((np.arange(12, dtype='f4')+0.5) / 12, dims='time')
    tboundsvar = np.empty((12, 2), dtype='f4')
    tboundsvar[:, 0] = time[:] - 1.0/24
    tboundsvar[:, 1] = time[:] + 1.0/24

    # seasonality index from winter to summer
    season = xr.DataArray(-np.cos(np.arange(12)*2*np.pi/12), dims='time')

    # order of operation is dictated by test md5sum and legacy f4 dtype
    temp = 5 * season - 10 * x / lx + 0 * y
    prec = y / ly * (season.astype('f4') + 0 * x + np.sign(y))
    stdv = (2+y/ly-x/lx) * (1+season)

    # this is also why transpose is needed here, and final type conversion
    temp = temp.transpose('time', 'x', 'y').astype('f4')
    prec = prec.transpose('time', 'x', 'y').astype('f4')
    stdv = stdv.transpose('time', 'x', 'y').astype('f4')

    # assign variable attributes
    temp.attrs.update(ATTRIBUTES['temp'])
    prec.attrs.update(ATTRIBUTES['prec'])
    stdv.attrs.update(ATTRIBUTES['stdv'])

    # make a dataset
    ds = xr.Dataset(
        data_vars={'temp': temp, 'prec': prec, 'stdv': stdv},
        coords={
            'time': time, 'x': x, 'y': y,
            'time_bounds': (['time', 'nv'], tboundsvar[:]),
        },
    )

    # write dataset to file
    if filename is not None:
        ds.to_netcdf(filename)

    # return dataset
    return ds

def test_torch_model():
    """
    Test the TorchPDDModel by comparing it to the ReferencePDDModel
    """
    temp, precip, sd = make_fake_climate()

    pdd_ref = ReferencePDDModel(
        pdd_factor_snow=0.003,
        pdd_factor_ice=0.008,
        refreeze_snow=0.6,
        refreeze_ice=0.1,
        temp_snow=0.0,
        temp_rain=2.0,
        interpolate_rule="linear",
        interpolate_n=52,
    )
    result_ref = pdd_ref(temp, precip, sd)

    pdd_torch = TorchPDDModel(
        pdd_factor_snow=3.0,
        pdd_factor_ice=8.0,
        refreeze_snow=0.6,
        refreeze_ice=0.1,
        temp_snow=0.0,
        temp_rain=2.0,
        interpolate_rule="linear",
        interpolate_n=52,
    )
    result_torch = pdd_torch.forward(temp, precip, sd)

    for m_var in [
        "temp",
        "prec",
        "accumulation_rate",
        "inst_pdd",
        "snow_depth",
        "snow_melt_rate",
        "ice_melt_rate",
        "melt_rate",
        "smb",
    ]:
        print(f"Comparing Reference and Torch implementation for variable {m_var}")
        assert_array_almost_equal(result_ref[m_var], result_torch[m_var], decimal=3)

def test_torch_model_2d():
    """
    Test the TorchPDDModel by comparing it to the ReferencePDDModel
    """
    ds = make_fake_climate_2d()
    
    temp = ds["temp"].to_numpy()
    precip = ds["prec"].to_numpy()
    sd = ds["stdv"].to_numpy()

    pdd_ref = ReferencePDDModel(
        pdd_factor_snow=0.003,
        pdd_factor_ice=0.008,
        refreeze_snow=0.6,
        refreeze_ice=0.1,
        temp_snow=0.0,
        temp_rain=2.0,
        interpolate_rule="linear",
        interpolate_n=52,
    )
    result_ref = pdd_ref(temp, precip, sd)

    pdd_torch = TorchPDDModel(
        pdd_factor_snow=3.0,
        pdd_factor_ice=8.0,
        refreeze_snow=0.6,
        refreeze_ice=0.1,
        temp_snow=0.0,
        temp_rain=2.0,
        interpolate_rule="linear",
        interpolate_n=52,
    )
    result_torch = pdd_torch.forward(temp, precip, sd)

    for m_var in [
        "temp",
        "prec",
        "accumulation_rate",
        "inst_pdd",
        "snow_depth",
        "snow_melt_rate",
        "ice_melt_rate",
        "melt_rate",
        "smb",
    ]:
        print(f"Comparing Reference and Torch implementation for variable {m_var}")
        assert_array_almost_equal(result_ref[m_var], result_torch[m_var], decimal=3)



In [None]:
    ds = make_fake_climate_2d()
    
    temp = ds["temp"].to_numpy()
    temp_K = ds["temp"].to_numpy() + 273.15
    precip = ds["prec"].to_numpy()
    sd = ds["stdv"].to_numpy()

    pdd_ref = ReferencePDDModel(
        pdd_factor_snow=0.003,
        pdd_factor_ice=0.008,
        refreeze_snow=0.6,
        refreeze_ice=0.1,
        temp_snow=0.0,
        temp_rain=2.0,
        interpolate_rule="linear",
        interpolate_n=52,
    )
    result_ref = pdd_ref(temp, precip, sd)

In [None]:
import pylab as plt

plt.imshow(result_ref["snow_depth"][0])
plt.colorbar()

In [None]:
debm = DEBMModel()
surface_elevation = np.zeros_like(temp) + 10
latitude = np.zeros_like(temp) 
albedo = np.zeros_like(temp) + 0.47
result_debm = debm(temp_K, sd, precip, surface_elevation, latitude)

In [None]:
result_debm = debm(temp_K, sd, precip, surface_elevation, latitude)

In [None]:
plt.imshow(result_debm["temperature"][0]-273.15, origin="lower")
plt.colorbar()

In [None]:
diff = result_ref["temp"][0]-(result_debm["temperature"][0]-273.15)
plt.imshow(diff, origin="lower")
plt.colorbar()
print(diff.min(), diff.max())

In [None]:
diff = result_ref["snow_depth"][0]-(result_debm["snow_depth"][0])
plt.imshow(diff, origin="lower")
plt.colorbar()
print(diff.min(), diff.max())

In [None]:
diff = result_ref["accumulation"]-(result_debm["accumulation"])
plt.imshow(diff, origin="lower")
plt.colorbar()
print(diff.min(), diff.max())

In [None]:
m_vars = ["accumulation", "melt", "runoff", "smb"]
fig, axs = plt.subplots(len(m_vars),3, sharex=True, figsize=(24,24))
for k, m_var in enumerate(m_vars):
    axs[k, 0].imshow(result_ref[m_var], origin="lower")
    axs[k, 1].imshow(result_debm[m_var], origin="lower")
    dc = axs[k, 2].imshow(result_ref[m_var]-result_debm[m_var], origin="lower")
    plt.colorbar(dc)

In [None]:
diff = (result_debm["insolation_melt"] + result_debm["temperature_melt"] + result_debm["offset_melt"]) 
plt.imshow(diff, origin="lower", vmin=0)
plt.colorbar()
print(diff.min(), diff.max())

In [None]:
diff = result_ref["melt"] / result_debm["melt"]
plt.imshow(diff, origin="lower")
plt.colorbar()
print(diff.min(), diff.max())

In [None]:
plt.imshow(result_ref["melt"], origin="lower")
plt.colorbar()

In [None]:
diff = result_debm["melt"]
plt.imshow(diff, origin="lower")
plt.colorbar()

In [None]:
result_debm["melt"].shape

In [None]:
ds["temp"].isel(time=0).plot()

In [None]:
plt.imshow(temp[0], vmin=-15, vmax=15, origin="lower", cmap="RdBu_r")
plt.colorbar()

In [None]:
year_fraction = 0
dt = 1/ 12
temp = 323.0
temp_sd = 12.0
s = 1000
lat = np.pi/4 * 3
albedo = 0.47
melt_info = debm.melt(temp, temp_sd, albedo, s, lat, year_fraction, dt)


In [None]:
make_fake_climate?

In [None]:
debm.CalovGreveIntegrand(temp_sd, temp_sd - debm.positive_threshold_temp)

In [None]:
melt_info["total_melt"]

In [None]:
melt_info

In [None]:
plt.imshow(result_debm["accumulation"])
plt.colorbar()

In [None]:
plt.imshow(result_ref["melt_rate"][0])
plt.colorbar()

In [None]:
plt.imshow(result_ref["prec"][0])
plt.colorbar()

In [None]:
plt.imshow(result_debm["precipitation"][0])
plt.colorbar()

In [None]:
snow_acc = pdd_ref.accumulation_rate(result_debm["temperature"][0], result_debm["precipitation"][0])
plt.imshow(snow_acc)
plt.colorbar()

In [None]:
snow_acc = debm.snow_accumulation(result_debm["temperature"][0], result_debm["precipitation"][0])
plt.imshow(snow_acc)
plt.colorbar()

In [None]:
result_debm["temperature"][0, 0, 0], result_debm["precipitation"][0, 0, 0]

In [None]:
pdd_ref.accumulation_rate(result_debm["temperature"][0, 0, 0], result_debm["precipitation"][0, 0, 0])

In [None]:
debm.snow_accumulation(result_debm["temperature"][0, 0, 0], result_debm["precipitation"][0, 0, 0])

In [None]:
T = np.array([-10, -5, 0, 1, 4, 8])
P = np.array([10, 0.2, 1.0, 0.2, 0.1, 0.4])

In [None]:
pdd_ref.accumulation_rate(T, P)

In [None]:
debm.snow_accumulation(T+273.15, P)

In [None]:
result_debm["snow_depth"][1]

In [None]:
from typing import Union

In [None]:
import torch

ds = make_fake_climate_2d()
    
temp = ds["temp"].to_numpy()
temp_K = ds["temp"].to_numpy() + 273.15
precip = ds["prec"].to_numpy()
sd = ds["stdv"].to_numpy()


temp_K_t = torch.from_numpy(temp_K)
sd_t = torch.from_numpy(sd)
precip_t = torch.from_numpy(precip)



In [None]:
debm = DEBMModel()
surface_elevation_t = torch.zeros_like(temp_K_t) + 10.0
latitude_t = torch.zeros_like(temp_K_t) 
albedo_t = torch.zeros_like(temp_K_t) + 0.47
debm_torch = TorchDEBMModel()
result_debm = debm_torch(temp_K_t, sd_t, precip_t, surface_elevation_t, latitude_t)

In [None]:
plt.imshow(result_debm["melt"], origin="lower")
plt.colorbar()

In [None]:
[s for s in torch.rand_like(temp_K).size()[1::]]

In [None]:
sd.shape

In [None]:
debm_torch?

In [None]:
    def step(
        max_melt: np.ndarray, old_snow_depth: np.ndarray, accumulation: np.ndarray
    ) -> dict:
        snow_depth = old_snow_depth
        snow_depth += accumulation

        snow_melted = np.where(max_melt < 0, 0.0, max_melt)
        snow_melted = np.where(max_melt <= snow_depth, snow_melted, snow_depth)
        ice_melted = np.minimum(max_melt - snow_melted, snow_depth)
        snow_depth = np.maximum(snow_depth - snow_melted, 0.0)
        snow_depth -= old_snow_depth
        ice_melted = max_melt - snow_melted
        total_melt = snow_melted + ice_melted
        ice_created_by_refreeze = refreeze * snow_melted
        runoff = total_melt - ice_created_by_refreeze
        smb = accumulation - runoff
        result = {
            "snow_depth": snow_depth,
            "melt": total_melt,
            "runoff": runoff,
            "smb": smb,
        }
        return result


In [None]:
    max_melt = 2.0
    old_snow_depth = 1.0
    accumulation = 0.1

    refreeze: float = 0.6,


In [None]:
step(max_melt, old_snow_depth, accumulation)

In [None]:
    max_melt = np.array([2.0])
    old_snow_depth = np.array([1.0])
    accumulation = np.array([0.1])


In [None]:
step(max_melt, old_snow_depth, accumulation)

In [None]:
np.ones((12, 1, 1))

In [None]:
 ds = xr.open_dataset("../data/observed_speeds/greenland_vel_mosaic250_v1_g9000m.nc", decode_times=False)

In [None]:
def preprocess(ds, thinning_factor: int = 1, mapplane_vars: list[str] = ["x", "y"]):
    """
    Select slices from dataset
    """
    slices = {key: slice(0, value, thinning_factor) for key, value in ds.sizes.items()}
    drop_dims = [key for (key, val) in slices.items() if key not in mapplane_vars]
    for d in drop_dims:
        del slices[d]
    return ds.isel(slices)


In [None]:

ds = preprocess(ds, thinning_factor=1)


In [None]:
ds.variables["velsurf_mag"].squeeze()

In [None]:
import numpy as np

In [None]:
a = np.array([1, 2, 3])
b = np.array([2, 3, 4])

In [None]:
def magnitude(a, b):
    return np.sqrt(np.sum(a**2 + b**2))

In [None]:
magnitude(a, b)

In [None]:
# Copyright (C) 2019-21 Andy Aschwanden
#
# This file is part of pism-emulator.
#
# PISM-EMULATOR is free software; you can redistribute it and/or modify it under the
# terms of the GNU General Public License as published by the Free Software
# Foundation; either version 3 of the License, or (at your option) any later
# version.
#
# PISM-EMULATOR is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License
# along with PISM; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA

# utils.py contains generic functions to read data or perform statistical analyses.

import sys
from math import sqrt
from os import mkdir
from os.path import isdir, join
from typing import Union

import numpy as np
import pandas as pd
import pylab as plt
import xarray as xr
from matplotlib.colors import LogNorm
from pyDOE2 import lhs
from SALib.sample import saltelli
from scipy.stats.distributions import gamma, randint, truncnorm, uniform
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

np.random.seed(0)


param_keys_dict = {
    "GCM": "GCM (1)",
    "FICE": "$f_i$ (mm K$^{-1}$ day$^{-1}$)",
    "FSNOW": "$f_s$ (mm K$^{-1}$ day$^{-1}$)",
    "RFR": "$\psi (1)$",
    "PRS": "$\omega$ (% K$^{-1}$)",
    "OCM": "$m_{t}$ (1)",
    "OCS": "$m_{x}$ (1)",
    "TCT": "$h_{\mathrm{min}}$ (1)",
    "VCM": "$\sigma_{\mathrm{max}}$ (MPa)",
    "SIAE": "$E_{\mathrm{SIA}}$ (1)",
    "SSAN": "$n_{\mathrm{SSA}}$ (1)",
    "TEFO": "$\delta$ (1)",
    "PPQ": "$q$ (1)",
    "PHIMIN": "$\phi_{\mathrm{min}}$ ($^{\circ}$)",
    "PHIMAX": "$\phi_{\mathrm{max}}$ ($^{\circ}$)",
    "ZMIN": "$z_{\mathrm{min}}$ (m)",
    "ZMAX": "$z_{\mathrm{max}}$ (m)",
    "a_glen": "A (Pa^{-n} s^{-1})",
    "sia_e": "$E_{\mathrm{SIA}}$ (1)",
    "ssa_e": "$E_{\mathrm{SSA}}$ (1)",
    "ssa_n": "$n_{\mathrm{SSA}}$ (1)",
    "ppq": "$q$ (1)",
    "tefo": "$\delta$ (1)",
    "till_effective_fraction_overburden": "$\delta$ (1)",
    "pseudo_plastic_uthershold": "u_{\mathrm{thr} (m yr^{-1})}",
    "phi_min": "$\phi_{\mathrm{min}}$ ($^{\circ}$)",
    "z_min": "$z_{\mathrm{min}}$ (m)",
    "z_max": "$z_{\mathrm{max}}$ (m)",
    "pseudo_plastic_uthreshold": "$u_{\mathrm{th}}$ (m yr$^{-1}$)",
    "SIAe": "$E_{\mathrm{SIA}}$ (1)",
    "SSAe": "$E_{\mathrm{SSA}}$ (1)",
    "topg_to_phi_base": "$b_{\mathrm{base}}$ (m)",
    "topg_to_phi_range": "$b_{\mathrm{range}}$ (m)",
}


def load_hirham_climate(file="DMI-HIRHAM5_1980_MM.nc", thinning_factor=1):
    """
    Read and return Obs
    """

    with xr.open_dataset(file) as Obs:
        stacked = Obs.stack(z=("rlat", "rlon"))
        ncl_stacked = Obs.stack(z=("ncl4", "ncl5"))

        temp = stacked.tas.dropna(dim="z").values - 273.15
        rainfall = stacked.rainfall.dropna(dim="z").values * 365.242198781 / 1000
        snowfall = stacked.snfall.dropna(dim="z").values * 365.242198781 / 1000
        smb = stacked.gld.dropna(dim="z").values * 365.242198781 / 1000 / 12
        refreeze = ncl_stacked.rfrz.dropna(dim="z").values * 365.242198781 / 1000 / 12
        melt = stacked.snmel.dropna(dim="z").values * 365.242198781 / 1000 / 12
        runoff = stacked.rogl.dropna(dim="z").values * 365.242198781 / 1000 / 12
        precip = rainfall + snowfall

    return (
        temp[..., ::thinning_factor],
        precip[..., ::thinning_factor],
        snowfall.sum(axis=0)[::thinning_factor],
        melt.sum(axis=0)[::thinning_factor],
        runoff.sum(axis=0)[::thinning_factor],
        refreeze.sum(axis=0)[::thinning_factor],
        smb.sum(axis=0)[::thinning_factor],
    )


def load_hirham_climate_w_std_dev(
    file="DMI-HIRHAM5_1980_2020_MMS.nc", thinning_factor=1
):
    """
    Read and return HIRHAM5 data grouped by year

    n: monthly forcing (12

    Returns

    temp (n, m) array
    precip (n, m) array
    std_dev (n, m) array
    a (1, m) array
    m (1, m) array
    r (1, m) array
    f (1, m) array
    b (1, m) array

    """

    with xr.open_dataset(file) as Obs:
        Obs = Obs.isel(rlat=slice(0, nlat, thinning_factor), rlon=slice(0, nlon, thinning_factor), ncl4=slice(0, nlat, thinning_factor), ncl5=slice(0, nlon, thinning_factor))
        stacked = Obs.stack(z=("rlat", "rlon"))
        ncl_stacked = Obs.stack(z=("ncl4", "ncl5"))

        temp = (
            np.hstack(
                [d.dropna(dim="z").values for _, d in stacked.tas.groupby("time.year")]
            )
            - 273.15
        )
        temp_std_dev = np.hstack(
            [
                d.dropna(dim="z").values
                for _, d in stacked.tas_std_dev.groupby("time.year")
            ]
        )
        rainfall = (
            np.hstack(
                [
                    d.dropna(dim="z").values
                    for _, d in stacked.rainfall.groupby("time.year")
                ]
            )
            * 365.242198781
            / 1000
        )
        snowfall = (
            np.hstack(
                [
                    d.dropna(dim="z").values
                    for _, d in stacked.snfall.groupby("time.year")
                ]
            )
            * 365.242198781
            / 1000
        )
        smb = (
            np.hstack(
                [d.dropna(dim="z").values for _, d in stacked.gld.groupby("time.year")]
            )
            * 365.242198781
            / 1000
            / 12
        )
        refreeze = (
            np.hstack(
                [
                    d.dropna(dim="z").values
                    for _, d in ncl_stacked.rfrz.groupby("time.year")
                ]
            )
            * 365.242198781
            / 1000
            / 12
        )
        snowmelt = (
            np.hstack(
                [
                    d.dropna(dim="z").values
                    for _, d in stacked.snmel.groupby("time.year")
                ]
            )
            * 365.242198781
            / 1000
            / 12
        )
        snowdepth = np.hstack(
            [d.dropna(dim="z").values for _, d in stacked.sn.groupby("time.year")]
        )
        runoff = (
            np.hstack(
                [d.dropna(dim="z").values for _, d in stacked.rogl.groupby("time.year")]
            )
            * 365.242198781
            / 1000
            / 12
        )
        precip = rainfall + snowfall

        obs = {
            "snow_depth": snowdepth
            - snowdepth[0] ,
            "accumulation": snowfall.sum(axis=0),
            "melt": snowmelt.sum(axis=0),
            "runoff": runoff.sum(axis=0),
            "refreeze": refreeze.sum(axis=0),
            "smb": smb.sum(axis=0),
        }

    return (
        temp,
        precip,
        temp_std_dev,
        obs,
    )


In [None]:
file="../pddsampler/DMI-HIRHAM5_1980_2020_MMS.nc"
ds = xr.open_dataset(file)

In [None]:
load_hirham_climate_w_std_dev(file, thinning_factor=100)

In [None]:
ds

In [None]:
slice?

In [None]:
nlat = len(ds["rlat"])
nlon = len(ds["rlon"])

In [None]:
thinning_factor = 2

In [None]:
ds.isel(rlat=slice(0, nlat, thinning_factor), rlon=slice(0, nlon, thinning_factor), ncl4=slice(0, nlat, thinning_factor), ncl5=slice(0, nlon, thinning_factor))

In [None]:
.i?