# Load the necessary libraries

In [1]:
import optax
import equinox as eqx

import pickle
from functools import partial

from math import floor

# import numpy as np
import jax
import jax.numpy as jnp
import jax.tree_util as jtu

import pandas as pd

from jax_canoak.physics.energy_fluxes import get_dispersion_matrix

from jax_canoak.models import CanoakIFT
from jax_canoak.models import get_canle, update_canle
from jax_canoak.models import get_soilresp, update_soilresp
from jax_canoak.models import load_model, save_model

# from jax_canoak.shared_utilities.optim import perform_optimization
from jax_canoak.shared_utilities.optim import perform_optimization_batch
from jax_canoak.shared_utilities import compute_metrics
from jax_canoak.subjects import convert_met_to_batched_met
from jax_canoak.subjects import (
    # convert_batchedstates_to_states,
    convert_obs_to_batched_obs,
)
from jax_canoak.subjects import get_met_forcings, get_obs, initialize_parameters
from jax_canoak.models import run_canoak_in_batch

from jax_canoak.models import save_model

import matplotlib.pyplot as plt
from jax_canoak.shared_utilities.plot import (
    plot_daily,
    plot_imshow2,
    plot_timeseries_obs_1to1,
    plot_rad,
    plot_ir,
    visualize_tree_diff,
)
from jax_canoak.shared_utilities.plot import (
    plot_obs_1to1,
    plot_obs_comparison,
    plot_obs_energy_closure,
    plot_dij,
    plot_para_sensitivity_ranking,
    plot_le_gs_lai,
    get_time,
)

# from jax_canoak.shared_utilities.plot import plot_veg_temp, plot_dij
# from jax_canoak.shared_utilities.plot import plot_ir, plot_rad, plot_prof2
from jax_canoak.shared_utilities import tune_jax_naninfs_for_debug

jax.config.update("jax_enable_x64", True)
tune_jax_naninfs_for_debug(False)
# jax.config.update("jax_debug_nans", False)
# jax.config.update("jax_debug_infs", False)
# jax.config.update("XLA_PYTHON_CLIENT_ALLOCATOR", 'platform')

%load_ext autoreload
%autoreload 2


# Model parameters and settings

In [2]:
time_zone = -8
latitude = 46.4089
longitude = -119.2750
stomata = 0
veg_ht = 1.2
# leafangle = 2  # erectophile
leafangle = 1
n_can_layers = 50
n_atmos_layers = 50
meas_ht = 5.0
soil_depth = 0.15
n_hr_per_day = 48
niter = 15
# niter = 1

# batch_size = 2
# batch_size = 1274
# batch_size = int(74496/2)
batch_size, batch_size_test = None, None

site = "US-Hn1"
key = "default"

# Forcing
f_forcing = f"../../data/fluxtower/{site}/{site}-forcings-v2.csv"


In [3]:
df = pd.read_csv(f_forcing)
df.shape


(27743, 14)

# GPU

In [4]:
# import os
# os.environ["JAX_PLATFORM_NAME"] = "gpu"
jax.config.update('jax_platform_name', 'cuda')

## Load and set the model forcings

In [5]:
jax.clear_caches()

In [6]:
met, n_time = get_met_forcings(f_forcing)
if batch_size is None: batch_size = n_time
n_batch = floor(n_time / batch_size)
batched_met = convert_met_to_batched_met(met, n_batch, batch_size)
timesteps = get_time(met)


In [7]:
setup, para, para_min, para_max = initialize_parameters(
    time_zone=time_zone,
    latitude=latitude,
    longitude=longitude,
    stomata=stomata,
    veg_ht=veg_ht,
    leafangle=leafangle,
    n_can_layers=n_can_layers,
    n_atmos_layers=n_atmos_layers,
    meas_ht=meas_ht,
    soil_depth=soil_depth,
    n_hr_per_day=n_hr_per_day,
    n_time=n_time,
    npart=int(1e6),
    obs=None,
    met=met,
    niter=niter,
    get_para_bounds=True,
)


In [8]:
# dispersion matrix
dij = get_dispersion_matrix(setup, para, f"../../data/dij/Dij_{site}.csv")

In [9]:
dij.device()

cuda(id=0)

## Run the model

In [10]:
canoak_eqx_ift = CanoakIFT(para, setup, dij)


In [11]:
%timeit canoak_eqx_ift(met)

1.96 s ± 62 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# CPU

In [4]:
# import os
# os.environ["JAX_PLATFORM_NAME"] = "cpu"
jax.config.update('jax_platform_name', 'cpu')

## Load and set the model forcings

In [5]:
jax.clear_caches()

In [6]:
met, n_time = get_met_forcings(f_forcing)
if batch_size is None: batch_size = n_time
n_batch = floor(n_time / batch_size)
batched_met = convert_met_to_batched_met(met, n_batch, batch_size)
timesteps = get_time(met)

In [7]:
setup, para, para_min, para_max = initialize_parameters(
    time_zone=time_zone,
    latitude=latitude,
    longitude=longitude,
    stomata=stomata,
    veg_ht=veg_ht,
    leafangle=leafangle,
    n_can_layers=n_can_layers,
    n_atmos_layers=n_atmos_layers,
    meas_ht=meas_ht,
    soil_depth=soil_depth,
    n_hr_per_day=n_hr_per_day,
    n_time=n_time,
    npart=int(1e6),
    obs=None,
    met=met,
    niter=niter,
    get_para_bounds=True,
)

In [8]:
# dispersion matrix
dij = get_dispersion_matrix(setup, para, f"../../data/dij/Dij_{site}.csv")

In [9]:
dij.device()

CpuDevice(id=0)

## Run the model

In [10]:
canoak_eqx_ift = CanoakIFT(para, setup, dij)


In [11]:
%timeit canoak_eqx_ift(met)


49.7 s ± 2.02 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
