In [1]:
# Hack to import from a parent directory
import sys
path = '..'
if path not in sys.path:
    sys.path.append(path)

In [2]:
from typing import Dict, Tuple, Union
import os
import time
from time import process_time

#PyData imports
import numpy as np

#Torch-related imports
import torch

#Module-specific imports
from obs_and_flow import LowerBound
from TruncatedNormal import *
from mean_field import *
from SBM_SDE_classes import *
from SBM_SDE_classes_multi_x import *
from SBM_SDE_classes_optim import *
from SBM_SDE_classes_multi_x_optim import *
from training import *

'''
This script includes the linear and Arrhenius temperature dependence functions to induce temperature-based forcing in differential equation soil biogeochemical models (SBMs). It also includes the SBM SDE classes corresponding to the various parameterizations of the stochastic conventional (SCON), stochastic AWB (SAWB), and stochastic AWB-equilibrium chemistry approximation (SAWB-ECA) for incorporation with normalizing flow "neural stochastic differential equation" solvers. The following SBM SDE system parameterizations are contained in this script:
    1) SCON constant diffusion (SCON-C)
    2) SCON state scaling diffusion (SCON-SS)
    3) SAWB constant diffusion (SAWB-C)
    4) SAWB state scaling diffusion (SAWB-SS)
    5) SAWB-ECA constant diffusion (SAWB-ECA-C)
    6) SAWB-ECA state scaling diffusion (SAWB-ECA-SS)
The respective analytical steady state estimation functions derived from the deterministic ODE versions of the stochastic SBMs are no longer included in this script, as we are no longer initiating SBMs at steady state before starting simulations.
'''

DictOfTensors = Dict[str, torch.Tensor]
Number = Union[int, float]
TupleOfTensors = Tuple[torch.Tensor, torch.Tensor]

In [3]:
#PyTorch settings
torch.manual_seed(0)
print('cuda device available?: ', torch.cuda.is_available())
active_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
torch.set_printoptions(precision = 8)

cuda device available?:  False


In [4]:
#Neural SDE parameters
dt_flow = 1.0 #Increased from 0.1 to reduce memory.
t = 1000 #In hours.
n = int(t / dt_flow) + 1
t_span = np.linspace(0, t, n)
t_span_tensor = torch.reshape(torch.Tensor(t_span), [1, n, 1]).to(active_device) #T_span needs to be converted to tensor object. Additionally, facilitates conversion of I_S and I_D to tensor objects.
state_dim_SCON = 3 #Not including CO2 in STATE_DIM, because CO2 is an observation.
state_dim_SAWB = 4 #Not including CO2 in STATE_DIM, because CO2 is an observation.
state_dim_SAWB_ECA = 4

#Training parameters
niter = 1000
batch_size = 5
num_sequences = 10
prior_scale_factor = 0.333 #Proportion of prior standard deviation to prior means.

#Specify desired SBM SDE model type and details.
theta_dist = TruncatedNormal #String needs to be exact name of the distribution class. Options are 'TruncatedNormal' and 'RescaledLogitNormal'.

#SBM temperature forcing parameters
temp_ref = 283.
temp_rise = 5. #High estimate of 5 celsius temperature rise by 2100.

In [5]:
#Generate exogenous input vectors.
#Obtain temperature forcing function.
temp_tensor = temp_gen(t_span_tensor, temp_ref, temp_rise).to(active_device)

#Obtain SOC and DOC pool litter input vectors for use in flow SDE functions.
i_s_tensor = i_s(t_span_tensor).to(active_device) #Exogenous SOC input function
i_d_tensor = i_d(t_span_tensor).to(active_device) #Exogenous DOC input function

## Generate dummy $\theta$ and $x$

In [6]:
#SCON drift parameters
u_M_mean = 0.0016
a_SD_mean = 0.5
a_DS_mean = 0.5
a_M_mean = 0.5
a_MSC_mean = 0.5
k_S_ref_mean = 0.0005
k_D_ref_mean = 0.0008
k_M_ref_mean = 0.0007
Ea_S_mean = 55
Ea_D_mean = 48
Ea_M_mean = 48

#SAWB drift parameters
u_Q_ref_mean = 0.25
Q_mean = 0.001
a_MSA_mean = 0.5
K_D_mean = 1000
K_U_mean = 0.1
V_D_ref_mean = 0.38
V_U_ref_mean = 0.04
Ea_V_D_mean = 55
Ea_V_U_mean = 50
r_M_mean = 0.002
r_E_mean = 0.000024
r_L_mean = 0.000015

