In [19]:
import torch
import torch.distributions as D

import SBM_SDE_tensor
import SBM_SDE
from obs_and_flow import *

In [20]:
torch.manual_seed(0)
devi = torch.device("".join(["cuda:",f'{cuda_id}']) if torch.cuda.is_available() else "cpu")

dt = .1 #SDE discretization timestep.
t = 1 #Simulation run for T hours.
n = int(t / dt) + 1 
t_span = np.linspace(0, t, n)
t_span_tensor = torch.reshape(torch.Tensor(t_span), [1, n, 1]) #T_span needs to be converted to tensor object. Additionally, facilitates conversion of I_S and I_D to tensor objects.
pretrain_lr = 1e-3
elbo_lr = 1e-3
niter = 1000
piter = 2
batch_size = 2 #Number of sets of observation outputs to sample per set of parameters.
state_dim_SCON = 3 #Not including CO2 in STATE_DIM, because CO2 is an observation.
state_dim_SAWB = 4 #Not including CO2 in STATE_DIM, because CO2 is an observation.
obs_error_scale = 0.1

In [21]:
temp_ref = 283
temp_rise = 5

#System parameters from deterministic CON model
u_M = 0.002
a_SD = 0.33
a_DS = 0.33
a_M = 0.33
a_MSC = 0.5
k_S_ref = 0.000025
k_D_ref = 0.005
k_M_ref = 0.0002
Ea_S = 75
Ea_D = 50
Ea_M = 50

#SCON diffusion matrix sigma scale parameters
c_SOC = 1.
c_DOC = 0.01
c_MBC = 0.1
s_SOC = 0.1
s_DOC = 0.1
s_MBC = 0.1

SCON_C_prior_means = {'u_M': u_M, 'a_SD': a_SD, 'a_DS': a_DS, 'a_M': a_M, 'a_MSC': a_MSC, 'k_S_ref': k_S_ref, 'k_D_ref': k_D_ref, 'k_M_ref': k_M_ref, 'Ea_S': Ea_S, 'Ea_D': Ea_D, 'Ea_M': Ea_M, 'c_SOC': c_SOC, 'c_DOC': c_DOC, 'c_MBC': c_MBC}
SCON_SS_prior_means = {'u_M': u_M, 'a_SD': a_SD, 'a_DS': a_DS, 'a_M': a_M, 'a_MSC': a_MSC, 'k_S_ref': k_S_ref, 'k_D_ref': k_D_ref, 'k_M_ref': k_M_ref, 'Ea_S': Ea_S, 'Ea_D': Ea_D, 'Ea_M': Ea_M, 's_SOC': s_SOC, 's_DOC': s_DOC, 's_MBC': s_MBC}

#System parameters from deterministic AWB model
u_Q_ref = 0.2
Q = 0.002
a_MSA = 0.5
K_D = 200
K_U = 1
V_D_ref = 0.4
V_U_ref = 0.02
Ea_V_D = 75
Ea_V_U = 50
r_M = 0.0004
r_E = 0.00001
r_L = 0.0005

#SAWB diffusion matrix sigma scale parameters
c_SOC = 1.
c_DOC = 0.01
c_MBC = 0.1
c_EEC = 0.001
s_SOC = 0.1
s_DOC = 0.1
s_MBC = 0.1
s_EEC = 0.1

SAWB_C_prior_means = {'u_Q_ref': u_Q_ref, 'Q': Q, 'a_MSA': a_MSA, 'K_D': K_D, 'K_U': K_U, 'V_D_ref': V_D_ref, 'V_U_ref': V_U_ref, 'Ea_V_D': Ea_V_D, 'Ea_V_U': Ea_V_U, 'r_M': r_M, 'r_E': r_E, 'r_L': r_L, 'c_SOC': c_SOC, 'c_DOC': c_DOC, 'c_MBC': c_MBC, 'c_EEC': c_EEC}
SAWB_SS_prior_means = {'u_Q_ref': u_Q_ref, 'Q': Q, 'a_MSA': a_MSA, 'K_D': K_D, 'K_U': K_U, 'V_D_ref': V_D_ref, 'V_U_ref': V_U_ref, 'Ea_V_D': Ea_V_D, 'Ea_V_U': Ea_V_U, 'r_M': r_M, 'r_E': r_E, 'r_L': r_L, 's_SOC': s_SOC, 's_DOC': s_DOC, 's_MBC': s_MBC, 's_EEC': s_EEC}

