In [1]:
from SBM_SDE_tensor import *
from obs_and_flow_classes_and_functions import *
import seaborn as sns
import torch
from torch import nn
import torch.distributions as D
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import math
from tqdm import tqdm
import random
from torch.autograd import Function
import argparse
import os
import sys
from pathlib import Path
import shutil
import pandas as pd

In [2]:
torch.manual_seed(0)
devi = torch.device("".join(["cuda:",f'{cuda_id}']) if torch.cuda.is_available() else "cpu")

cuda_id = 1
dt = .2 #SDE discretization timestep.
t = 500 #Simulation run for T hours.
n = int(t / dt) 
t_span = np.linspace(0, t, n + 1)
t_span_tensor = torch.reshape(torch.Tensor(t_span), [1, n + 1, 1]) #T_span needs to be converted to tensor object. Additionally, facilitates conversion of I_S and I_D to tensor objects.
l_r = 1e-3
niter = 20
piter = 1
batch_size = 3 #Number of sets of observation outputs to sample per set of parameters.
state_dim_SCON = 3 #Not including CO2 in STATE_DIM, because CO2 is an observation.
state_dim_SAWB = 4 #Not including CO2 in STATE_DIM, because CO2 is an observation.
prior_scale_factor = 0.25 #Prior standard deviations set at 0.25 of prior means.
obs_error_scale_factor = 0.1

In [3]:
temp_ref = 283

#System parameters from deterministic CON model
u_M = 0.002
a_SD = 0.33
a_DS = 0.33
a_M = 0.33
a_MSC = 0.5
k_S_ref = 0.000025
k_D_ref = 0.005
k_M_ref = 0.0002
Ea_S = 75
Ea_D = 50
Ea_M = 50

#SCON diffusion matrix sigma scale parameters
c_SOC = 1.
c_DOC = 0.01
c_MBC = 0.1
s_SOC = 0.1
s_DOC = 0.1
s_MBC = 0.1

SCON_C_prior_means = {'u_M': u_M, 'a_SD': a_SD, 'a_DS': a_DS, 'a_M': a_M, 'a_MSC': a_MSC, 'k_S_ref': k_S_ref, 'k_D_ref': k_D_ref, 'k_M_ref': k_M_ref, 'Ea_S': Ea_S, 'Ea_D': Ea_D, 'Ea_M': Ea_M, 'c_SOC': c_SOC, 'c_DOC': c_DOC, 'c_MBC': c_MBC}
SCON_SS_prior_means = {'u_M': u_M, 'a_SD': a_SD, 'a_DS': a_DS, 'a_M': a_M, 'a_MSC': a_MSC, 'k_S_ref': k_S_ref, 'k_D_ref': k_D_ref, 'k_M_ref': k_M_ref, 'Ea_S': Ea_S, 'Ea_D': Ea_D, 'Ea_M': Ea_M, 's_SOC': s_SOC, 's_DOC': s_DOC, 's_MBC': s_MBC}

#System parameters from deterministic AWB model
u_Q_ref = 0.2
Q = 0.002
a_MSA = 0.5
K_D = 200
K_U = 1
V_D_ref = 0.4
V_U_ref = 0.02
Ea_V_D = 75
Ea_V_U = 50
r_M = 0.0004
r_E = 0.00001
r_L = 0.0005

#SAWB diffusion matrix sigma scale parameters
c_SOC = 1.
c_DOC = 0.01
c_MBC = 0.1
c_EEC = 0.001
s_SOC = 0.1
s_DOC = 0.1
s_MBC = 0.1
s_EEC = 0.1

SAWB_C_prior_means = {'u_Q_ref': u_Q_ref, 'Q': Q, 'a_MSA': a_MSA, 'K_D': K_D, 'K_U': K_U, 'V_D_ref': V_D_ref, 'V_U_ref': V_U_ref, 'Ea_V_D': Ea_V_D, 'Ea_V_U': Ea_V_U, 'r_M': r_M, 'r_E': r_E, 'r_L': r_L, 'c_SOC': c_SOC, 'c_DOC': c_DOC, 'c_MBC': c_MBC, 'c_EEC': c_EEC}
SAWB_SS_prior_means = {'u_Q_ref': u_Q_ref, 'Q': Q, 'a_MSA': a_MSA, 'K_D': K_D, 'K_U': K_U, 'V_D_ref': V_D_ref, 'V_U_ref': V_U_ref, 'Ea_V_D': Ea_V_D, 'Ea_V_U': Ea_V_U, 'r_M': r_M, 'r_E': r_E, 'r_L': r_L, 's_SOC': s_SOC, 's_DOC': s_DOC, 's_MBC': s_MBC, 's_EEC': s_EEC}

#System parameters from deterministic model
u_Q_ref = 0.2
Q = 0.002
a_MSA = 0.5
K_DE = 200
K_UE = 1
V_DE_ref = 0.4
V_UE_ref = 0.02
Ea_V_DE = 75
Ea_V_UE = 50
r_M = 0.0004
r_E = 0.00001
r_L = 0.0005

#Diffusion matrix sigma scale parameters
c_SOC = 1.
c_DOC = 0.01
c_MBC = 0.1
c_EEC = 0.001
s_SOC = 0.1
s_DOC = 0.1
s_MBC = 0.1
s_EEC = 0.1