#SAWB-ECA drift parameters
K_DE_mean = 1000
K_UE_mean = 0.1
V_DE_ref_mean = 0.38
V_UE_ref_mean = 0.04
Ea_V_DE_mean = 55
Ea_V_UE_mean = 50

#C diffusion parameters
c_SOC_mean = 0.5
c_DOC_mean = 0.01
c_MBC_mean = 0.01
c_EEC_mean = 0.001

#SS diffusion parameters
s_SOC_mean = 0.01
s_DOC_mean = 0.01
s_MBC_mean = 0.01
s_EEC_mean = 0.001

#SCON theta truncated normal distribution parameter details in order of mean, sdev, lower, and upper.
u_M_details = torch.Tensor([u_M_mean, u_M_mean * prior_scale_factor, 0, 1])
a_SD_details = torch.Tensor([a_SD_mean, a_SD_mean * prior_scale_factor, 0, 1])
a_DS_details = torch.Tensor([a_DS_mean, a_DS_mean * prior_scale_factor, 0, 1])
a_M_details = torch.Tensor([a_M_mean, a_M_mean * prior_scale_factor, 0, 1])
a_MSC_details = torch.Tensor([a_MSC_mean, a_MSC_mean * prior_scale_factor, 0, 1])
k_S_ref_details = torch.Tensor([k_S_ref_mean, k_S_ref_mean * prior_scale_factor, 0, 1])
k_D_ref_details = torch.Tensor([k_D_ref_mean, k_D_ref_mean * prior_scale_factor, 0, 1])
k_M_ref_details = torch.Tensor([k_M_ref_mean, k_M_ref_mean * prior_scale_factor, 0, 1])
Ea_S_details = torch.Tensor([Ea_S_mean, Ea_S_mean * prior_scale_factor, 10, 100])
Ea_D_details = torch.Tensor([Ea_D_mean, Ea_D_mean * prior_scale_factor, 10, 100])
Ea_M_details = torch.Tensor([Ea_M_mean, Ea_M_mean * prior_scale_factor, 10, 100])

#SAWB theta truncated normal distribution parameter details in order of mean, sdev, lower, and upper.
u_Q_ref_details = torch.Tensor([u_Q_ref_mean, u_Q_ref_mean * prior_scale_factor, 0, 1])
Q_details = torch.Tensor([Q_mean, Q_mean * prior_scale_factor, 0, 1])
a_MSA_details = torch.Tensor([a_MSA_mean, a_MSA_mean * prior_scale_factor, 0, 1])
K_D_details = torch.Tensor([K_D_mean, K_D_mean * prior_scale_factor, 0, 10000])
K_U_details = torch.Tensor([K_U_mean, K_U_mean * prior_scale_factor, 0, 1])
V_D_ref_details = torch.Tensor([V_D_ref_mean, V_D_ref_mean * prior_scale_factor, 0, 5])
V_U_ref_details = torch.Tensor([V_U_ref_mean, V_U_ref_mean * prior_scale_factor, 0, 1])
Ea_V_D_details = torch.Tensor([Ea_V_D_mean, Ea_V_D_mean * prior_scale_factor, 10, 150])
Ea_V_U_details = torch.Tensor([Ea_V_U_mean, Ea_V_U_mean * prior_scale_factor, 10, 150])
r_M_details = torch.Tensor([r_M_mean, r_M_mean * prior_scale_factor, 0, 1])
r_E_details = torch.Tensor([r_E_mean, r_E_mean * prior_scale_factor, 0, 1])
r_L_details = torch.Tensor([r_L_mean, r_L_mean * prior_scale_factor, 0, 1])

#SAWB-ECA theta truncated normal distribution parameter details in order of mean, sdev, lower, and upper.
K_DE_details = torch.Tensor([K_DE_mean, K_DE_mean * prior_scale_factor, 0, 10000])
K_UE_details = torch.Tensor([K_UE_mean, K_UE_mean * prior_scale_factor, 0, 1])
V_DE_ref_details = torch.Tensor([V_DE_ref_mean, V_DE_ref_mean * prior_scale_factor, 0, 5])
V_UE_ref_details = torch.Tensor([V_UE_ref_mean, V_UE_ref_mean * prior_scale_factor, 0, 1])
Ea_V_DE_details = torch.Tensor([Ea_V_DE_mean, Ea_V_DE_mean * prior_scale_factor, 10, 150])
Ea_V_UE_details = torch.Tensor([Ea_V_UE_mean, Ea_V_UE_mean * prior_scale_factor, 10, 150])

