In [1]:
# Hack to import from a parent directory
import sys
path = '..'
if path not in sys.path:
    sys.path.append(path)

In [2]:
from typing import Dict, Tuple, Union
import os
from time import process_time

#PyData imports
import numpy as np

#Torch-related imports
import torch

#Module-specific imports
from obs_and_flow import LowerBound
from TruncatedNormal import *
from mean_field import *
from SBM_SDE_classes import *
from SBM_SDE_classes_multi_x import *
from training import *

'''
This script includes the linear and Arrhenius temperature dependence functions to induce temperature-based forcing in differential equation soil biogeochemical models (SBMs). It also includes the SBM SDE classes corresponding to the various parameterizations of the stochastic conventional (SCON), stochastic AWB (SAWB), and stochastic AWB-equilibrium chemistry approximation (SAWB) for incorporation with normalizing flow "neural stochastic differential equation" solvers. The following SBM SDE system parameterizations are contained in this script:
    1) SCON constant diffusion (SCON-C)
    2) SCON state scaling diffusion (SCON-SS)
    3) SAWB constant diffusion (SAWB-C)
    4) SAWB state scaling diffusion (SAWB-SS)
    5) SAWB constant diffusion (SAWB-C)
    6) SAWB state scaling diffusion (SAWB-SS)
The respective analytical steady state estimation functions derived from the deterministic ODE versions of the stochastic SBMs are no longer included in this script, as we are no longer initiating SBMs at steady state before starting simulations.
'''

DictOfTensors = Dict[str, torch.Tensor]
Number = Union[int, float]
TupleOfTensors = Tuple[torch.Tensor, torch.Tensor]

In [3]:
#PyTorch settings
torch.manual_seed(0)
print('cuda device available?: ', torch.cuda.is_available())
active_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
torch.set_printoptions(precision = 8)

cuda device available?:  False


In [4]:
#Neural SDE parameters
dt_flow = 1.0 #Increased from 0.1 to reduce memory.
t = 1000 #In hours.
n = int(t / dt_flow) + 1
t_span = np.linspace(0, t, n)
t_span_tensor = torch.reshape(torch.Tensor(t_span), [1, n, 1]).to(active_device) #T_span needs to be converted to tensor object. Additionally, facilitates conversion of I_S and I_D to tensor objects.
state_dim_SCON = 3 #Not including CO2 in STATE_DIM, because CO2 is an observation.
state_dim_SAWB = 4 #Not including CO2 in STATE_DIM, because CO2 is an observation.

#SBM temperature forcing parameters
temp_ref = 283.
temp_rise = 5. #High estimate of 5 celsius temperature rise by 2100.

In [5]:
#Training parameters
niter = 110
train_lr = 2e-5 #ELBO learning rate
batch_size = 20 #3 - number needed to fit UCI HPC3 RAM requirements with 16 GB RAM at t = 5000.
eval_batch_size = 20
obs_error_scale = 0.1 #Observation (y) standard deviation.
prior_scale_factor = 0.333 #Proportion of prior standard deviation to prior means.
num_layers = 5 #5 - number needed to fit UCI HPC3 RAM requirements with 16 GB RAM at t = 5000.

In [6]:
#Generate exogenous input vectors.
#Obtain temperature forcing function.
temp_tensor = temp_gen(t_span_tensor, temp_ref, temp_rise).to(active_device)

#Obtain SOC and DOC pool litter input vectors for use in flow SDE functions.
i_s_tensor = i_s(t_span_tensor).to(active_device) #Exogenous SOC input function
i_d_tensor = i_d(t_span_tensor).to(active_device) #Exogenous DOC input function

In [7]:
#Specify desired SBM SDE model type and details.
state_dim_SAWB = 4
sbm_sde_class = 'SAWB'
diffusion_type = 'SS'
learn_CO2 = True
theta_dist = 'TruncatedNormal' #String needs to be exact name of the distribution class. Options are 'TruncatedNormal' and 'RescaledLogitNormal'.

#Parameter prior means
u_Q_ref_mean = 0.2
Q_mean = 0.001
a_MSA_mean = 0.5
K_D_mean = 1850
K_U_mean = 0.2
V_D_ref_mean = 0.16
V_U_ref_mean = 0.012
Ea_V_D_mean = 65
Ea_V_U_mean = 55
r_M_mean = 0.0018
r_E_mean = 0.00003
r_L_mean = 0.000008
s_SOC_mean = 0.005
s_DOC_mean = 0.005
s_MBC_mean = 0.005
s_EEC_mean = 0.005

#SAWB theta truncated normal distribution parameter details in order of mean, sdev, lower, and upper.
u_Q_ref_details = torch.Tensor([u_Q_ref_mean, u_Q_ref_mean * prior_scale_factor, 0, 1])
Q_details = torch.Tensor([Q_mean, Q_mean * prior_scale_factor, 0, 1])
a_MSA_details = torch.Tensor([a_MSA_mean, a_MSA_mean * prior_scale_factor, 0, 1])
K_D_details = torch.Tensor([K_D_mean, K_D_mean * prior_scale_factor, 0, 10000])
K_U_details = torch.Tensor([K_U_mean, K_U_mean * prior_scale_factor, 0, 100])
V_D_ref_details = torch.Tensor([V_D_ref_mean, V_D_ref_mean * prior_scale_factor, 0, 10])
V_U_ref_details = torch.Tensor([V_U_ref_mean, V_U_ref_mean * prior_scale_factor, 0, 1])
Ea_V_D_details = torch.Tensor([Ea_V_D_mean, Ea_V_D_mean * prior_scale_factor, 10, 150])
Ea_V_U_details = torch.Tensor([Ea_V_U_mean, Ea_V_U_mean * prior_scale_factor, 10, 150])
r_M_details = torch.Tensor([r_M_mean, r_M_mean * prior_scale_factor, 0, 1])
r_E_details = torch.Tensor([r_E_mean, r_M_mean * prior_scale_factor, 0, 1])
r_L_details = torch.Tensor([r_L_mean, r_M_mean * prior_scale_factor, 0, 1])

#SAWB-SS diffusion matrix parameter distribution details
s_SOC_details = torch.Tensor([s_SOC_mean, s_SOC_mean * prior_scale_factor, 0, 1])
s_DOC_details = torch.Tensor([s_DOC_mean, s_DOC_mean * prior_scale_factor, 0, 1])
s_MBC_details = torch.Tensor([s_MBC_mean, s_MBC_mean * prior_scale_factor, 0, 1])
s_EEC_details = torch.Tensor([s_EEC_mean, s_EEC_mean * prior_scale_factor, 0, 1])

SAWB_SS_priors_details = {'u_Q_ref': u_Q_ref_details, 'Q': Q_details, 'a_MSA': a_MSA_details, 'K_D': K_D_details, 'K_U': K_U_details, 'V_D_ref': V_D_ref_details, 'V_U_ref': V_U_ref_details, 'Ea_V_D': Ea_V_D_details, 'Ea_V_U': Ea_V_U_details, 'r_M': r_M_details, 'r_E': r_E_details, 'r_L': r_L_details, 's_SOC': s_SOC_details, 's_DOC': s_DOC_details, 's_MBC': s_MBC_details, 's_EEC': s_EEC_details}