#System parameters from deterministic model
u_Q_ref = 0.2
Q = 0.002
a_MSA = 0.5
K_DE = 200
K_UE = 1
V_DE_ref = 0.4
V_UE_ref = 0.02
Ea_V_DE = 75
Ea_V_UE = 50
r_M = 0.0004
r_E = 0.00001
r_L = 0.0005

#Diffusion matrix sigma scale parameters
c_SOC = 1.
c_DOC = 0.01
c_MBC = 0.1
c_EEC = 0.001
s_SOC = 0.1
s_DOC = 0.1
s_MBC = 0.1
s_EEC = 0.1

SAWB_ECA_C_prior_means = {'u_Q_ref': u_Q_ref, 'Q': Q, 'a_MSA': a_MSA, 'K_DE': K_DE, 'K_UE': K_UE, 'V_DE_ref': V_DE_ref, 'V_UE_ref': V_UE_ref, 'Ea_V_DE': Ea_V_DE, 'Ea_V_UE': Ea_V_UE, 'r_M': r_M, 'r_E': r_E, 'r_L': r_L, 'c_SOC': c_SOC, 'c_DOC': c_DOC, 'c_MBC': c_MBC, 'c_EEC': c_EEC}
SAWB_ECA_SS_prior_means = {'u_Q_ref': u_Q_ref, 'Q': Q, 'a_MSA': a_MSA, 'K_DE': K_DE, 'K_UE': K_UE, 'V_DE_ref': V_DE_ref, 'V_UE_ref': V_UE_ref, 'Ea_V_DE': Ea_V_DE, 'Ea_V_UE': Ea_V_UE, 'r_M': r_M, 'r_E': r_E, 'r_L': r_L, 's_SOC': s_SOC, 's_DOC': s_DOC, 's_MBC': s_MBC, 's_EEC': s_EEC}

In [22]:
#Obtain temperature forcing function.
temp_tensor = SBM_SDE.temp_gen(t_span_tensor, temp_ref, temp_rise)

#Obtain SOC and DOC pool litter input vectors for use in flow SDE functions.
i_s_tensor = SBM_SDE.i_s(t_span_tensor) #Exogenous SOC input function
i_d_tensor = SBM_SDE.i_d(t_span_tensor) #Exogenous DOC input function

In [33]:
x_single = torch.zeros([1, 11, 3])
print(x_single.size())
x_batch = x_single.expand([2, 11, 3]).clone()
print(x_batch.size())

SOC = torch.normal(mean = torch.linspace(45, 46.1, 11), std = 1)
print(SOC)
DOC = torch.normal(mean = torch.linspace(0.07, 0.081, 11), std = 0.01)
print(DOC)
MBC = torch.normal(mean = torch.linspace(0.7, 0.81, 11), std = 0.1)
print(MBC)

torch.Size([1, 11, 3])
torch.Size([2, 11, 3])
tensor([43.7905, 46.4541, 47.6032, 44.7635, 44.2864, 43.0477, 46.5356, 43.0974,
        45.8487, 46.4888, 45.5767])
tensor([0.0675, 0.0605, 0.0666, 0.0721, 0.0728, 0.0730, 0.0657, 0.0777, 0.0811,
        0.0717, 0.0742])
tensor([0.8425, 0.9075, 0.7988, 0.9320, 0.8158, 0.7053, 0.8970, 0.8166, 0.7969,
        0.7822, 0.7627])


In [34]:
x_single[:, :, 0] = SOC
x_single[:, :, 1] = DOC
x_single[:, :, 2] = MBC
print(x_single)

x_batch[:, :, 0] = SOC
x_batch[:, :, 1] = DOC
x_batch[:, :, 2] = MBC
print(x_batch)

tensor([[[43.7905,  0.0675,  0.8425],
         [46.4541,  0.0605,  0.9075],
         [47.6032,  0.0666,  0.7988],
         [44.7635,  0.0721,  0.9320],
         [44.2864,  0.0728,  0.8158],
         [43.0477,  0.0730,  0.7053],
         [46.5356,  0.0657,  0.8970],
         [43.0974,  0.0777,  0.8166],
         [45.8487,  0.0811,  0.7969],
         [46.4888,  0.0717,  0.7822],
         [45.5767,  0.0742,  0.7627]]])