#C diffusion matrix parameter distribution details
c_SOC_details = torch.Tensor([c_SOC_mean, c_SOC_mean * prior_scale_factor, 0, 1])
c_DOC_details = torch.Tensor([c_DOC_mean, c_DOC_mean * prior_scale_factor, 0, 1])
c_MBC_details = torch.Tensor([c_MBC_mean, c_MBC_mean * prior_scale_factor, 0, 1])
c_EEC_details = torch.Tensor([c_EEC_mean, c_EEC_mean * prior_scale_factor, 0, 1])

#SS diffusion matrix parameter distribution details
s_SOC_details = torch.Tensor([s_SOC_mean, s_SOC_mean * prior_scale_factor, 0, 1])
s_DOC_details = torch.Tensor([s_DOC_mean, s_DOC_mean * prior_scale_factor, 0, 1])
s_MBC_details = torch.Tensor([s_MBC_mean, s_MBC_mean * prior_scale_factor, 0, 1])
s_EEC_details = torch.Tensor([s_EEC_mean, s_EEC_mean * prior_scale_factor, 0, 1])

theta_hyperparams = {'u_M': u_M_details, 'a_SD': a_SD_details, 'a_DS': a_DS_details, 'a_M': a_M_details, 'a_MSC': a_MSC_details, 'k_S_ref': k_S_ref_details, 'k_D_ref': k_D_ref_details, 'k_M_ref': k_M_ref_details, 'Ea_S': Ea_S_details, 'Ea_D': Ea_D_details, 'Ea_M': Ea_M_details,
                     'u_Q_ref': u_Q_ref_details, 'Q': Q_details, 'a_MSA': a_MSA_details, 'K_D': K_DE_details, 'K_U': K_UE_details, 'V_D_ref': V_DE_ref_details, 'V_U_ref': V_UE_ref_details, 'Ea_V_D': Ea_V_DE_details, 'Ea_V_U': Ea_V_UE_details, 'r_M': r_M_details, 'r_E': r_E_details, 'r_L': r_L_details,
                     'K_DE': K_DE_details, 'K_UE': K_UE_details, 'V_DE_ref': V_DE_ref_details, 'V_UE_ref': V_UE_ref_details, 'Ea_V_DE': Ea_V_DE_details, 'Ea_V_UE': Ea_V_UE_details, 
                     'c_SOC': s_SOC_details, 'c_DOC': s_DOC_details, 'c_MBC': s_MBC_details, 'c_EEC': c_EEC_details, 's_SOC': s_SOC_details, 's_DOC': s_DOC_details, 's_MBC': s_MBC_details, 's_EEC': s_EEC_details}
theta_samples = {k: theta_dist(loc=loc, scale=scale, a=a, b=b).sample((batch_size, )) for k, (loc, scale, a, b) in theta_hyperparams.items()}

In [7]:
SOC = D.normal.Normal(loc=torch.linspace(45, 20, n), scale=1.0).sample((batch_size, ))
DOC = D.normal.Normal(loc=torch.linspace(1, 5, n), scale=0.1).sample((batch_size, ))
MBC = D.normal.Normal(loc=torch.linspace(1, 5, n), scale=0.1).sample((batch_size, ))
EEC = D.normal.Normal(loc=torch.linspace(0.01, 0.02, n), scale=0.001).sample((batch_size, ))
x = torch.stack((SOC, DOC, MBC), -1)
x1 = torch.stack((SOC, DOC, MBC, EEC), -1)
x.shape, x1.shape

(torch.Size([5, 1001, 3]), torch.Size([5, 1001, 4]))

In [8]:
x_multi = x.unsqueeze(1) * torch.ones((batch_size, num_sequences, n, state_dim_SCON))
x_multi1 = x1.unsqueeze(1) * torch.ones((batch_size, num_sequences, n, state_dim_SAWB))
x_multi.shape, x_multi1.shape

(torch.Size([5, 10, 1001, 3]), torch.Size([5, 10, 1001, 4]))

In [9]:
def test_time(f, x=x):
    t0 = time.time()
    for i in range(niter):
        f(x, theta_samples)
    t = time.time() - t0
    print('Total time: {}, time/iter: {}'.format(t, t/niter))

## SCON-C