SAWB_ECA_C_prior_means = {'u_Q_ref': u_Q_ref, 'Q': Q, 'a_MSA': a_MSA, 'K_DE': K_DE, 'K_UE': K_UE, 'V_DE_ref': V_DE_ref, 'V_UE_ref': V_UE_ref, 'Ea_V_DE': Ea_V_DE, 'Ea_V_UE': Ea_V_UE, 'r_M': r_M, 'r_E': r_E, 'r_L': r_L, 'c_SOC': c_SOC, 'c_DOC': c_DOC, 'c_MBC': c_MBC, 'c_EEC': c_EEC}
SAWB_ECA_SS_prior_means = {'u_Q_ref': u_Q_ref, 'Q': Q, 'a_MSA': a_MSA, 'K_DE': K_DE, 'K_UE': K_UE, 'V_DE_ref': V_DE_ref, 'V_UE_ref': V_UE_ref, 'Ea_V_DE': Ea_V_DE, 'Ea_V_UE': Ea_V_UE, 'r_M': r_M, 'r_E': r_E, 'r_L': r_L, 's_SOC': s_SOC, 's_DOC': s_DOC, 's_MBC': s_MBC, 's_EEC': s_EEC}

In [4]:
#Obtain SOC and DOC pool litter inputs for all SBMs.
i_s_tensor = 0.001 + 0.0005 * torch.sin((2 * np.pi / (24 * 365)) * t_span_tensor) #Exogenous SOC input function
i_d_tensor = 0.0001 + 0.00005 * torch.sin((2 * np.pi / (24 * 365)) * t_span_tensor) #Exogenous DOC input function

In [39]:
#Mean field VI code block.

#Define mean-field class: consumes parameter dictionary with values used as initial mean values.
class MeanField(nn.Module):
    def __init__(self, init_params, sdev_scale_factor):
        super().__init__()

        #Use param dict to intialise the means for the mean-field approximations.
        means = []
        keys = []
        for key, value in init_params.items():
            keys += [key]
            means += [value]
        self.means = nn.Parameter(torch.Tensor(means))
        self.sds = nn.Parameter(self.means * sdev_scale_factor)
        #Save keys for forward output.
        self.keys = keys

    def forward(self, n = 30):
        #Update posterior.
        q_dist = D.normal.Normal(self.means, LowerBound.apply(self.sds, 1e-7))
        #Sample theta ~ q(theta).
        samples = q_dist.rsample([n])
        #Evaluate log prob of theta samples.
        log_q_theta = torch.sum(q_dist.log_prob(samples), -1) #Shape of n.
        #Return samples in same dictionary format.
        dict_out = {} #Define dictionary with n samples for each parameter.
        for key, sample in zip(self.keys, torch.split(samples, 1, -1),):
            dict_out[f"{key}"] = sample.squeeze(1) #Each sample is of shape [n].
        #Return samples in dictionary and tensor format.
        return dict_out, samples, log_q_theta

In [40]:
def calc_log_lik(C_PATH, T_SPAN_TENSOR, DT, I_S_TENSOR, I_D_TENSOR, DRIFT_DIFFUSION, PARAMS_DICT, TEMP_GEN, TEMP_REF):
    drift, diffusion_sqrt = DRIFT_DIFFUSION(C_PATH[:, :-1, :], T_SPAN_TENSOR[:, :-1, :], I_S_TENSOR[:, :-1, :], I_D_TENSOR[:, :-1, :], PARAMS_DICT, TEMP_GEN, TEMP_REF)
    print('\nDrift = ', drift)
    print('\nDiffusion = ', diffusion_sqrt)
    euler_maruyama_state_sample_object = D.multivariate_normal.MultivariateNormal(loc = C_PATH[:, :-1, :] + drift * DT, scale_tril = diffusion_sqrt * math.sqrt(DT))
    return euler_maruyama_state_sample_object.log_prob(C_PATH[:, 1:, :]).sum(-1)