#Initial condition prior means
x0_SAWB = [65, 0.4, 2.5, 0.3]
x0_SAWB_tensor = torch.tensor(x0_SAWB).to(active_device)
x0_prior_SAWB = D.multivariate_normal.MultivariateNormal(x0_SAWB_tensor, scale_tril = torch.eye(state_dim_SAWB).to(active_device) * obs_error_scale * x0_SAWB_tensor)

In [8]:
param_names = list(SAWB_SS_priors_details.keys())
prior_list = list(zip(*(SAWB_SS_priors_details[k] for k in param_names))) #Unzip prior distribution details from dictionary values into individual lists.
prior_means_tensor, prior_sds_tensor, prior_lowers_tensor, prior_uppers_tensor = torch.tensor(prior_list).to(active_device) #Ensure conversion of lists into tensors.
priors = TruncatedNormal(loc = prior_means_tensor, scale = prior_sds_tensor, a = prior_lowers_tensor, b = prior_uppers_tensor)

In [9]:
#x = torch.rand(2, t_span_tensor.size(1), 3)
batch_size = 3

x_single = torch.zeros([1, 1001, 3])
print(x_single.size())
x = x_single.expand([batch_size, 1001, 3]).clone()
print(x.size())

x2_single = torch.zeros([1, 1001, 4])
x2 = x2_single.expand([batch_size, 1001, 4]).clone()

SOC = torch.normal(mean = torch.linspace(45, 20, 1001), std = 1.)
print(SOC)
DOC = torch.normal(mean = torch.linspace(1, 5, 1001), std = 0.1)
print(DOC)
MBC = torch.normal(mean = torch.linspace(1, 5, 1001), std = 0.1)
print(MBC)
EEC = torch.normal(mean = torch.linspace(0.01, 0.02, 1001), std = 0.001)

x[:, :, 0] = SOC
x[:, :, 1] = DOC
x[:, :, 2] = MBC
x = LowerBound.apply(x, 1e-6)
print(x)
print(x.size())

x2[:, :, 0] = SOC
x2[:, :, 1] = DOC
x2[:, :, 2] = MBC
x2[:, :, 3] = EEC
x2 = LowerBound.apply(x2, 1e-6)
print(x2)
print(x2.size())

torch.Size([1, 1001, 3])
torch.Size([3, 1001, 3])
tensor([43.87416077, 43.82263947, 44.69942093,  ..., 20.17641830,
        19.43882751, 19.59896469])
tensor([1.10990882, 1.02039123, 0.96598482,  ..., 4.98133755, 5.12637997,
        5.00414944])
tensor([1.05660713, 1.16628361, 1.00154042,  ..., 5.07834673, 5.07599831,
        4.99099350])