In [24]:
diffusion_type = 'C'
SCON_C_test_optim = SCON_optim(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)
SCON_C_test = SCON(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)
SCON_C_test_multi = SCON_multi(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)
SCON_C_test_multi1 = SCON_multi_optim(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)

### `drift_diffusion`

In [None]:
drift_optim, diffusion_optim = SCON_C_test_optim.drift_diffusion(x, theta_samples)
drift, diffusion = SCON_C_test.drift_diffusion(x, theta_samples)

In [None]:
diffusion_optim.shape == (batch_size, 1, state_dim_SCON, state_dim_SCON)

In [None]:
torch.all(drift_optim == drift), torch.all(diffusion_optim == diffusion)

In [None]:
test_time(SCON_C_test_optim.drift_diffusion)
test_time(SCON_C_test.drift_diffusion)

### `drift_diffusion_multi`

In [25]:
drift_multi_optim, diffusion_multi_optim = SCON_C_test_optim.drift_diffusion_multi(x_multi, theta_samples)
drift_multi, diffusion_multi = SCON_C_test_multi.drift_diffusion(x_multi, theta_samples)
drift_multi1, diffusion_multi1 = SCON_C_test_multi1.drift_diffusion(x_multi, theta_samples)

In [None]:
torch.all(drift_multi_optim == drift_multi), torch.all(diffusion_multi_optim == diffusion_multi)

In [26]:
torch.all(drift_multi_optim == drift_multi1), torch.all(diffusion_multi_optim == diffusion_multi1)

(tensor(True), tensor(True))

In [None]:
torch.all(drift_multi_optim == drift_optim.unsqueeze(1)), torch.all(diffusion_multi_optim == diffusion_optim.unsqueeze(1))

In [None]:
test_time(SCON_C_test_optim.drift_diffusion_multi, x_multi)
test_time(SCON_C_test_multi.drift_diffusion, x_multi)
test_time(SCON_C_test_multi_optim.drift_diffusion, x_multi)

### `drift_diffusion_add_CO2`

In [None]:
drift_optim_alt, diffusion_optim_alt, x_add_CO2_optim_alt = SCON_C_test_optim.drift_diffusion_add_CO2(x, theta_samples)
drift_alt, diffusion_alt, x_add_CO2_alt = SCON_C_test.drift_diffusion_add_CO2(x, theta_samples)

In [None]:
torch.all(drift_optim_alt == drift_alt), torch.all(diffusion_optim_alt == diffusion_alt), torch.all(x_add_CO2_optim_alt == x_add_CO2_alt)

In [None]:
torch.all(drift_optim_alt == drift_optim), torch.all(diffusion_optim_alt == diffusion_optim)

In [None]:
test_time(SCON_C_test_optim.drift_diffusion_add_CO2)
test_time(SCON_C_test.drift_diffusion_add_CO2)

### `add_CO2`

In [None]:
x_add_CO2_optim = SCON_C_test_optim.add_CO2(x, theta_samples)
x_add_CO2 = SCON_C_test.add_CO2(x, theta_samples)

In [None]:
torch.all(x_add_CO2_optim == x_add_CO2)

In [None]:
torch.all(x_add_CO2_optim_alt == x_add_CO2_optim)

In [None]:
test_time(SCON_C_test_optim.add_CO2)
test_time(SCON_C_test.add_CO2)

## SCON-SS

In [None]:
diffusion_type = 'SS'
SCON_SS_test_optim = SCON_optim(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)
SCON_SS_test = SCON(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)
SCON_SS_test_multi = SCON_multi(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)
SCON_SS_test_multi1 = SCON_multi_optim(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)

### `drift_diffusion`

In [None]:
drift_optim, diffusion_optim = SCON_SS_test_optim.drift_diffusion(x, theta_samples)
drift, diffusion = SCON_SS_test.drift_diffusion(x, theta_samples)

In [None]:
diffusion_optim.shape == (batch_size, n - 1, state_dim_SCON, state_dim_SCON)

In [None]:
torch.all(drift_optim == drift), torch.all(diffusion_optim == diffusion)

In [None]:
test_time(SCON_SS_test_optim.drift_diffusion)
test_time(SCON_SS_test.drift_diffusion)

### `drift_diffusion_multi`