In [45]:
def train(DEVICE, L_R, NITER, PRETRAIN_ITER, BATCH_SIZE, PRIOR_SCALE_FACTOR, SDEFLOW, ObsModel, csv_to_obs_df, DATA_CSV, OBS_ERROR_SCALE_FACTOR, STATE_DIM, T, DT, N, T_SPAN_TENSOR, I_S_TENSOR, I_D_TENSOR, DRIFT_DIFFUSION, PARAM_PRIOR_MEANS_DICT, TEMP_GEN, TEMP_REF, ANALYTICAL_STEADY_STATE_INIT):
    if PRETRAIN_ITER >= NITER:
        raise Exception("PRETRAIN_ITER must be < NITER.")
    obs_times, obs_means, obs_error = csv_to_obs_df(DATA_CSV, STATE_DIM + 1, T, OBS_ERROR_SCALE_FACTOR) 
    obs_model = ObsModel(DEVICE, obs_times, DT, obs_means[:-1, :], obs_error[:, :-1]) #Hack for bypassing ObsModel and SDEFlow dimension mismatch issue.
    net = SDEFlow(DEVICE, BATCH_SIZE, obs_model, STATE_DIM, T, DT, N).to(DEVICE)
    prior_means_tensor = torch.Tensor(list(PARAM_PRIOR_MEANS_DICT.values()))
    priors = D.normal.Normal(prior_means_tensor, prior_means_tensor * PRIOR_SCALE_FACTOR)
    q_theta = MeanField(PARAM_PRIOR_MEANS_DICT, PRIOR_SCALE_FACTOR)
    pretrain_optimizer = optim.Adamax(net.parameters(), lr = L_R, eps = 1e-7)
    ELBO_optimizer = optim.Adam(list(net.parameters()) + list(q_theta.parameters()), lr = L_R)
    best_loss_norm = 1e10
    best_loss_ELBO = 1e20
    norm_losses = [best_loss_norm] * 10
    ELBO_losses = [best_loss_ELBO] * 10
    with tqdm(total = NITER, desc = f'\nTrain Diffusion', position = -1) as tq:
        for iter in range(NITER):
            net.train()
            C_PATH, log_prob = net() #Obtain paths with solutions at times after t0.
            theta_dict, theta, log_q_theta = q_theta(BATCH_SIZE)
            print('\ntheta_dict = ', theta_dict)
            C_0 = LowerBound.apply(ANALYTICAL_STEADY_STATE_INIT(I_S_TENSOR[0, 0, 0].item(), I_D_TENSOR[0, 0, 0].item(), theta_dict), 1e-5) #Calculate deterministic initial conditions.
            print('\nC_0 =', C_0)
            #C0 = C0[(None,) * 2].repeat(BATCH_SIZE, 1, 1).to(DEVICE) #Commenting out because analytical steady state init functions now output tensors with appropriate batch size if argument into MeanField forward function is BATCH_SIZE. #Assign initial conditions to C_PATH.
            C_PATH = torch.cat([C_0.unsqueeze(1), C_PATH], 1) #Append deterministic CON initial conditions conditional on parameter values to C path. 
            print('\nC_PATH =', C_PATH)
            print('\nC_PATH mean =', C_PATH.mean(-2))
            if iter <= PRETRAIN_ITER:
                pretrain_optimizer.zero_grad()
                #l1_norm_element = C_PATH - torch.mean(obs_model.mu, -1)
                #l1_norm = torch.sum(torch.abs(l1_norm_element)).mean()
                #best_loss_norm = l1_norm if l1_norm < best_loss_norm else best_loss_norm
                #norm_losses.append(l1_norm.item())
                l2_norm_element = C_PATH - torch.mean(obs_model.mu, -1)
                l2_norm = torch.sqrt(torch.sum(torch.square(l2_norm_element))).mean()
                best_loss_norm = l2_norm if l2_norm < best_loss_norm else best_loss_norm
                norm_losses.append(l2_norm.item())
                if len(norm_losses) > 10:
                    norm_losses.pop(0)
                if iter % 10 == 0:
                    print(f"\nMoving average norm loss at {iter} iterations is: {sum(norm_losses) / len(norm_losses)}. Best norm loss value is: {best_loss_norm}.")
                    print('\nC_PATH mean =', C_PATH.mean(-2))
                    print('\nC_PATH =', C_PATH)
                #l1_norm.backward()
                l2_norm.backward()
                pretrain_optimizer.step()
            else:
                log_p_theta = priors.log_prob(theta).sum(-1)
                ELBO_optimizer.zero_grad()
                log_lik = calc_log_lik(C_PATH, T_SPAN_TENSOR.to(DEVICE), dt, I_S_TENSOR.to(DEVICE), I_D_TENSOR.to(DEVICE), DRIFT_DIFFUSION, theta_dict, TEMP_GEN, TEMP_REF)
                neg_ELBO = -log_p_theta.mean() + log_q_theta.mean() - log_lik.mean() - obs_model(C_PATH) + log_prob.mean() #From equation 14 of Ryder et al., 2019.
                #neg_ELBO = -log_lik.mean() - obs_model(C_PATH) + log_prob.mean() #Old ELBO computation without joint density optimization.
                print('\nneg_ELBO_mean = ', neg_ELBO)
                best_loss_ELBO = neg_ELBO if neg_ELBO < best_loss_ELBO else best_loss_ELBO
                ELBO_losses.append(neg_ELBO)
                if len(ELBO_losses) > 10:
                    ELBO_losses.pop(0)
                if iter % 10 == 0:
                    print(f"\nMoving average ELBO loss at {iter} iterations is: {sum(ELBO_losses) / len(ELBO_losses)}. Best ELBO loss value is: {best_loss_ELBO}.")
                neg_ELBO.backward()
                ELBO_optimizer.step()
            torch.nn.utils.clip_grad_norm_(net.parameters(), 3.0)
            if iter % 100000 == 0 and iter > 0:
                ELBO_optimizer.param_groups[0]['lr'] *= 0.1
            tq.update()

In [46]:
train(devi, l_r, niter, piter, batch_size, prior_scale_factor, SDEFlow, ObsModel, csv_to_obs_df, 'CON_synthetic_sol_df.csv', 0.1, state_dim_SCON, t, dt, n, t_span_tensor, i_s_tensor, i_d_tensor, drift_diffusion_SCON_C, SCON_C_prior_means, temp_gen, temp_ref, analytical_steady_state_init_CON)