tensor([[[43.87416077,  1.10990882,  1.05660713],
         [43.82263947,  1.02039123,  1.16628361],
         [44.69942093,  0.96598482,  1.00154042],
         ...,
         [20.17641830,  4.98133755,  5.07834673],
         [19.43882751,  5.12637997,  5.07599831],
         [19.59896469,  5.00414944,  4.99099350]],

        [[43.87416077,  1.10990882,  1.05660713],
         [43.82263947,  1.02039123,  1.16628361],
         [44.69942093,  0.96598482,  1.00154042],
         ...,
         [20.17641830,  4.98133755,  5.07834673],
         [19.43882751,  5.12637997,  5.07599831],
         [19.59896469,  5.00414944,  4.99099350]],

        [[43.87416077,  1.1

In [10]:
q_theta_SAWB_SS = MeanField(active_device, param_names, SAWB_SS_priors_details, TruncatedNormal, False)

In [11]:
SAWB_SS_dict_out, SAWB_SS_samples, _, _ = q_theta_SAWB_SS(batch_size)
SAWB_SS_dict_out

{'u_Q_ref': tensor([0.19216022, 0.29522306, 0.20716973], grad_fn=<SqueezeBackward1>),
 'Q': tensor([0.00063818, 0.00179973, 0.00080447], grad_fn=<SqueezeBackward1>),
 'a_MSA': tensor([0.29899603, 0.78918350, 0.54942733], grad_fn=<SqueezeBackward1>),
 'K_D': tensor([2050.52075195, 2170.28173828, 1256.52612305],
        grad_fn=<SqueezeBackward1>),
 'K_U': tensor([0.29699630, 0.12118475, 0.32046181], grad_fn=<SqueezeBackward1>),
 'V_D_ref': tensor([0.13269615, 0.10744360, 0.24209984], grad_fn=<SqueezeBackward1>),
 'V_U_ref': tensor([0.01172355, 0.01322458, 0.01446030], grad_fn=<SqueezeBackward1>),
 'Ea_V_D': tensor([41.85933685, 31.13208771, 80.69321442], grad_fn=<SqueezeBackward1>),
 'Ea_V_U': tensor([59.18062973, 71.61378479, 12.36256027], grad_fn=<SqueezeBackward1>),
 'r_M': tensor([0.00201972, 0.00214535, 0.00134339], grad_fn=<SqueezeBackward1>),
 'r_E': tensor([3.44519591e-04, 5.38871100e-04, 1.47498868e-05],
        grad_fn=<SqueezeBackward1>),
 'r_L': tensor([9.37776407e-04, 4.830

In [12]:
class SBM_SDE:
    '''
    This is the base class for evaluating the SBM SDE SSMs.
    '''

    def __init__(
            self,
            T_SPAN_TENSOR: torch.Tensor,
            I_S_TENSOR: torch.Tensor,
            I_D_TENSOR: torch.Tensor,
            TEMP_TENSOR: torch.Tensor,
            TEMP_REF: Number
            ):
        
        self.times = T_SPAN_TENSOR
        self.i_S = I_S_TENSOR
        self.i_D = I_D_TENSOR
        self.temps = TEMP_TENSOR
        self.temp_ref = TEMP_REF
        
class SAWB(SBM_SDE):
    '''
    Class contains SAWB SDE drift (alpha) and diffusion (beta) equations.
    Constant (C) and state-scaling (SS) diffusion paramterizations are included. DIFFUSION_TYPE must thereby be specified as 'C' or 'SS'. 
    Other diffusion parameterizations are not included.
    '''
    def __init__(
            self,
            T_SPAN_TENSOR: torch.Tensor,
            I_S_TENSOR: torch.Tensor,
            I_D_TENSOR: torch.Tensor,
            TEMP_TENSOR: torch.Tensor,
            TEMP_REF: Number,
            DIFFUSION_TYPE: str
            ):
        super().__init__(T_SPAN_TENSOR, I_S_TENSOR, I_D_TENSOR, TEMP_TENSOR, TEMP_REF)

        if DIFFUSION_TYPE not in {'C', 'SS'}:
            raise NotImplementedError('Other diffusion parameterizations aside from constant (c) or state-scaling (ss) have not been implemented.')

        self.DIFFUSION_TYPE = DIFFUSION_TYPE
        self.state_dim = 4

    def drift_diffusion(
        self,
        C_PATH: torch.Tensor, 
        SAWB_params_dict: DictOfTensors,
        ) -> TupleOfTensors:
        '''
        Accepts states x and dictionary of parameter samples.
        Returns SAWB drift and diffusion tensors corresponding to state values and parameter samples.  
        Expected SAWB_params_dict = {'u_Q_ref': u_Q_ref, 'Q': Q, 'a_MSA': a_MSA, 'K_D': K_D, 'K_U': K_U, 'V_D_ref': V_D_ref, 'V_U_ref': V_U_ref, 'Ea_V_D': Ea_V_D, 'Ea_V_U': Ea_V_U, 'r_M': r_M, 'r_E': r_E, 'r_L': r_L, '[cs]_SOC': [cs]_SOC, '[cs]_DOC': [cs]_DOC, '[cs]_MBC': [cs]_MBC, '[cs]_EEC': [cs]_EEC}
        '''
        #Appropriately index tensors based on order of operations in data generating process.
        c_path_drift_diffusion = C_PATH[:, :-1, :]
        t_span_tensor_drift_diffusion = self.times[:, 1:, :]        
        i_S_tensor_drift_diffusion = self.i_S[:, 1:, :]
        i_D_tensor_drift_diffusion = self.i_D[:, 1:, :]
        temp_tensor_drift_diffusion = self.temps[:, 1:, :]
        #Partition SOC, DOC, MBC, EEC values. Split based on final c_path_drift_diffusion dim, which specifies state variables and is also indexed as dim #2 in tensor.
        SOC, DOC, MBC, EEC = torch.chunk(c_path_drift_diffusion, self.state_dim, -1)
        #Repeat and permute parameter values to match dimension sizes.
        SAWB_params_dict_rep = dict((k, v.repeat(1, t_span_tensor_drift_diffusion.size(1), 1).permute(2, 1, 0)) for k, v in SAWB_params_dict.items())
        #Initiate tensor with same dims as c_path_drift_diffusion to assign drift.
        drift = torch.empty_like(c_path_drift_diffusion, device = C_PATH.device)
        #Decay parameters are forced by temperature changes.
        u_Q = linear_temp_dep(SAWB_params_dict_rep['u_Q_ref'], temp_tensor_drift_diffusion, SAWB_params_dict_rep['Q'], self.temp_ref) #Apply linear temperature-dependence to u_Q.
        V_D = arrhenius_temp_dep(SAWB_params_dict_rep['V_D_ref'], temp_tensor_drift_diffusion, SAWB_params_dict_rep['Ea_V_D'], self.temp_ref) #Apply vectorized temperature-dependent transformation to V_D.
        V_U = arrhenius_temp_dep(SAWB_params_dict_rep['V_U_ref'], temp_tensor_drift_diffusion, SAWB_params_dict_rep['Ea_V_U'], self.temp_ref) #Apply vectorized temperature-dependent transformation to V_U.
        #Drift is calculated.
        drift_SOC = i_S_tensor_drift_diffusion + SAWB_params_dict_rep['a_MSA'] * SAWB_params_dict_rep['r_M'] * MBC - ((V_D * EEC * SOC) / (SAWB_params_dict_rep['K_D'] + SOC))
        drift_DOC = i_D_tensor_drift_diffusion + (1 - SAWB_params_dict_rep['a_MSA']) * SAWB_params_dict_rep['r_M'] * MBC + ((V_D * EEC * SOC) / (SAWB_params_dict_rep['K_D'] + SOC)) + SAWB_params_dict_rep['r_L'] * EEC - ((V_U * MBC * DOC) / (SAWB_params_dict_rep['K_U'] + DOC))
        drift_MBC = (u_Q * (V_U * MBC * DOC) / (SAWB_params_dict_rep['K_U'] + DOC)) - (SAWB_params_dict_rep['r_M'] + SAWB_params_dict_rep['r_E']) * MBC
        drift_EEC = SAWB_params_dict_rep['r_E'] * MBC - SAWB_params_dict_rep['r_L'] * EEC
        #Assign elements to drift vector.
        drift[:, :, 0 : 1] = drift_SOC
        drift[:, :, 1 : 2] = drift_DOC
        drift[:, :, 2 : 3] = drift_MBC
        drift[:, :, 3 : 4] = drift_EEC
        #Diffusion matrix is computed based on diffusion type.
        diffusion_sqrt = torch.zeros([drift.size(0), drift.size(1), self.state_dim, self.state_dim], device = drift.device) #Create tensor to assign diffusion matrix elements.            
        if self.DIFFUSION_TYPE == 'C':
            diffusion_sqrt[:, :, 0 : 1, 0] = torch.sqrt(LowerBound.apply(SAWB_params_dict_rep['c_SOC'], 1e-8)) #SOC diffusion standard deviation
            diffusion_sqrt[:, :, 1 : 2, 1] = torch.sqrt(LowerBound.apply(SAWB_params_dict_rep['c_DOC'], 1e-8)) #DOC diffusion standard deviation
            diffusion_sqrt[:, :, 2 : 3, 2] = torch.sqrt(LowerBound.apply(SAWB_params_dict_rep['c_MBC'], 1e-8)) #MBC diffusion standard deviation
            diffusion_sqrt[:, :, 3 : 4, 3] = torch.sqrt(LowerBound.apply(SAWB_params_dict_rep['c_EEC'], 1e-8)) #EEC diffusion standard deviation            
            #diffusion_sqrt_single = torch.diag_embed(torch.sqrt(LowerBound.apply(torch.as_tensor([SAWB_params_dict['c_SOC'], SAWB_params_dict['c_DOC'], SAWB_params_dict['c_MBC'], SAWB_params_dict['c_EEC'], SAWB_params_dict['c_CO2']]), 1e-8))) #Create single diffusion matrix by diagonalizing constant noise scale parameters.            
            #diffusion_sqrt = diffusion_sqrt_single.unsqueeze(1).expand(-1, t_span_tensor_drift_diffusion.size(1), -1, -1) #Expand diffusion matrices across all paths and across discretized time steps.
        elif self.DIFFUSION_TYPE == 'SS':
            diffusion_sqrt[:, :, 0 : 1, 0] = torch.sqrt(LowerBound.apply(SOC * SAWB_params_dict_rep['s_SOC'], 1e-8)) #SOC diffusion standard deviation
            diffusion_sqrt[:, :, 1 : 2, 1] = torch.sqrt(LowerBound.apply(DOC * SAWB_params_dict_rep['s_DOC'], 1e-8)) #DOC diffusion standard deviation
            diffusion_sqrt[:, :, 2 : 3, 2] = torch.sqrt(LowerBound.apply(MBC * SAWB_params_dict_rep['s_MBC'], 1e-8)) #MBC diffusion standard deviation
            diffusion_sqrt[:, :, 3 : 4, 3] = torch.sqrt(LowerBound.apply(EEC * SAWB_params_dict_rep['s_EEC'], 1e-8)) #EEC diffusion standard deviation            
        return drift, diffusion_sqrt

    def drift_diffusion_add_CO2(
        self,
        C_PATH: torch.Tensor, 
        SAWB_params_dict: DictOfTensors,
        ) -> TupleOfTensors:
        '''
        Accepts states x and dictionary of parameter samples.
        Returns SAWB drift and diffusion tensors corresponding to state values and parameter samples, along with tensor of states x concatenated with CO2.  
        Expected SAWB_params_dict = {'u_Q_ref': u_Q_ref, 'Q': Q, 'a_MSA': a_MSA, 'K_D': K_D, 'K_U': K_U, 'V_D_ref': V_D_ref, 'V_U_ref': V_U_ref, 'Ea_V_D': Ea_V_D, 'Ea_V_U': Ea_V_U, 'r_M': r_M, 'r_E': r_E, 'r_L': r_L, '[cs]_SOC': [cs]_SOC, '[cs]_DOC': [cs]_DOC, '[cs]_MBC': [cs]_MBC, '[cs]_EEC': [cs]_EEC}
        '''
        c_path_drift_diffusion = C_PATH[:, :-1, :]
        i_S_tensor_drift_diffusion = self.i_S[:, 1:, :]
        i_D_tensor_drift_diffusion = self.i_D[:, 1:, :]
        #Partition SOC, DOC, MBC, EEC values. Split based on final C_PATH dim, which specifies state variables and is also indexed as dim #2 in tensor.
        SOC_full, DOC_full, MBC_full, EEC_full = torch.chunk(C_PATH, self.state_dim, -1)
        SOC = SOC_full[:, :-1, :]
        DOC = DOC_full[:, :-1, :]
        MBC = MBC_full[:, :-1, :]
        EEC = EEC_full[:, :-1, :]
        #Repeat and permute parameter values to match dimension sizes.
        SAWB_params_dict_rep_full = dict((k, v.repeat(1, self.times.size(1), 1).permute(2, 1, 0)) for k, v in SAWB_params_dict.items())
        SAWB_params_dict_rep = dict((k, v[:, 1:, :]) for k, v in SAWB_params_dict_rep_full.items())        
        #Initiate tensor with same dims as C_PATH to assign drift.
        drift = torch.empty_like(c_path_drift_diffusion, device = C_PATH.device)
        #Decay parameters are forced by temperature changes.
        u_Q_full = linear_temp_dep(SAWB_params_dict_rep_full['u_Q_ref'], self.temps, SAWB_params_dict_rep_full['Q'], self.temp_ref) #Apply linear temperature-dependence to u_Q.
        u_Q = u_Q_full[:, 1:, :]
        V_D_full = arrhenius_temp_dep(SAWB_params_dict_rep_full['V_D_ref'], self.temps, SAWB_params_dict_rep_full['Ea_V_D'], self.temp_ref) #Apply vectorized temperature-dependent transformation to V_D.
        V_D = V_D_full[:, 1:, :]
        V_U_full = arrhenius_temp_dep(SAWB_params_dict_rep_full['V_U_ref'], self.temps, SAWB_params_dict_rep_full['Ea_V_U'], self.temp_ref) #Apply vectorized temperature-dependent transformation to V_U.
        V_U = V_U_full[:, 1:, :]
        #Drift is calculated.
        drift_SOC = i_S_tensor_drift_diffusion + SAWB_params_dict_rep['a_MSA'] * SAWB_params_dict_rep['r_M'] * MBC - ((V_D * EEC * SOC) / (SAWB_params_dict_rep['K_D'] + SOC))
        drift_DOC = i_D_tensor_drift_diffusion + (1 - SAWB_params_dict_rep['a_MSA']) * SAWB_params_dict_rep['r_M'] * MBC + ((V_D * EEC * SOC) / (SAWB_params_dict_rep['K_D'] + SOC)) + SAWB_params_dict_rep['r_L'] * EEC - ((V_U * MBC * DOC) / (SAWB_params_dict_rep['K_U'] + DOC))
        drift_MBC = (u_Q * (V_U * MBC * DOC) / (SAWB_params_dict_rep['K_U'] + DOC)) - (SAWB_params_dict_rep['r_M'] + SAWB_params_dict_rep['r_E']) * MBC
        drift_EEC = SAWB_params_dict_rep['r_E'] * MBC - SAWB_params_dict_rep['r_L'] * EEC
        #Assign elements to drift vector.
        drift[:, :, 0 : 1] = drift_SOC
        drift[:, :, 1 : 2] = drift_DOC
        drift[:, :, 2 : 3] = drift_MBC
        drift[:, :, 3 : 4] = drift_EEC
        #Diffusion matrix is computed based on diffusion type.
        diffusion_sqrt = torch.zeros([drift.size(0), drift.size(1), self.state_dim, self.state_dim], device = drift.device) #Create tensor to assign diffusion matrix elements.            
        if self.DIFFUSION_TYPE == 'C':
            diffusion_sqrt[:, :, 0 : 1, 0] = torch.sqrt(LowerBound.apply(SAWB_params_dict_rep['c_SOC'], 1e-8)) #SOC diffusion standard deviation
            diffusion_sqrt[:, :, 1 : 2, 1] = torch.sqrt(LowerBound.apply(SAWB_params_dict_rep['c_DOC'], 1e-8)) #DOC diffusion standard deviation
            diffusion_sqrt[:, :, 2 : 3, 2] = torch.sqrt(LowerBound.apply(SAWB_params_dict_rep['c_MBC'], 1e-8)) #MBC diffusion standard deviation
            diffusion_sqrt[:, :, 3 : 4, 3] = torch.sqrt(LowerBound.apply(SAWB_params_dict_rep['c_EEC'], 1e-8)) #EEC diffusion standard deviation            
            #diffusion_sqrt_single = torch.diag_embed(torch.sqrt(LowerBound.apply(torch.as_tensor([SAWB_params_dict['c_SOC'], SAWB_params_dict['c_DOC'], SAWB_params_dict['c_MBC'], SAWB_params_dict['c_EEC'], SAWB_params_dict['c_CO2']]), 1e-8))) #Create single diffusion matrix by diagonalizing constant noise scale parameters.            
            #diffusion_sqrt = diffusion_sqrt_single.unsqueeze(1).expand(-1, T_SPAN_TENSOR.size(1), -1, -1) #Expand diffusion matrices across all paths and across discretized time steps.
        elif self.DIFFUSION_TYPE == 'SS':
            diffusion_sqrt[:, :, 0 : 1, 0] = torch.sqrt(LowerBound.apply(SOC * SAWB_params_dict_rep['s_SOC'], 1e-8)) #SOC diffusion standard deviation
            diffusion_sqrt[:, :, 1 : 2, 1] = torch.sqrt(LowerBound.apply(DOC * SAWB_params_dict_rep['s_DOC'], 1e-8)) #DOC diffusion standard deviation
            diffusion_sqrt[:, :, 2 : 3, 2] = torch.sqrt(LowerBound.apply(MBC * SAWB_params_dict_rep['s_MBC'], 1e-8)) #MBC diffusion standard deviation
            diffusion_sqrt[:, :, 3 : 4, 3] = torch.sqrt(LowerBound.apply(EEC * SAWB_params_dict_rep['s_EEC'], 1e-8)) #EEC diffusion standard deviation           
        #Compute CO2.
        CO2 = (1 - u_Q_full) * (V_U_full * MBC_full * DOC_full) / (SAWB_params_dict_rep_full['K_U'] + DOC_full)
        #Add CO2 as additional dimension to original x matrix.
        x_add_CO2 = torch.cat([C_PATH, CO2], -1)

        return drift, diffusion_sqrt, x_add_CO2

    def add_CO2(
        self,
        C_PATH: torch.Tensor,
        SAWB_params_dict: DictOfTensors,
        ) -> TupleOfTensors:
        '''
        Accepts input of states x and dictionary of parameter samples.
        Returns matrix (re-sized from x) that not only includes states, but added CO2 values in expanded third dimension of tensor.
        '''
        #Partition SOC, DOC, MBC, and EEC values. Split based on final C_PATH dim, which specifies state variables and is also indexed as dim #2 in tensor. 
        SOC, DOC, MBC, EEC = torch.chunk(C_PATH, self.state_dim, -1)
        #Repeat and permute parameter values to match dimension sizes.
        SAWB_params_dict_rep = dict((k, v.repeat(1, self.times.size(1), 1).permute(2, 1, 0)) for k, v in SAWB_params_dict.items())
        #Decay parameters are forced by temperature changes.
        u_Q = linear_temp_dep(SAWB_params_dict_rep['u_Q_ref'], self.temps, SAWB_params_dict_rep['Q'], self.temp_ref) #Apply linear temperature-dependence to u_Q.
        V_D = arrhenius_temp_dep(SAWB_params_dict_rep['V_D_ref'], self.temps, SAWB_params_dict_rep['Ea_V_D'], self.temp_ref) #Apply vectorized temperature-dependent transformation to V_D.
        V_U = arrhenius_temp_dep(SAWB_params_dict_rep['V_U_ref'], self.temps, SAWB_params_dict_rep['Ea_V_U'], self.temp_ref) #Apply vectorized temperature-dependent transformation to V_U.
        #Compute CO2.
        CO2 = (1 - u_Q) * (V_U * MBC * DOC) / (SAWB_params_dict_rep['K_U'] + DOC)
        #Add CO2 as additional dimension to original x matrix.
        x_add_CO2 = torch.cat([C_PATH, CO2], -1)
        
        return x_add_CO2

class SAWB_multi(SBM_SDE):
    '''
    Class contains SAWB SDE drift (alpha) and diffusion (beta) equations.
    Constant (C) and state-scaling (SS) diffusion paramterizations are included. DIFFUSION_TYPE must thereby be specified as 'C' or 'SS'. 
    Other diffusion parameterizations are not included.
    '''
    def __init__(
            self,
            T_SPAN_TENSOR: torch.Tensor,
            I_S_TENSOR: torch.Tensor,
            I_D_TENSOR: torch.Tensor,
            TEMP_TENSOR: torch.Tensor,
            TEMP_REF: Number,
            DIFFUSION_TYPE: str
            ):
        super().__init__(T_SPAN_TENSOR, I_S_TENSOR, I_D_TENSOR, TEMP_TENSOR, TEMP_REF)

        if DIFFUSION_TYPE not in {'C', 'SS'}:
            raise NotImplementedError('Other diffusion parameterizations aside from constant (c) or state-scaling (ss) have not been implemented.')

        self.DIFFUSION_TYPE = DIFFUSION_TYPE
        self.state_dim = 4

    def drift_diffusion(
        self,
        C_PATH: torch.Tensor, 
        SAWB_params_dict: DictOfTensors,
        ) -> TupleOfTensors:
        '''
        Accepts states x and dictionary of parameter samples.
        Returns SAWB drift and diffusion tensors corresponding to state values and parameter samples.  
        Expected SAWB_params_dict = {'u_Q_ref': u_Q_ref, 'Q': Q, 'a_MSA': a_MSA, 'K_D': K_D, 'K_U': K_U, 'V_D_ref': V_D_ref, 'V_U_ref': V_U_ref, 'Ea_V_D': Ea_V_D, 'Ea_V_U': Ea_V_U, 'r_M': r_M, 'r_E': r_E, 'r_L': r_L, '[cs]_SOC': [cs]_SOC, '[cs]_DOC': [cs]_DOC, '[cs]_MBC': [cs]_MBC, '[cs]_EEC': [cs]_EEC}
        '''
        #Appropriately index tensors based on order of operations in data generating process.
        c_path_drift_diffusion = C_PATH[:, :, :-1, :]
        t_span_tensor_drift_diffusion = self.times[:, 1:, :].unsqueeze(1)      
        i_S_tensor_drift_diffusion = self.i_S[:, 1:, :].unsqueeze(1)
        i_D_tensor_drift_diffusion = self.i_D[:, 1:, :].unsqueeze(1)
        temp_tensor_drift_diffusion = self.temps[:, 1:, :].unsqueeze(1)
        #Partition SOC, DOC, MBC, EEC values. Split based on final c_path_drift_diffusion dim, which specifies state variables and is also indexed as dim #2 in tensor.
        SOC, DOC, MBC, EEC =  torch.chunk(c_path_drift_diffusion, self.state_dim, -1)
        #Repeat and permute parameter values to match dimension sizes.
        SAWB_params_dict_rep = dict((k, v.repeat(1, t_span_tensor_drift_diffusion.size(1), 1).permute(2, 1, 0).unsqueeze(1)) for k, v in SAWB_params_dict.items())
        #Initiate tensor with same dims as c_path_drift_diffusion to assign drift.
        drift = torch.empty_like(c_path_drift_diffusion, device = C_PATH.device)
        #Decay parameters are forced by temperature changes.
        u_Q = linear_temp_dep(SAWB_params_dict_rep['u_Q_ref'], temp_tensor_drift_diffusion, SAWB_params_dict_rep['Q'], self.temp_ref) #Apply linear temperature-dependence to u_Q.
        V_D = arrhenius_temp_dep(SAWB_params_dict_rep['V_D_ref'], temp_tensor_drift_diffusion, SAWB_params_dict_rep['Ea_V_D'], self.temp_ref) #Apply vectorized temperature-dependent transformation to V_D.
        V_U = arrhenius_temp_dep(SAWB_params_dict_rep['V_U_ref'], temp_tensor_drift_diffusion, SAWB_params_dict_rep['Ea_V_U'], self.temp_ref) #Apply vectorized temperature-dependent transformation to V_U.
        #Drift is calculated.
        drift_SOC = i_S_tensor_drift_diffusion + SAWB_params_dict_rep['a_MSA'] * SAWB_params_dict_rep['r_M'] * MBC - ((V_D * EEC * SOC) / (SAWB_params_dict_rep['K_D'] + SOC))
        drift_DOC = i_D_tensor_drift_diffusion + (1 - SAWB_params_dict_rep['a_MSA']) * SAWB_params_dict_rep['r_M'] * MBC + ((V_D * EEC * SOC) / (SAWB_params_dict_rep['K_D'] + SOC)) + SAWB_params_dict_rep['r_L'] * EEC - ((V_U * MBC * DOC) / (SAWB_params_dict_rep['K_U'] + DOC))
        drift_MBC = (u_Q * (V_U * MBC * DOC) / (SAWB_params_dict_rep['K_U'] + DOC)) - (SAWB_params_dict_rep['r_M'] + SAWB_params_dict_rep['r_E']) * MBC
        drift_EEC = SAWB_params_dict_rep['r_E'] * MBC - SAWB_params_dict_rep['r_L'] * EEC
        #Assign elements to drift vector.
        drift[:, :, :, 0 : 1] = drift_SOC
        drift[:, :, :, 1 : 2] = drift_DOC
        drift[:, :, :, 2 : 3] = drift_MBC
        drift[:, :, :, 3 : 4] = drift_EEC
        #Diffusion matrix is computed based on diffusion type.
        diffusion_sqrt = torch.zeros([drift.size(0), drift.size(1), drift.size(2), self.state_dim, self.state_dim], device = drift.device) #Create tensor to assign diffusion matrix elements.            
        if self.DIFFUSION_TYPE == 'C':
            diffusion_sqrt[:, :, :, 0 : 1, 0] = torch.sqrt(LowerBound.apply(SAWB_params_dict_rep['c_SOC'], 1e-8)) #SOC diffusion standard deviation
            diffusion_sqrt[:, :, :, 1 : 2, 1] = torch.sqrt(LowerBound.apply(SAWB_params_dict_rep['c_DOC'], 1e-8)) #DOC diffusion standard deviation
            diffusion_sqrt[:, :, :, 2 : 3, 2] = torch.sqrt(LowerBound.apply(SAWB_params_dict_rep['c_MBC'], 1e-8)) #MBC diffusion standard deviation
            diffusion_sqrt[:, :, :, 3 : 4, 3] = torch.sqrt(LowerBound.apply(SAWB_params_dict_rep['c_EEC'], 1e-8)) #EEC diffusion standard deviation
        elif self.DIFFUSION_TYPE == 'SS':
            diffusion_sqrt[:, :, :, 0 : 1, 0] = torch.sqrt(LowerBound.apply(SOC * SAWB_params_dict_rep['s_SOC'], 1e-8)) #SOC diffusion standard deviation
            diffusion_sqrt[:, :, :, 1 : 2, 1] = torch.sqrt(LowerBound.apply(DOC * SAWB_params_dict_rep['s_DOC'], 1e-8)) #DOC diffusion standard deviation
            diffusion_sqrt[:, :, :, 2 : 3, 2] = torch.sqrt(LowerBound.apply(MBC * SAWB_params_dict_rep['s_MBC'], 1e-8)) #MBC diffusion standard deviation
            diffusion_sqrt[:, :, :, 3 : 4, 3] = torch.sqrt(LowerBound.apply(EEC * SAWB_params_dict_rep['s_EEC'], 1e-8)) #EEC diffusion standard deviation            
        return drift, diffusion_sqrt

In [13]:
SAWB_SS_test = SAWB(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)
print(SAWB_SS_test.times)
print(SAWB_SS_test.temps)
print(SAWB_SS_test.temp_ref)
print(SAWB_SS_test.i_S)
print(SAWB_SS_test.i_D)

tensor([[[   0.],
         [   1.],
         [   2.],
         ...,
         [ 998.],
         [ 999.],
         [1000.]]])
tensor([[[283.00000000],
         [285.59536743],
         [288.01434326],
         ...,
         [284.56951904],
         [282.50381470],
         [280.92001343]]])
283.0
tensor([[[0.00100000],
         [0.00100036],
         [0.00100072],
         ...,
         [0.00132812],
         [0.00132839],
         [0.00132866]]])
tensor([[[9.99999975e-05],
         [1.00035861e-04],
         [1.00071724e-04],
         ...,
         [1.32811969e-04],
         [1.32839021e-04],
         [1.32866058e-04]]])


In [14]:
SAWB_SS_drift_1, SAWB_SS_diffusion_sqrt_1, SAWB_SS_x_add_CO2_1 = SAWB_SS_test.drift_diffusion_add_CO2(x2, SAWB_SS_dict_out)

In [15]:
print(SAWB_SS_drift_1)
print(SAWB_SS_diffusion_sqrt_1)
print(SAWB_SS_x_add_CO2_1)

tensor([[[ 1.60982891e-03, -1.06490459e-02, -1.58324372e-04,  3.55813478e-04],
         [ 1.66351791e-03, -1.46064730e-02,  3.43389809e-04,  3.91523849e-04],
         [ 1.56258454e-03, -1.50396442e-02,  7.48788007e-04,  3.35764169e-04],
         ...,
         [ 4.27675713e-03, -5.56097999e-02,  3.35224904e-04,  1.68078684e-03],
         [ 4.37142514e-03, -4.63937856e-02, -1.65920611e-03,  1.73184555e-03],
         [ 4.37143864e-03, -3.93303186e-02, -2.96656042e-03,  1.72942039e-03]],

        [[ 2.76826113e-03, -1.60124358e-02,  1.99041795e-03,  5.68952179e-04],
         [ 2.94596562e-03, -2.27624122e-02,  3.57216178e-03,  6.27946807e-04],
         [ 2.66709039e-03, -2.41859406e-02,  4.30789683e-03,  5.39222849e-04],
         ...,
         [ 9.65883769e-03, -7.29899108e-02,  8.79621226e-03,  2.65787658e-03],
         [ 9.90813319e-03, -5.97021766e-02,  4.77249175e-03,  2.73566018e-03],
         [ 9.90487076e-03, -4.98985611e-02,  2.02458445e-03,  2.73431139e-03]],

        [[ 1.6825927

In [16]:
SAWB_SS_drift_2, SAWB_SS_diffusion_sqrt_2 = SAWB_SS_test.drift_diffusion(x2, SAWB_SS_dict_out)

In [17]:
print(SAWB_SS_drift_2)
print(SAWB_SS_diffusion_sqrt_2)

tensor([[[ 1.60982891e-03, -1.06490459e-02, -1.58324372e-04,  3.55813478e-04],
         [ 1.66351791e-03, -1.46064730e-02,  3.43389809e-04,  3.91523849e-04],
         [ 1.56258454e-03, -1.50396442e-02,  7.48788007e-04,  3.35764169e-04],
         ...,
         [ 4.27675713e-03, -5.56097999e-02,  3.35224904e-04,  1.68078684e-03],
         [ 4.37142514e-03, -4.63937856e-02, -1.65920611e-03,  1.73184555e-03],
         [ 4.37143864e-03, -3.93303186e-02, -2.96656042e-03,  1.72942039e-03]],

        [[ 2.76826113e-03, -1.60124358e-02,  1.99041795e-03,  5.68952179e-04],
         [ 2.94596562e-03, -2.27624122e-02,  3.57216178e-03,  6.27946807e-04],
         [ 2.66709039e-03, -2.41859406e-02,  4.30789683e-03,  5.39222849e-04],
         ...,
         [ 9.65883769e-03, -7.29899108e-02,  8.79621226e-03,  2.65787658e-03],
         [ 9.90813319e-03, -5.97021766e-02,  4.77249175e-03,  2.73566018e-03],
         [ 9.90487076e-03, -4.98985611e-02,  2.02458445e-03,  2.73431139e-03]],

        [[ 1.6825927

In [18]:
SAWB_SS_x_add_CO2_2 = SAWB_SS_test.add_CO2(x2, SAWB_SS_dict_out)
print(SAWB_SS_x_add_CO2_2)

tensor([[[4.38741608e+01, 1.10990882e+00, 1.05660713e+00, 8.75303708e-03,
          7.89442006e-03],
         [4.38226395e+01, 1.02039123e+00, 1.16628361e+00, 1.09660393e-02,
          1.07745826e-02],
         [4.46994209e+01, 9.65984821e-01, 1.00154042e+00, 9.90229100e-03,
          1.12855956e-02],
         ...,
         [2.01764183e+01, 4.98133755e+00, 5.07834673e+00, 1.89218055e-02,
          5.22086285e-02],
         [1.94388275e+01, 5.12637997e+00, 5.07599831e+00, 2.06450876e-02,
          4.34600599e-02],
         [1.95989647e+01, 5.00414944e+00, 4.99099350e+00, 2.00175680e-02,
          3.69774513e-02]],

        [[4.38741608e+01, 1.10990882e+00, 1.05660713e+00, 8.75303708e-03,
          8.87857936e-03],
         [4.38226395e+01, 1.02039123e+00, 1.16628361e+00, 1.09660393e-02,
          1.28971329e-02],
         [4.46994209e+01, 9.65984821e-01, 1.00154042e+00, 9.90229100e-03,
          1.42704295e-02],
         ...,
         [2.01764183e+01, 4.98133755e+00, 5.07834673e+00, 1.8

In [19]:
num_sequences = 6
x2_multi = x2.unsqueeze(1).repeat(1, num_sequences, 1, 1) #batch_size, num_sequences, time_steps, state_dims
print(x2_multi.size())

torch.Size([3, 6, 1001, 4])


In [20]:
SOC, DOC, MBC, EEC = torch.chunk(x2_multi, state_dim_SAWB, -1)
print(SOC.size())
print(SOC)

torch.Size([3, 6, 1001, 1])
tensor([[[[43.87416077],
          [43.82263947],
          [44.69942093],
          ...,
          [20.17641830],
          [19.43882751],
          [19.59896469]],

         [[43.87416077],
          [43.82263947],
          [44.69942093],
          ...,
          [20.17641830],
          [19.43882751],
          [19.59896469]],

         [[43.87416077],
          [43.82263947],
          [44.69942093],
          ...,
          [20.17641830],
          [19.43882751],
          [19.59896469]],

         [[43.87416077],
          [43.82263947],
          [44.69942093],
          ...,
          [20.17641830],
          [19.43882751],
          [19.59896469]],

         [[43.87416077],
          [43.82263947],
          [44.69942093],
          ...,
          [20.17641830],
          [19.43882751],
          [19.59896469]],

         [[43.87416077],
          [43.82263947],
          [44.69942093],
          ...,
          [20.17641830],
          [19.43882751

In [21]:
t_span_tensor_multi = t_span_tensor[:, 1:, :].unsqueeze(1)
print(t_span_tensor_multi.size())

torch.Size([1, 1, 1000, 1])


In [22]:
test_params_dict_values_multi = torch.rand([batch_size, n, 1]).unsqueeze(1)
print(test_params_dict_values_multi.size())

torch.Size([3, 1, 1001, 1])


In [23]:
SAWB_SS_multi_test = SAWB_multi(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)

In [24]:
SAWB_SS_drift_multi, SAWB_SS_diffusion_sqrt_multi = SAWB_SS_multi_test.drift_diffusion(x2_multi, SAWB_SS_dict_out)

In [28]:
print(SAWB_SS_drift_2)
print(SAWB_SS_drift_2.size())
print(SAWB_SS_drift_multi)
print(SAWB_SS_drift_multi.size())
SAWB_SS_drift_multi[:, 0, :, :] == SAWB_SS_drift_2

tensor([[[ 1.60982891e-03, -1.06490459e-02, -1.58324372e-04,  3.55813478e-04],
         [ 1.66351791e-03, -1.46064730e-02,  3.43389809e-04,  3.91523849e-04],
         [ 1.56258454e-03, -1.50396442e-02,  7.48788007e-04,  3.35764169e-04],
         ...,
         [ 4.27675713e-03, -5.56097999e-02,  3.35224904e-04,  1.68078684e-03],
         [ 4.37142514e-03, -4.63937856e-02, -1.65920611e-03,  1.73184555e-03],
         [ 4.37143864e-03, -3.93303186e-02, -2.96656042e-03,  1.72942039e-03]],

        [[ 2.76826113e-03, -1.60124358e-02,  1.99041795e-03,  5.68952179e-04],
         [ 2.94596562e-03, -2.27624122e-02,  3.57216178e-03,  6.27946807e-04],
         [ 2.66709039e-03, -2.41859406e-02,  4.30789683e-03,  5.39222849e-04],
         ...,
         [ 9.65883769e-03, -7.29899108e-02,  8.79621226e-03,  2.65787658e-03],
         [ 9.90813319e-03, -5.97021766e-02,  4.77249175e-03,  2.73566018e-03],
         [ 9.90487076e-03, -4.98985611e-02,  2.02458445e-03,  2.73431139e-03]],

        [[ 1.6825927

tensor([[[True, True, True, True],
         [True, True, True, True],
         [True, True, True, True],
         ...,
         [True, True, True, True],
         [True, True, True, True],
         [True, True, True, True]],

        [[True, True, True, True],
         [True, True, True, True],
         [True, True, True, True],
         ...,
         [True, True, True, True],
         [True, True, True, True],
         [True, True, True, True]],

        [[True, True, True, True],
         [True, True, True, True],
         [True, True, True, True],
         ...,
         [True, True, True, True],
         [True, True, True, True],
         [True, True, True, True]]])

In [29]:
#Parameter prior means
u_Q_ref_mean = 0.2
Q_mean = 0.001
a_MSA_mean = 0.5
K_DE_mean = 1850
K_UE_mean = 0.2
V_DE_ref_mean = 0.16
V_UE_ref_mean = 0.012
Ea_V_DE_mean = 65
Ea_V_UE_mean = 55
r_M_mean = 0.0018
r_E_mean = 0.00003
r_L_mean = 0.000008
s_SOC_mean = 0.005
s_DOC_mean = 0.005
s_MBC_mean = 0.005
s_EEC_mean = 0.005

#SAWB-ECA theta truncated normal distribution parameter details in order of mean, sdev, lower, and upper.
u_Q_ref_details = torch.Tensor([u_Q_ref_mean, u_Q_ref_mean * prior_scale_factor, 0, 1])
Q_details = torch.Tensor([Q_mean, Q_mean * prior_scale_factor, 0, 1])
a_MSA_details = torch.Tensor([a_MSA_mean, a_MSA_mean * prior_scale_factor, 0, 1])
K_DE_details = torch.Tensor([K_DE_mean, K_DE_mean * prior_scale_factor, 0, 10000])
K_UE_details = torch.Tensor([K_UE_mean, K_UE_mean * prior_scale_factor, 0, 100])
V_DE_ref_details = torch.Tensor([V_DE_ref_mean, V_DE_ref_mean * prior_scale_factor, 0, 10])
V_UE_ref_details = torch.Tensor([V_UE_ref_mean, V_UE_ref_mean * prior_scale_factor, 0, 1])
Ea_V_DE_details = torch.Tensor([Ea_V_DE_mean, Ea_V_DE_mean * prior_scale_factor, 10, 150])
Ea_V_UE_details = torch.Tensor([Ea_V_UE_mean, Ea_V_UE_mean * prior_scale_factor, 10, 150])
r_M_details = torch.Tensor([r_M_mean, r_M_mean * prior_scale_factor, 0, 1])
r_E_details = torch.Tensor([r_E_mean, r_M_mean * prior_scale_factor, 0, 1])
r_L_details = torch.Tensor([r_L_mean, r_M_mean * prior_scale_factor, 0, 1])

#SAWB-ECA-SS diffusion matrix parameter distribution details
s_SOC_details = torch.Tensor([s_SOC_mean, s_SOC_mean * prior_scale_factor, 0, 1])
s_DOC_details = torch.Tensor([s_DOC_mean, s_DOC_mean * prior_scale_factor, 0, 1])
s_MBC_details = torch.Tensor([s_MBC_mean, s_MBC_mean * prior_scale_factor, 0, 1])
s_EEC_details = torch.Tensor([s_EEC_mean, s_EEC_mean * prior_scale_factor, 0, 1])

SAWB_ECA_SS_priors_details = {'u_Q_ref': u_Q_ref_details, 'Q': Q_details, 'a_MSA': a_MSA_details, 'K_DE': K_DE_details, 'K_UE': K_UE_details, 'V_DE_ref': V_DE_ref_details, 'V_UE_ref': V_UE_ref_details, 'Ea_V_DE': Ea_V_DE_details, 'Ea_V_UE': Ea_V_UE_details, 'r_M': r_M_details, 'r_E': r_E_details, 'r_L': r_L_details, 's_SOC': s_SOC_details, 's_DOC': s_DOC_details, 's_MBC': s_MBC_details, 's_EEC': s_EEC_details}

In [30]:
param_names_ECA = list(SAWB_ECA_SS_priors_details.keys())
prior_list_ECA = list(zip(*(SAWB_ECA_SS_priors_details[k] for k in param_names_ECA))) #Unzip prior distribution details from dictionary values into individual lists.
prior_means_tensor_ECA, prior_sds_tensor_ECA, prior_lowers_tensor_ECA, prior_uppers_tensor_ECA = torch.tensor(prior_list_ECA).to(active_device) #Ensure conversion of lists into tensors.
priors_ECA = TruncatedNormal(loc = prior_means_tensor_ECA, scale = prior_sds_tensor_ECA, a = prior_lowers_tensor_ECA, b = prior_uppers_tensor_ECA)

In [32]:
q_theta_SAWB_ECA_SS = MeanField(active_device, param_names_ECA, SAWB_ECA_SS_priors_details, TruncatedNormal, False)
SAWB_ECA_SS_dict_out, SAWB_ECA_SS_samples, _, _ = q_theta_SAWB_ECA_SS(batch_size)

In [33]:
SAWB_ECA_SS_multi_test = SAWB_ECA_multi(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type)

In [34]:
SAWB_ECA_SS_drift_multi, SAWB_ECA_SS_diffusion_sqrt_multi = SAWB_ECA_SS_multi_test.drift_diffusion(x2_multi, SAWB_ECA_SS_dict_out)

In [35]:
#SCON theta truncated normal prior distribution parameter details in order of mean, lower, and upper. Distribution sdev assumed to be some proportion of the mean. 
u_M_details = torch.Tensor([0.0016, 0.0016 * prior_scale_factor, 0, 0.1]).to(active_device)
a_SD_details = torch.Tensor([0.5, 0.5 * prior_scale_factor, 0, 1]).to(active_device)
a_DS_details = torch.Tensor([0.5, 0.5 * prior_scale_factor, 0, 1]).to(active_device)
a_M_details = torch.Tensor([0.5, 0.5 * prior_scale_factor, 0, 1]).to(active_device)
a_MSC_details = torch.Tensor([0.5, 0.5 * prior_scale_factor, 0, 1]).to(active_device)
k_S_ref_details = torch.Tensor([0.0005, 0.0005 * prior_scale_factor, 0, 0.01]).to(active_device)
k_D_ref_details = torch.Tensor([0.0008, 0.0008 * prior_scale_factor, 0, 0.01]).to(active_device)
k_M_ref_details = torch.Tensor([0.0007, 0.0007 * prior_scale_factor, 0, 0.01]).to(active_device)
Ea_S_details = torch.Tensor([55, 55 * prior_scale_factor, 10, 100]).to(active_device)
Ea_D_details = torch.Tensor([48, 48 * prior_scale_factor, 10, 100]).to(active_device)
Ea_M_details = torch.Tensor([48, 48 * prior_scale_factor, 10, 100]).to(active_device)

#SCON-C diffusion matrix parameter truncated normal prior distribution parameter details in order of mean, lower, and upper. 
c_SOC_details = torch.Tensor([0.1, 0.1 * prior_scale_factor, 0, 1]).to(active_device)
c_DOC_details = torch.Tensor([0.002, 0.002 * prior_scale_factor, 0, 0.02]).to(active_device)
c_MBC_details = torch.Tensor([0.002, 0.002 * prior_scale_factor, 0, 0.02]).to(active_device)

SCON_C_priors_details = {'u_M': u_M_details, 'a_SD': a_SD_details, 'a_DS': a_DS_details, 'a_M': a_M_details, 'a_MSC': a_MSC_details, 'k_S_ref': k_S_ref_details, 'k_D_ref': k_D_ref_details, 'k_M_ref': k_M_ref_details, 'Ea_S': Ea_S_details, 'Ea_D': Ea_D_details, 'Ea_M': Ea_M_details, 'c_SOC': c_SOC_details, 'c_DOC': c_DOC_details, 'c_MBC': c_MBC_details}

diffusion_type_C = 'C'

In [36]:
param_names_SCON_C = list(SCON_C_priors_details.keys())
prior_list_SCON_C = list(zip(*(SCON_C_priors_details[k] for k in param_names_SCON_C))) #Unzip prior distribution details from dictionary values into individual lists.
prior_means_tensor_SCON_C, prior_sds_tensor_SCON_C, prior_lowers_tensor_SCON_C, prior_uppers_tensor_SCON_C = torch.tensor(prior_list_SCON_C).to(active_device) #Ensure conversion of lists into tensors.
priors_SCON_C = TruncatedNormal(loc = prior_means_tensor_SCON_C, scale = prior_sds_tensor_SCON_C, a = prior_lowers_tensor_SCON_C, b = prior_uppers_tensor_SCON_C)

In [39]:
q_theta_SCON_C = MeanField(active_device, param_names_SCON_C, SCON_C_priors_details, TruncatedNormal, False)
SCON_C_dict_out, SCON_C_samples, _, _ = q_theta_SCON_C(batch_size)

In [41]:
SCON_C_multi_test = SCON_multi(t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref, diffusion_type_C)

In [42]:
x_multi = x.unsqueeze(1).repeat(1, num_sequences, 1, 1) #batch_size, num_sequences, time_steps, state_dims
print(x_multi.size())

torch.Size([3, 6, 1001, 3])


In [43]:
SCON_C_drift_multi, SCON_C_diffusion_sqrt_multi = SCON_C_multi_test.drift_diffusion(x_multi, SCON_C_dict_out)