In [None]:
drift_multi_optim, diffusion_multi_optim = SCON_SS_test_optim.drift_diffusion_multi(x_multi, theta_samples)
drift_multi, diffusion_multi = SCON_SS_test_multi.drift_diffusion(x_multi, theta_samples)
drift_multi1, diffusion_multi1 = SCON_SS_test_multi1.drift_diffusion(x_multi, theta_samples)

In [None]:
torch.all(drift_multi_optim == drift_multi), torch.all(diffusion_multi_optim == diffusion_multi)

In [None]:
torch.all(drift_multi_optim == drift_multi1), torch.all(diffusion_multi_optim == diffusion_multi1)

In [None]:
torch.all(drift_multi_optim == drift_optim.unsqueeze(1)), torch.all(diffusion_multi_optim == diffusion_optim.unsqueeze(1))

In [None]:
test_time(SCON_SS_test_optim.drift_diffusion_multi, x_multi)
test_time(SCON_SS_test_multi.drift_diffusion, x_multi)
test_time(SCON_SS_test_multi1.drift_diffusion, x_multi)

### `drift_diffusion_add_CO2`

In [None]:
drift_optim_alt, diffusion_optim_alt, x_add_CO2_optim_alt = SCON_SS_test_optim.drift_diffusion_add_CO2(x, theta_samples)
drift_alt, diffusion_alt, x_add_CO2_alt = SCON_SS_test.drift_diffusion_add_CO2(x, theta_samples)

In [None]:
torch.all(drift_optim_alt == drift_alt), torch.all(diffusion_optim_alt == diffusion_alt), torch.all(x_add_CO2_optim_alt == x_add_CO2_alt)

In [None]:
torch.all(drift_optim_alt == drift_optim), torch.all(diffusion_optim_alt == diffusion_optim)

In [None]:
test_time(SCON_SS_test_optim.drift_diffusion_add_CO2)
test_time(SCON_SS_test.drift_diffusion_add_CO2)

### `add_CO2`

In [None]:
x_add_CO2_optim = SCON_SS_test_optim.add_CO2(x, theta_samples)
x_add_CO2 = SCON_SS_test.add_CO2(x, theta_samples)

In [None]:
torch.all(x_add_CO2_optim == x_add_CO2)

In [None]:
torch.all(x_add_CO2_optim_alt == x_add_CO2_optim)

In [None]:
test_time(SCON_SS_test_optim.add_CO2)
test_time(SCON_SS_test.add_CO2)

## SAWB-C

In [None]:
diffusion_type = 'C'
SAWB_C_test_optim = SAWB_optim(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)
SAWB_C_test = SAWB(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)
SAWB_C_test_multi = SAWB_multi(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)
SAWB_C_test_multi1 = SAWB_multi_optim(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)

### `drift_diffusion`

In [None]:
drift_optim, diffusion_optim = SAWB_C_test_optim.drift_diffusion(x1, theta_samples)
drift, diffusion = SAWB_C_test.drift_diffusion(x1, theta_samples)

In [None]:
drift_optim.shape == (batch_size, n - 1, state_dim_SAWB), diffusion_optim.shape == (batch_size, 1, state_dim_SAWB, state_dim_SAWB)

In [None]:
torch.all(drift_optim == drift), torch.all(diffusion_optim == diffusion)

In [None]:
test_time(SAWB_C_test_optim.drift_diffusion, x1)
test_time(SAWB_C_test.drift_diffusion, x1)

### `drift_diffusion_multi`

In [None]:
drift_multi_optim, diffusion_multi_optim = SAWB_C_test_optim.drift_diffusion_multi(x_multi1, theta_samples)
drift_multi, diffusion_multi = SAWB_C_test_multi.drift_diffusion(x_multi1, theta_samples)
drift_multi1, diffusion_multi1 = SAWB_C_test_multi1.drift_diffusion(x_multi1, theta_samples)

In [None]:
torch.all(drift_multi_optim == drift_multi), torch.all(diffusion_multi_optim == diffusion_multi)

In [None]:
torch.all(drift_multi_optim == drift_multi1), torch.all(diffusion_multi_optim == diffusion_multi1)

In [None]:
torch.all(drift_multi_optim == drift_optim.unsqueeze(1)), torch.all(diffusion_multi_optim == diffusion_optim.unsqueeze(1))

In [None]:
test_time(SAWB_C_test_optim.drift_diffusion_multi, x_multi1)
test_time(SAWB_C_test_multi.drift_diffusion, x_multi1)
test_time(SAWB_C_test_multi1.drift_diffusion, x_multi1)

### `drift_diffusion_add_CO2`