Train Diffusion:   0%|          | 0/20 [00:00<?, ?it/s][A


self.means =  Parameter containing:
tensor([2.0000e-03, 3.3000e-01, 3.3000e-01, 3.3000e-01, 5.0000e-01, 2.5000e-05,
        5.0000e-03, 2.0000e-04, 7.5000e+01, 5.0000e+01, 5.0000e+01, 1.0000e+00,
        1.0000e-02, 1.0000e-01], requires_grad=True)

theta_dict =  {'u_M': tensor([0.0020, 0.0025, 0.0018], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.3826, 0.3257, 0.2473], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.3308, 0.2032, 0.3824], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.1883, 0.3267, 0.3422], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.2446, 0.4189, 0.6453], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([3.1356e-05, 1.2812e-05, 2.2550e-05], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0039, 0.0044, 0.0035], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([0.0002, 0.0002, 0.0002], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([68.8244, 84.2626, 57.6730], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([41.5139, 53.9760, 41.8441], grad_fn=<SqueezeBackward1>), 'Ea_M': 



Train Diffusion:   5%|▌         | 1/20 [00:04<01:23,  4.39s/it][A


theta_dict =  {'u_M': tensor([0.0026, 0.0018, 0.0025], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.3009, 0.4424, 0.3464], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.4915, 0.3570, 0.3932], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.3192, 0.2657, 0.4194], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.4254, 0.5928, 0.6229], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([2.4629e-05, 2.2129e-05, 1.7798e-05], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0048, 0.0048, 0.0049], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([0.0002, 0.0002, 0.0002], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([84.2107, 63.6709, 67.8526], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([59.3475, 54.2353, 54.1754], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([61.8855, 49.7605, 51.6899], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([0.9404, 1.1273, 1.1205], grad_fn=<SqueezeBackward1>), 'c_DOC': tensor([0.0117, 0.0127, 0.0093], grad_fn=<SqueezeBackward1>), 'c_MBC': tensor([0.0751, 0.1089, 0.1281], g



Train Diffusion:  10%|█         | 2/20 [00:07<01:08,  3.83s/it][A


theta_dict =  {'u_M': tensor([0.0021, 0.0025, 0.0017], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.3848, 0.2322, 0.3811], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.4789, 0.3763, 0.2910], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.4690, 0.3053, 0.3525], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.6852, 0.4054, 0.5098], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([1.2496e-05, 2.3817e-05, 3.0567e-05], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0036, 0.0044, 0.0047], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([0.0002, 0.0002, 0.0002], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([72.5217, 92.1215, 74.5550], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([46.6540, 59.7879, 53.7580], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([66.6014, 51.1724, 37.8529], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([0.9792, 0.9915, 1.1275], grad_fn=<SqueezeBackward1>), 'c_DOC': tensor([0.0055, 0.0123, 0.0072], grad_fn=<SqueezeBackward1>), 'c_MBC': tensor([0.1196, 0.1024, 0.0989], g



Train Diffusion:  15%|█▌        | 3/20 [00:10<00:57,  3.40s/it][A


theta_dict =  {'u_M': tensor([0.0054, 0.0051, 0.0022], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.2761, 0.4131, 0.4496], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.2924, 0.2675, 0.2789], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.2946, 0.3987, 0.3827], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.3228, 0.5008, 0.4126], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([0.0010, 0.0010, 0.0010], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0027, 0.0046, 0.0046], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([0.0012, 0.0012, 0.0012], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([ 74.1029, 105.1288,  79.6316], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([51.7568, 52.8430, 75.0106], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([63.3019, 73.1481, 36.0229], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([1.2095, 0.6746, 1.0667], grad_fn=<SqueezeBackward1>), 'c_DOC': tensor([0.0106, 0.0135, 0.0126], grad_fn=<SqueezeBackward1>), 'c_MBC': tensor([0.1382, 0.1218, 0.0906], grad_fn=<S



Train Diffusion:  20%|██        | 4/20 [00:13<00:51,  3.21s/it][A


theta_dict =  {'u_M': tensor([0.0006, 0.0064, 0.0018], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.3583, 0.4374, 0.2907], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.3901, 0.2384, 0.2824], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.4611, 0.2790, 0.3078], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.4598, 0.4492, 0.5157], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([0.0017, 0.0017, 0.0017], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0009, 0.0067, 0.0011], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([0.0012, 0.0012, 0.0012], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([104.1323,  81.4529,  72.9225], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([51.9309, 45.7012, 74.9756], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([44.0806, 58.0966, 61.9146], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([1.1103, 0.7439, 1.1606], grad_fn=<SqueezeBackward1>), 'c_DOC': tensor([0.0138, 0.0107, 0.0111], grad_fn=<SqueezeBackward1>), 'c_MBC': tensor([0.0941, 0.1129, 0.1257], grad_fn=<S



Train Diffusion:  25%|██▌       | 5/20 [00:16<00:47,  3.14s/it][A


theta_dict =  {'u_M': tensor([0.0047, 0.0038, 0.0061], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.2537, 0.2670, 0.3910], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.3614, 0.2842, 0.3609], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.3069, 0.2686, 0.2043], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.2312, 0.5170, 0.6063], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([0.0022, 0.0022, 0.0022], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0043, 0.0083, 0.0056], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([ 0.0010,  0.0013, -0.0002], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([65.7265, 65.9810, 65.4483], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([56.3194, 33.5958, 59.3738], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([63.1084, 55.1167, 54.3606], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([1.2589, 0.9354, 1.0086], grad_fn=<SqueezeBackward1>), 'c_DOC': tensor([0.0118, 0.0126, 0.0127], grad_fn=<SqueezeBackward1>), 'c_MBC': tensor([0.1041, 0.1152, 0.0783], grad_fn=<S



Train Diffusion:  30%|███       | 6/20 [00:19<00:44,  3.20s/it][A


theta_dict =  {'u_M': tensor([0.0050, 0.0043, 0.0049], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.3899, 0.2745, 0.3996], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.2192, 0.4026, 0.2549], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.3740, 0.3371, 0.3274], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.6093, 0.3426, 0.4608], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([0.0026, 0.0026, 0.0026], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0033, 0.0060, 0.0046], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([0.0025, 0.0012, 0.0012], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([86.3265, 55.5956, 64.0247], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([52.7295, 52.2945, 51.1400], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([49.8267, 51.5785, 67.1572], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([0.9822, 1.1820, 0.6678], grad_fn=<SqueezeBackward1>), 'c_DOC': tensor([0.0131, 0.0131, 0.0131], grad_fn=<SqueezeBackward1>), 'c_MBC': tensor([0.1263, 0.0685, 0.1061], grad_fn=<Sque



Train Diffusion:  35%|███▌      | 7/20 [00:22<00:40,  3.13s/it][A


theta_dict =  {'u_M': tensor([0.0055, 0.0059, 0.0054], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.1033, 0.3588, 0.2276], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.4117, 0.2677, 0.3045], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.3777, 0.3526, 0.3630], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.3983, 0.4898, 0.3519], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([0.0029, 0.0029, 0.0029], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0030, 0.0058, 0.0050], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([-0.0008, -0.0033,  0.0020], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([70.8286, 36.6402, 87.2752], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([44.0596, 35.8639, 46.8634], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([53.6996, 35.8523, 47.1079], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([1.1908, 1.2829, 0.9007], grad_fn=<SqueezeBackward1>), 'c_DOC': tensor([0.0137, 0.0137, 0.0137], grad_fn=<SqueezeBackward1>), 'c_MBC': tensor([0.0817, 0.1305, 0.1712], grad_fn=<S



Train Diffusion:  40%|████      | 8/20 [00:25<00:36,  3.05s/it][A


theta_dict =  {'u_M': tensor([0.0060, 0.0058, 0.0060], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.3016, 0.2600, 0.2680], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.2678, 0.3567, 0.3719], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.3246, 0.2676, 0.3400], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.5983, 0.2999, 0.5382], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([0.0031, 0.0031, 0.0031], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0036, 0.0064, 0.0039], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([-0.0021, -0.0006, -0.0003], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([57.1533, 50.6028, 77.7596], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([31.4842, 55.1576, 62.2642], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([44.6571, 47.4306, 34.5639], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([1.3553, 0.6548, 1.1785], grad_fn=<SqueezeBackward1>), 'c_DOC': tensor([0.0141, 0.0142, 0.0143], grad_fn=<SqueezeBackward1>), 'c_MBC': tensor([0.1387, 0.1112, 0.1112], grad_fn=<S



Train Diffusion:  45%|████▌     | 9/20 [00:28<00:33,  3.00s/it][A


theta_dict =  {'u_M': tensor([0.0062, 0.0062, 0.0062], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.2652, 0.3681, 0.3789], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.3780, 0.2825, 0.3589], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.3038, 0.2783, 0.3726], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.5652, 0.4947, 0.5787], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([0.0033, 0.0033, 0.0033], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0072, 0.0058, 0.0068], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([-0.0007,  0.0027,  0.0024], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([51.1794, 92.2540, 86.4188], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([53.1253, 49.4698, 45.8401], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([39.2076, 45.8729, 55.1345], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([0.6449, 1.2962, 0.9539], grad_fn=<SqueezeBackward1>), 'c_DOC': tensor([0.0147, 0.0146, 0.0147], grad_fn=<SqueezeBackward1>), 'c_MBC': tensor([0.0978, 0.1232, 0.0554], grad_fn=<S



Train Diffusion:  50%|█████     | 10/20 [00:31<00:29,  2.96s/it][A


theta_dict =  {'u_M': tensor([0.0069, 0.0067, 0.0066], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.3400, 0.3981, 0.2349], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.4129, 0.2475, 0.3148], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.3314, 0.2695, 0.2901], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.5764, 0.2691, 0.5189], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([0.0034, 0.0034, 0.0034], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0060, 0.0056, 0.0051], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([-0.0026, -0.0010, -0.0036], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([70.4369, 68.3119, 65.6571], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([49.0431, 64.7878, 41.5219], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([49.4812, 40.0822, 37.9402], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([0.9457, 1.2127, 1.1023], grad_fn=<SqueezeBackward1>), 'c_DOC': tensor([0.0152, 0.0151, 0.0151], grad_fn=<SqueezeBackward1>), 'c_MBC': tensor([0.0795, 0.0994, 0.0641], grad_fn=<S



Train Diffusion:  55%|█████▌    | 11/20 [00:34<00:26,  2.97s/it][A


theta_dict =  {'u_M': tensor([0.0061, 0.0068, 0.0071], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.4093, 0.2773, 0.2908], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.3117, 0.3911, 0.3433], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.4079, 0.3704, 0.3690], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.3629, 0.5787, 0.4697], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([0.0035, 0.0035, 0.0035], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0063, 0.0061, 0.0057], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([-0.0010,  0.0010, -0.0050], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([87.5604, 72.8392, 55.4077], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([68.8082, 59.7781, 54.9345], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([31.7388, 36.7008, 51.3445], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([1.1421, 1.1612, 1.3039], grad_fn=<SqueezeBackward1>), 'c_DOC': tensor([0.0157, 0.0156, 0.0156], grad_fn=<SqueezeBackward1>), 'c_MBC': tensor([0.0773, 0.0663, 0.0849], grad_fn=<S



Train Diffusion:  60%|██████    | 12/20 [00:37<00:23,  2.95s/it][A


theta_dict =  {'u_M': tensor([0.0052, 0.0073, 0.0057], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.3404, 0.4388, 0.2710], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.2680, 0.3916, 0.3107], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.2972, 0.4732, 0.3022], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.6379, 0.5411, 0.4977], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([0.0036, 0.0036, 0.0036], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0055, 0.0064, 0.0062], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([ 0.0079,  0.0053, -0.0013], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([76.3335, 41.9329, 62.9047], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([54.5048, 65.7630, 50.2730], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([63.5116, 43.9257, 34.2162], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([0.3584, 0.8178, 1.1641], grad_fn=<SqueezeBackward1>), 'c_DOC': tensor([0.0160, 0.0163, 0.0159], grad_fn=<SqueezeBackward1>), 'c_MBC': tensor([0.0726, 0.1391, 0.1356], grad_fn=<S



Train Diffusion:  65%|██████▌   | 13/20 [00:40<00:20,  2.93s/it][A


theta_dict =  {'u_M': tensor([0.0066, 0.0073, 0.0064], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.2508, 0.2570, 0.3779], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.1094, 0.2945, 0.4040], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.3334, 0.3152, 0.4436], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.6252, 0.4889, 0.3658], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([0.0036, 0.0036, 0.0036], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0062, 0.0062, 0.0064], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([ 3.2508e-05, -2.3947e-03,  1.3076e-03], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([54.1589, 46.3608, 90.9043], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([46.5046, 46.5814, 22.7039], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([19.7028, 36.3432, 55.4218], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([1.1606, 1.1582, 0.8848], grad_fn=<SqueezeBackward1>), 'c_DOC': tensor([0.0165, 0.0164, 0.0165], grad_fn=<SqueezeBackward1>), 'c_MBC': tensor([0.0987, 0.0571, 0.1368]



Train Diffusion:  70%|███████   | 14/20 [00:43<00:17,  2.91s/it][A


theta_dict =  {'u_M': tensor([0.0048, 0.0064, 0.0055], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.3347, 0.3583, 0.3715], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.3190, 0.2990, 0.3473], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.4374, 0.3542, 0.3715], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.5471, 0.4321, 0.6308], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([0.0036, 0.0036, 0.0036], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0068, 0.0068, 0.0068], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([ 0.0028, -0.0056,  0.0097], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([58.8059, 93.9638, 62.6865], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([55.6986, 60.2831, 61.1648], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([68.4312, 31.3172, 51.8481], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([0.7086, 0.9970, 0.7229], grad_fn=<SqueezeBackward1>), 'c_DOC': tensor([0.0169, 0.0168, 0.0170], grad_fn=<SqueezeBackward1>), 'c_MBC': tensor([0.0934, 0.1335, 0.1346], grad_fn=<S



Train Diffusion:  75%|███████▌  | 15/20 [00:46<00:14,  3.00s/it][A


theta_dict =  {'u_M': tensor([0.0070, 0.0044, 0.0075], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.2462, 0.3950, 0.3386], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.2254, 0.2518, 0.2961], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.2931, 0.2356, 0.2912], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.5110, 0.3382, 0.2476], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([0.0035, 0.0035, 0.0035], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0072, 0.0072, 0.0069], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([-0.0013,  0.0033,  0.0043], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([59.4217, 63.7127, 42.9244], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([25.1559, 58.5430, 50.4697], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([60.0166, 43.8823, 69.9030], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([0.9728, 1.1216, 1.4043], grad_fn=<SqueezeBackward1>), 'c_DOC': tensor([0.0174, 0.0173, 0.0173], grad_fn=<SqueezeBackward1>), 'c_MBC': tensor([0.1148, 0.1163, 0.0906], grad_fn=<S



Train Diffusion:  80%|████████  | 16/20 [00:49<00:11,  2.97s/it][A


theta_dict =  {'u_M': tensor([0.0057, 0.0055, 0.0080], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.3080, 0.2106, 0.2462], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.4122, 0.3386, 0.2203], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.3645, 0.2622, 0.4132], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.6685, 0.6711, 0.4854], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([0.0035, 0.0035, 0.0035], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0062, 0.0082, 0.0066], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([-0.0080, -0.0004, -0.0021], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([ 84.1558,  68.7681, 101.4191], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([72.4512, 64.9997, 41.1190], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([72.0215, 43.3870, 31.3825], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([0.3281, 0.9595, 1.2842], grad_fn=<SqueezeBackward1>), 'c_DOC': tensor([0.0176, 0.0175, 0.0176], grad_fn=<SqueezeBackward1>), 'c_MBC': tensor([0.1340, 0.1096, 0.1321], grad_fn



Train Diffusion:  85%|████████▌ | 17/20 [00:52<00:09,  3.07s/it][A


theta_dict =  {'u_M': tensor([0.0059, 0.0034, 0.0072], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.2582, 0.3119, 0.3318], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.3259, 0.2305, 0.3531], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.3927, 0.3828, 0.2759], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.6457, 0.3911, 0.5995], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([0.0034, 0.0034, 0.0034], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0083, 0.0067, 0.0075], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([-0.0034,  0.0032, -0.0068], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([119.9724,  97.8644,  65.6210], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([51.9786, 66.0006, 43.4749], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([57.4737, 62.3338, 45.9963], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([1.3737, 1.0898, 0.6281], grad_fn=<SqueezeBackward1>), 'c_DOC': tensor([0.0180, 0.0180, 0.0178], grad_fn=<SqueezeBackward1>), 'c_MBC': tensor([0.1345, 0.1030, 0.1166], grad_fn



Train Diffusion:  90%|█████████ | 18/20 [00:55<00:06,  3.03s/it][A


theta_dict =  {'u_M': tensor([0.0039, 0.0079, 0.0032], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.1864, 0.3790, 0.3983], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.3412, 0.4060, 0.3682], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.3430, 0.3596, 0.1842], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.1985, 0.5674, 0.5558], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([0.0033, 0.0033, 0.0033], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0089, 0.0088, 0.0088], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([-0.0014, -0.0037, -0.0004], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([63.3714, 75.3852, 94.7993], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([43.5556, 45.4817, 50.5184], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([49.7568, 51.6728, 49.9340], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([1.2622, 0.8134, 1.2069], grad_fn=<SqueezeBackward1>), 'c_DOC': tensor([0.0183, 0.0184, 0.0183], grad_fn=<SqueezeBackward1>), 'c_MBC': tensor([0.1077, 0.1281, 0.1324], grad_fn=<S



Train Diffusion:  95%|█████████▌| 19/20 [00:59<00:03,  3.15s/it][A


theta_dict =  {'u_M': tensor([0.0006, 0.0069, 0.0094], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.3315, 0.3103, 0.4193], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.4148, 0.3116, 0.4803], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.2676, 0.1894, 0.2687], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.4122, 0.3256, 0.3986], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([0.0032, 0.0032, 0.0032], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0103, 0.0044, 0.0076], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([ 0.0036, -0.0088,  0.0030], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([88.7469, 85.2899, 62.9724], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([32.9099, 23.9865, 58.0463], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([40.3350, 43.8602, 50.0882], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([0.7891, 1.1378, 0.7736], grad_fn=<SqueezeBackward1>), 'c_DOC': tensor([0.0187, 0.0187, 0.0188], grad_fn=<SqueezeBackward1>), 'c_MBC': tensor([0.1089, 0.1128, 0.1176], grad_fn=<S



Train Diffusion: 100%|██████████| 20/20 [01:02<00:00,  3.20s/it][A
Train Diffusion: 100%|██████████| 20/20 [01:02<00:00,  3.12s/it]


In [18]:
t_span_tensor.size()

torch.Size([1, 2501, 1])

In [19]:
prior_means_tensor = torch.Tensor(list(SCON_C_prior_means.values()))
priors = D.normal.Normal(prior_means_tensor, prior_means_tensor * prior_scale_factor)
q_theta = MeanField(SCON_C_prior_means, prior_scale_factor)
theta_dict, theta, log_q_theta = q_theta(batch_size)

In [20]:
theta_dict

{'u_M': tensor([0.0016, 0.0018, 0.0026], grad_fn=<SqueezeBackward1>),
 'a_SD': tensor([0.4986, 0.2236, 0.2067], grad_fn=<SqueezeBackward1>),
 'a_DS': tensor([0.4280, 0.4224, 0.1910], grad_fn=<SqueezeBackward1>),
 'a_M': tensor([0.3747, 0.3460, 0.3120], grad_fn=<SqueezeBackward1>),
 'a_MSC': tensor([0.4961, 0.5466, 0.4423], grad_fn=<SqueezeBackward1>),
 'k_S_ref': tensor([2.0269e-05, 2.4345e-05, 1.7274e-05], grad_fn=<SqueezeBackward1>),
 'k_D_ref': tensor([0.0044, 0.0054, 0.0066], grad_fn=<SqueezeBackward1>),
 'k_M_ref': tensor([0.0002, 0.0002, 0.0002], grad_fn=<SqueezeBackward1>),
 'Ea_S': tensor([84.5666, 72.5103, 81.0127], grad_fn=<SqueezeBackward1>),
 'Ea_D': tensor([54.5815, 40.6330, 63.7570], grad_fn=<SqueezeBackward1>),
 'Ea_M': tensor([34.0350, 41.4008, 56.2262], grad_fn=<SqueezeBackward1>),
 'c_SOC': tensor([1.1623, 1.4763, 0.6716], grad_fn=<SqueezeBackward1>),
 'c_DOC': tensor([0.0138, 0.0127, 0.0106], grad_fn=<SqueezeBackward1>),
 'c_MBC': tensor([0.1381, 0.0956, 0.1080], gra

In [21]:
obs_times, obs_means, obs_error = csv_to_obs_df('CON_synthetic_sol_df.csv', 3 + 1, t, obs_error_scale_factor)
obs_model = ObsModel(devi, obs_times, dt, obs_means[:-1, :], obs_error[:, :-1])
net = SDEFlow(devi, batch_size, obs_model, 3, t, dt, n).to(devi)

In [22]:
C_PATH, log_prob = net()
print(C_PATH)
C_PATH.size()

tensor([[[0.7703, 0.8147, 0.7190],
         [0.3577, 0.6447, 0.7709],
         [0.5693, 0.3108, 0.8354],
         ...,
         [0.8004, 0.8717, 1.5919],
         [0.7038, 0.8529, 1.4649],
         [0.7325, 0.7718, 0.8040]],

        [[0.8571, 0.8375, 0.7950],
         [0.2069, 0.7113, 0.9912],
         [0.5616, 0.5460, 0.8096],
         ...,
         [0.7837, 0.7142, 1.3335],
         [0.9051, 0.8525, 1.5034],
         [0.6693, 0.6562, 4.9073]],

        [[0.5143, 0.5829, 0.7754],
         [1.4471, 0.6817, 0.8258],
         [0.7452, 0.8609, 0.9399],
         ...,
         [0.8733, 0.7922, 1.5084],
         [0.7237, 0.9007, 1.5550],
         [0.7362, 0.8360, 7.2313]]], grad_fn=<AddBackward0>)


torch.Size([3, 2500, 3])

In [47]:
C_0 = analytical_steady_state_init_CON(i_s_tensor[0, 0, 0].item(), i_d_tensor[0, 0, 0].item(), theta_dict)
print(C_0)
C_0.unsqueeze(1).size()

tensor([[6.3284e+01, 1.3084e-01, 9.9657e-01],
        [4.6585e+01, 5.1670e-02, 4.4936e-01],
        [6.1308e+01, 3.6414e-02, 4.5395e-01]], grad_fn=<StackBackward>)


torch.Size([3, 1, 3])

In [24]:
C_PATH = torch.cat([C_0.unsqueeze(1), C_PATH], 1)
C_PATH

tensor([[[6.3284e+01, 1.3084e-01, 9.9657e-01],
         [7.7030e-01, 8.1470e-01, 7.1902e-01],
         [3.5773e-01, 6.4470e-01, 7.7092e-01],
         ...,
         [8.0044e-01, 8.7174e-01, 1.5919e+00],
         [7.0381e-01, 8.5288e-01, 1.4649e+00],
         [7.3249e-01, 7.7180e-01, 8.0398e-01]],

        [[4.6585e+01, 5.1670e-02, 4.4936e-01],
         [8.5708e-01, 8.3745e-01, 7.9501e-01],
         [2.0688e-01, 7.1126e-01, 9.9121e-01],
         ...,
         [7.8371e-01, 7.1421e-01, 1.3335e+00],
         [9.0510e-01, 8.5253e-01, 1.5034e+00],
         [6.6934e-01, 6.5619e-01, 4.9073e+00]],

        [[6.1308e+01, 3.6414e-02, 4.5395e-01],
         [5.1430e-01, 5.8294e-01, 7.7539e-01],
         [1.4471e+00, 6.8167e-01, 8.2578e-01],
         ...,
         [8.7326e-01, 7.9220e-01, 1.5084e+00],
         [7.2368e-01, 9.0072e-01, 1.5550e+00],
         [7.3621e-01, 8.3604e-01, 7.2313e+00]]], grad_fn=<CatBackward>)

In [25]:
SOC, DOC, MBC =  torch.chunk(C_PATH, 3, -1)
SOC.size()

torch.Size([3, 2501, 1])

In [26]:
current_temp = temp_gen(t_span_tensor, temp_ref)

In [27]:
theta_dict['k_S_ref'].size()

torch.Size([3])

In [28]:
k_S = arrhenius_temp_dep(theta_dict['k_S_ref'], current_temp, theta_dict['Ea_S'], temp_ref)
k_S = k_S.permute(2, 1, 0)
k_D = arrhenius_temp_dep(theta_dict['k_D_ref'], current_temp, theta_dict['Ea_D'], temp_ref)
k_D = k_D.permute(2, 1, 0)
k_M = arrhenius_temp_dep(theta_dict['k_M_ref'], current_temp, theta_dict['Ea_M'], temp_ref)
k_M = k_M.permute(2, 1, 0)

In [29]:
theta_dict_repeat = dict((k, v.repeat(1, 2501, 1).permute(2, 1, 0)) for k, v in theta_dict.items())
drift_SOC = i_s_tensor + theta_dict_repeat['a_DS'] * k_D * DOC + theta_dict_repeat['a_M'] * theta_dict_repeat['a_MSC'] * k_M * MBC - k_S * SOC
drift_DOC = i_d_tensor + theta_dict_repeat['a_SD'] * k_S * SOC + theta_dict_repeat['a_M'] * (1 - theta_dict_repeat['a_MSC']) * k_M * MBC - (theta_dict_repeat['u_M'] + k_D) * DOC
drift_MBC = theta_dict_repeat['u_M'] * DOC - k_M * MBC

In [30]:
drift_MBC.size()

torch.Size([3, 2501, 1])

In [31]:
theta_dict_u_M_test = theta_dict['u_M'].repeat(1, 2501, 1)
theta_dict_u_M_test = theta_dict_u_M_test.permute(2, 1, 0)
print(theta_dict_u_M_test.size())
test = theta_dict_u_M_test * DOC
print(test.size())

torch.Size([3, 2501, 1])
torch.Size([3, 2501, 1])


In [32]:
C_PATH.size()

torch.Size([3, 2501, 3])

In [33]:
drift = torch.empty_like(C_PATH, device = C_PATH.device) #Initiate tensor with same dims as C_PATH to assign drift.
drift[:, :, 0 : 1] = drift_SOC
drift[:, :, 1 : 2] = drift_DOC
drift[:, :, 2 : 3] = drift_MBC

In [34]:
test_m = torch.stack([theta_dict['c_SOC'], theta_dict['c_DOC'], theta_dict['c_MBC']], 1)
print(LowerBound.apply(test_m, 1e-6))
test_m_sqrt = torch.sqrt(test_m)
torch.diag_embed(test_m_sqrt)

tensor([[1.1623, 0.0138, 0.1381],
        [1.4763, 0.0127, 0.0956],
        [0.6716, 0.0106, 0.1080]], grad_fn=<LowerBoundBackward>)


tensor([[[1.0781, 0.0000, 0.0000],
         [0.0000, 0.1175, 0.0000],
         [0.0000, 0.0000, 0.3717]],

        [[1.2150, 0.0000, 0.0000],
         [0.0000, 0.1127, 0.0000],
         [0.0000, 0.0000, 0.3093]],

        [[0.8195, 0.0000, 0.0000],
         [0.0000, 0.1029, 0.0000],
         [0.0000, 0.0000, 0.3286]]], grad_fn=<CopySlices>)

In [35]:
diffusion_sqrt_single = torch.diag_embed(torch.sqrt(LowerBound.apply(torch.stack([theta_dict['c_SOC'], theta_dict['c_DOC'], theta_dict['c_MBC']], 1), 1e-6)))
diffusion_sqrt_single

tensor([[[1.0781, 0.0000, 0.0000],
         [0.0000, 0.1175, 0.0000],
         [0.0000, 0.0000, 0.3717]],

        [[1.2150, 0.0000, 0.0000],
         [0.0000, 0.1127, 0.0000],
         [0.0000, 0.0000, 0.3093]],

        [[0.8195, 0.0000, 0.0000],
         [0.0000, 0.1029, 0.0000],
         [0.0000, 0.0000, 0.3286]]], grad_fn=<CopySlices>)

In [36]:
diffusion_sqrt_single.unsqueeze(1).expand(-1, 2501, -1, -1)

tensor([[[[1.0781, 0.0000, 0.0000],
          [0.0000, 0.1175, 0.0000],
          [0.0000, 0.0000, 0.3717]],

         [[1.0781, 0.0000, 0.0000],
          [0.0000, 0.1175, 0.0000],
          [0.0000, 0.0000, 0.3717]],

         [[1.0781, 0.0000, 0.0000],
          [0.0000, 0.1175, 0.0000],
          [0.0000, 0.0000, 0.3717]],

         ...,

         [[1.0781, 0.0000, 0.0000],
          [0.0000, 0.1175, 0.0000],
          [0.0000, 0.0000, 0.3717]],

         [[1.0781, 0.0000, 0.0000],
          [0.0000, 0.1175, 0.0000],
          [0.0000, 0.0000, 0.3717]],

         [[1.0781, 0.0000, 0.0000],
          [0.0000, 0.1175, 0.0000],
          [0.0000, 0.0000, 0.3717]]],


        [[[1.2150, 0.0000, 0.0000],
          [0.0000, 0.1127, 0.0000],
          [0.0000, 0.0000, 0.3093]],

         [[1.2150, 0.0000, 0.0000],
          [0.0000, 0.1127, 0.0000],
          [0.0000, 0.0000, 0.3093]],

         [[1.2150, 0.0000, 0.0000],
          [0.0000, 0.1127, 0.0000],
          [0.0000, 0.0000, 0.30