tensor([[[43.7905,  0.0675,  0.8425],
         [46.4541,  0.0605,  0.9075],
         [47.6032,  0.0666,  0.7988],
         [44.7635,  0.0721,  0.9320],
         [44.2864,  0.0728,  0.8158],
         [43.0477,  0.0730,  0.7053],
         [46.5356,  0.0657,  0.8970],
         [43.0974,  0.0777,  0.8166],
         [45.8487,  0.0811,  0.7969],
         [46.4888,  0.0717,  0.7822],
         [45.5767,  0.0742,  0.7627]],

        [[43.7905,  0.0675,  0.8425],
         [46.4541,  0.0605,  0.9075],
         [47.6032,  0.0666,  0.7988],
         [44.7635,  0.0721,  0.9320],
        

In [14]:
SCON_C_prior_means_tensor = {k: torch.tensor(v).expand(batch_size) for k, v in SCON_C_prior_means.items()}
print(SCON_C_prior_means)
print(SCON_C_prior_means_tensor)

{'u_M': 0.002, 'a_SD': 0.33, 'a_DS': 0.33, 'a_M': 0.33, 'a_MSC': 0.5, 'k_S_ref': 2.5e-05, 'k_D_ref': 0.005, 'k_M_ref': 0.0002, 'Ea_S': 75, 'Ea_D': 50, 'Ea_M': 50, 'c_SOC': 1.0, 'c_DOC': 0.01, 'c_MBC': 0.1}
{'u_M': tensor([0.0020, 0.0020]), 'a_SD': tensor([0.3300, 0.3300]), 'a_DS': tensor([0.3300, 0.3300]), 'a_M': tensor([0.3300, 0.3300]), 'a_MSC': tensor([0.5000, 0.5000]), 'k_S_ref': tensor([2.5000e-05, 2.5000e-05]), 'k_D_ref': tensor([0.0050, 0.0050]), 'k_M_ref': tensor([0.0002, 0.0002]), 'Ea_S': tensor([75, 75]), 'Ea_D': tensor([50, 50]), 'Ea_M': tensor([50, 50]), 'c_SOC': tensor([1., 1.]), 'c_DOC': tensor([0.0100, 0.0100]), 'c_MBC': tensor([0.1000, 0.1000])}


In [19]:
#C_path_single = torch.cat([C_0_single[(None,) * 2].repeat(batch_size, 1, 1).to(devi), C_path], 1)
#C_path_batch = torch.cat([C_0_batch.unsqueeze(1), C_path], 1)

In [35]:
drift_single, diffusion_sqrt_single = SBM_SDE.drift_diffusion_SCON_C(x_single[:, :-1, :], t_span_tensor[:, :-1, :], i_s_tensor[:, :-1, :], i_d_tensor[:, :-1, :], temp_tensor[:, :-1, :], temp_ref, SCON_C_prior_means)
print(drift_single)
drift_batch, diffusion_sqrt_batch = SBM_SDE_tensor.drift_diffusion_SCON_C(x_batch[:, :-1, :], t_span_tensor[:, :-1, :], i_s_tensor[:, :-1, :], i_d_tensor[:, :-1, :], temp_tensor[:, :-1, :], temp_ref, SCON_C_prior_means_tensor)
print(drift_batch)

tensor([[[ 4.4392e-05,  1.6678e-05, -3.3531e-05],
         [-6.3708e-05,  9.5449e-05, -6.4015e-05],
         [-1.2061e-04,  6.4407e-05, -3.2951e-05],
         [-6.3565e-05,  9.4547e-06, -5.3504e-05],
         [-8.6241e-05,  1.0708e-06, -3.0946e-05],
         [-8.7628e-05, -1.1515e-05, -9.5743e-06],
         [-2.3146e-04,  9.0107e-05, -7.0261e-05],
         [-1.4378e-04, -3.3844e-05, -3.1678e-05],
         [-2.5975e-04, -2.7796e-05, -2.3933e-05],
         [-3.3771e-04,  5.8758e-05, -4.2833e-05]]])
tensor([[[ 4.4392e-05,  1.6678e-05, -3.3531e-05],
         [-6.3708e-05,  9.5449e-05, -6.4015e-05],
         [-1.2061e-04,  6.4407e-05, -3.2951e-05],
         [-6.3565e-05,  9.4547e-06, -5.3504e-05],
         [-8.6241e-05,  1.0708e-06, -3.0946e-05],
         [-8.7628e-05, -1.1515e-05, -9.5743e-06],
         [-2.3146e-04,  9.0107e-05, -7.0261e-05],
         [-1.4378e-04, -3.3844e-05, -3.1678e-05],
         [-2.5975e-04, -2.7796e-05, -2.3933e-05],
         [-3.3771e-04,  5.8758e-05, -4.2833e-05]