In [None]:
drift_optim_alt, diffusion_optim_alt, x_add_CO2_optim_alt = SAWB_C_test_optim.drift_diffusion_add_CO2(x1, theta_samples)
drift_alt, diffusion_alt, x_add_CO2_alt = SAWB_C_test.drift_diffusion_add_CO2(x1, theta_samples)

In [None]:
torch.all(drift_optim_alt == drift_alt), torch.all(diffusion_optim_alt == diffusion_alt), torch.all(x_add_CO2_optim_alt == x_add_CO2_alt)

In [None]:
torch.all(drift_optim_alt == drift_optim), torch.all(diffusion_optim_alt == diffusion_optim)

In [None]:
test_time(SAWB_C_test_optim.drift_diffusion_add_CO2, x1)
test_time(SAWB_C_test.drift_diffusion_add_CO2, x1)

### `add_CO2`

In [None]:
x_add_CO2_optim = SAWB_C_test_optim.add_CO2(x1, theta_samples)
x_add_CO2 = SAWB_C_test.add_CO2(x1, theta_samples)

In [None]:
torch.all(x_add_CO2_optim == x_add_CO2)

In [None]:
torch.all(x_add_CO2_optim_alt == x_add_CO2_optim)

In [None]:
test_time(SAWB_C_test_optim.add_CO2, x1)
test_time(SAWB_C_test.add_CO2, x1)

## SAWB-SS

In [None]:
diffusion_type = 'SS'
SAWB_SS_test_optim = SAWB_optim(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)
SAWB_SS_test = SAWB(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)
SAWB_SS_test_multi = SAWB_multi(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)
SAWB_SS_test_multi1 = SAWB_multi_optim(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)

### `drift_diffusion`

In [None]:
drift_optim, diffusion_optim = SAWB_SS_test_optim.drift_diffusion(x1, theta_samples)
drift, diffusion = SAWB_SS_test.drift_diffusion(x1, theta_samples)

In [None]:
drift_optim.shape == (batch_size, n - 1, state_dim_SAWB), diffusion_optim.shape == (batch_size, n - 1, state_dim_SAWB, state_dim_SAWB)

In [None]:
torch.all(drift_optim == drift), torch.all(diffusion_optim == diffusion)

In [None]:
test_time(SAWB_SS_test_optim.drift_diffusion, x1)
test_time(SAWB_SS_test.drift_diffusion, x1)

### `drift_diffusion_multi`

In [None]:
drift_multi_optim, diffusion_multi_optim = SAWB_SS_test_optim.drift_diffusion_multi(x_multi1, theta_samples)
drift_multi, diffusion_multi = SAWB_SS_test_multi.drift_diffusion(x_multi1, theta_samples)
drift_multi1, diffusion_multi1 = SAWB_SS_test_multi1.drift_diffusion(x_multi1, theta_samples)

In [None]:
torch.all(drift_multi_optim == drift_multi), torch.all(diffusion_multi_optim == diffusion_multi)

In [None]:
torch.all(drift_multi_optim == drift_multi1), torch.all(diffusion_multi_optim == diffusion_multi1)

In [None]:
torch.all(drift_multi_optim == drift_optim.unsqueeze(1)), torch.all(diffusion_multi_optim == diffusion_optim.unsqueeze(1))

In [None]:
test_time(SAWB_SS_test_optim.drift_diffusion_multi, x_multi1)
test_time(SAWB_SS_test_multi.drift_diffusion, x_multi1)
test_time(SAWB_SS_test_multi1.drift_diffusion, x_multi1)

### `drift_diffusion_add_CO2`

In [None]:
drift_optim_alt, diffusion_optim_alt, x_add_CO2_optim_alt = SAWB_SS_test_optim.drift_diffusion_add_CO2(x1, theta_samples)
drift_alt, diffusion_alt, x_add_CO2_alt = SAWB_SS_test.drift_diffusion_add_CO2(x1, theta_samples)

In [None]:
torch.all(drift_optim_alt == drift_alt), torch.all(diffusion_optim_alt == diffusion_alt), torch.all(x_add_CO2_optim_alt == x_add_CO2_alt)

In [None]:
torch.all(drift_optim_alt == drift_optim), torch.all(diffusion_optim_alt == diffusion_optim)

In [None]:
test_time(SAWB_SS_test_optim.drift_diffusion_add_CO2, x1)
test_time(SAWB_SS_test.drift_diffusion_add_CO2, x1)

### `add_CO2`

In [None]:
x_add_CO2_optim = SAWB_SS_test_optim.add_CO2(x1, theta_samples)
x_add_CO2 = SAWB_SS_test.add_CO2(x1, theta_samples)

In [None]:
torch.all(x_add_CO2_optim == x_add_CO2)

In [None]:
torch.all(x_add_CO2_optim_alt == x_add_CO2_optim)

In [None]:
test_time(SAWB_SS_test_optim.add_CO2, x1)
test_time(SAWB_SS_test.add_CO2, x1)

## SAWB-ECA-C

In [10]:
diffusion_type = 'C'
SAWB_ECA_C_test_optim = SAWB_ECA_optim(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)
SAWB_ECA_C_test = SAWB_ECA(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)
SAWB_ECA_C_test_multi = SAWB_ECA_multi(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)
SAWB_ECA_C_test_multi1 = SAWB_ECA_multi_optim(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)

### `drift_diffusion`

In [14]:
drift_optim, diffusion_optim = SAWB_ECA_C_test_optim.drift_diffusion(x1, theta_samples)
drift, diffusion = SAWB_ECA_C_test.drift_diffusion(x1, theta_samples)

In [None]:
drift_optim.shape == (batch_size, n-1, state_dim_SAWB), diffusion_optim.shape == (batch_size, 1, state_dim_SAWB, state_dim_SAWB)

In [None]:
torch.all(drift_optim == drift), torch.all(diffusion_optim == diffusion)

In [None]:
test_time(SAWB_ECA_C_test_optim.drift_diffusion, x1)
test_time(SAWB_ECA_C_test.drift_diffusion, x1)

### `drift_diffusion_multi`

In [11]:
drift_multi_optim, diffusion_multi_optim = SAWB_ECA_C_test_optim.drift_diffusion_multi(x_multi1, theta_samples)
drift_multi, diffusion_multi = SAWB_ECA_C_test_multi.drift_diffusion(x_multi1, theta_samples)
drift_multi1, diffusion_multi1 = SAWB_ECA_C_test_multi1.drift_diffusion(x_multi1, theta_samples)

In [12]:
torch.all(drift_multi_optim == drift_multi), torch.all(diffusion_multi_optim == diffusion_multi)

(tensor(True), tensor(True))

In [13]:
torch.all(drift_multi_optim == drift_multi1), torch.all(diffusion_multi_optim == diffusion_multi1)

(tensor(True), tensor(True))

In [15]:
torch.all(drift_multi_optim == drift_optim.unsqueeze(1)), torch.all(diffusion_multi_optim == diffusion_optim.unsqueeze(1))

(tensor(True), tensor(True))

In [16]:
test_time(SAWB_ECA_C_test_optim.drift_diffusion_multi, x_multi1)
test_time(SAWB_ECA_C_test_multi.drift_diffusion, x_multi1)
test_time(SAWB_ECA_C_test_multi1.drift_diffusion, x_multi1)

Total time: 1.9806022644042969, time/iter: 0.0019806022644042968
Total time: 3.1617491245269775, time/iter: 0.0031617491245269774
Total time: 1.8239428997039795, time/iter: 0.0018239428997039796


### `drift_diffusion_add_CO2`

In [None]:
drift_optim_alt, diffusion_optim_alt, x_add_CO2_optim_alt = SAWB_ECA_C_test_optim.drift_diffusion_add_CO2(x1, theta_samples)
drift_alt, diffusion_alt, x_add_CO2_alt = SAWB_ECA_C_test.drift_diffusion_add_CO2(x1, theta_samples)

In [None]:
torch.all(drift_optim_alt == drift_alt), torch.all(diffusion_optim_alt == diffusion_alt), torch.all(x_add_CO2_optim_alt == x_add_CO2_alt)

In [None]:
torch.all(drift_optim_alt == drift_optim), torch.all(diffusion_optim_alt == diffusion_optim)

In [None]:
test_time(SAWB_ECA_C_test_optim.drift_diffusion_add_CO2, x1)
test_time(SAWB_ECA_C_test.drift_diffusion_add_CO2, x1)

### `add_CO2`

In [None]:
x_add_CO2_optim = SAWB_ECA_C_test_optim.add_CO2(x1, theta_samples)
x_add_CO2 = SAWB_ECA_C_test.add_CO2(x1, theta_samples)

In [None]:
torch.all(x_add_CO2_optim == x_add_CO2)

In [None]:
torch.all(x_add_CO2_optim_alt == x_add_CO2_optim)

In [None]:
test_time(SAWB_ECA_C_test_optim.add_CO2, x1)
test_time(SAWB_ECA_C_test.add_CO2, x1)

## SAWB-ECA-SS

In [18]:
diffusion_type = 'SS'
SAWB_ECA_SS_test_optim = SAWB_ECA_optim(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)
SAWB_ECA_SS_test = SAWB_ECA(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)
SAWB_ECA_SS_test_multi = SAWB_ECA_multi(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)
SAWB_ECA_SS_test_multi1 = SAWB_ECA_multi_optim(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)

### `drift_diffusion`

In [None]:
drift_optim, diffusion_optim = SAWB_ECA_SS_test_optim.drift_diffusion(x1, theta_samples)
drift, diffusion = SAWB_ECA_SS_test.drift_diffusion(x1, theta_samples)

In [None]:
drift_optim.shape == (batch_size, n - 1, state_dim_SAWB), diffusion_optim.shape == (batch_size, n - 1, state_dim_SAWB, state_dim_SAWB)

In [None]:
torch.all(drift_optim == drift), torch.all(diffusion_optim == diffusion)

In [None]:
test_time(SAWB_ECA_SS_test_optim.drift_diffusion, x1)
test_time(SAWB_ECA_SS_test.drift_diffusion, x1)

### `drift_diffusion_multi`

In [19]:
drift_multi_optim, diffusion_multi_optim = SAWB_ECA_SS_test_optim.drift_diffusion_multi(x_multi1, theta_samples)
drift_multi, diffusion_multi = SAWB_ECA_SS_test_multi.drift_diffusion(x_multi1, theta_samples)
drift_multi1, diffusion_multi1 = SAWB_ECA_SS_test_multi1.drift_diffusion(x_multi1, theta_samples)

In [20]:
torch.all(drift_multi_optim == drift_multi), torch.all(diffusion_multi_optim == diffusion_multi)

(tensor(True), tensor(True))

In [21]:
torch.all(drift_multi_optim == drift_multi1), torch.all(diffusion_multi_optim == diffusion_multi1)

(tensor(True), tensor(True))

In [None]:
torch.all(drift_multi_optim == drift_optim.unsqueeze(1)), torch.all(diffusion_multi_optim == diffusion_optim.unsqueeze(1))

In [23]:
test_time(SAWB_ECA_SS_test_optim.drift_diffusion_multi, x_multi1)
test_time(SAWB_ECA_SS_test_multi.drift_diffusion, x_multi1)
test_time(SAWB_ECA_SS_test_multi1.drift_diffusion, x_multi1)

Total time: 2.665102958679199, time/iter: 0.0026651029586791993
Total time: 3.2730610370635986, time/iter: 0.0032730610370635986
Total time: 2.869656801223755, time/iter: 0.002869656801223755


### `drift_diffusion_add_CO2`

In [None]:
drift_optim_alt, diffusion_optim_alt, x_add_CO2_optim_alt = SAWB_ECA_SS_test_optim.drift_diffusion_add_CO2(x1, theta_samples)
drift_alt, diffusion_alt, x_add_CO2_alt = SAWB_ECA_SS_test.drift_diffusion_add_CO2(x1, theta_samples)

In [None]:
torch.all(drift_optim_alt == drift_alt), torch.all(diffusion_optim_alt == diffusion_alt), torch.all(x_add_CO2_optim_alt == x_add_CO2_alt)

In [None]:
torch.all(drift_optim_alt == drift_optim), torch.all(diffusion_optim_alt == diffusion_optim)

In [None]:
test_time(SAWB_ECA_SS_test_optim.drift_diffusion_add_CO2, x1)
test_time(SAWB_ECA_SS_test.drift_diffusion_add_CO2, x1)

### `add_CO2`

In [None]:
x_add_CO2_optim = SAWB_ECA_SS_test_optim.add_CO2(x1, theta_samples)
x_add_CO2 = SAWB_ECA_SS_test.add_CO2(x1, theta_samples)

In [None]:
torch.all(x_add_CO2_optim == x_add_CO2)

In [None]:
torch.all(x_add_CO2_optim_alt == x_add_CO2_optim)

In [None]:
test_time(SAWB_ECA_SS_test_optim.add_CO2, x1)
test_time(SAWB_ECA_SS_test.add_CO2, x1)