In [17]:
from SBM_SDE_tensor import *
from obs_and_flow_classes_and_functions import *
import seaborn as sns
import torch
from torch import nn
import torch.distributions as D
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import math
from tqdm import tqdm
import random
from torch.autograd import Function
import argparse
import os
import sys
from pathlib import Path
import shutil
import pandas as pd

In [18]:
torch.manual_seed(0)
devi = torch.device("".join(["cuda:",f'{cuda_id}']) if torch.cuda.is_available() else "cpu")

cuda_id = 1
dt = .2 #SDE discretization timestep.
t = 500 #Simulation run for T hours.
n = int(t / dt) 
t_span = np.linspace(0, t, n + 1)
t_span_tensor = torch.reshape(torch.Tensor(t_span), [1, n + 1, 1]) #T_span needs to be converted to tensor object. Additionally, facilitates conversion of I_S and I_D to tensor objects.
l_r = 1e-3
niter = 100
piter = 5
batch_size = 4 #Number of sets of observation outputs to sample per set of parameters.
state_dim_SCON = 3 #Not including CO2 in STATE_DIM, because CO2 is an observation.
state_dim_SAWB = 4 #Not including CO2 in STATE_DIM, because CO2 is an observation.
prior_scale_factor = 0.25 #Prior standard deviations set at 0.25 of prior means.
obs_error_scale_factor = 0.25

In [19]:
temp_ref = 283

#System parameters from deterministic CON model
u_M = 0.002
a_SD = 0.33
a_DS = 0.33
a_M = 0.33
a_MSC = 0.5
k_S_ref = 0.000025
k_D_ref = 0.005
k_M_ref = 0.0002
Ea_S = 75
Ea_D = 50
Ea_M = 50

#SCON diffusion matrix sigma scale parameters
c_SOC = 1.
c_DOC = 0.01
c_MBC = 0.1
s_SOC = 0.1
s_DOC = 0.1
s_MBC = 0.1

SCON_C_prior_means = {'u_M': u_M, 'a_SD': a_SD, 'a_DS': a_DS, 'a_M': a_M, 'a_MSC': a_MSC, 'k_S_ref': k_S_ref, 'k_D_ref': k_D_ref, 'k_M_ref': k_M_ref, 'Ea_S': Ea_S, 'Ea_D': Ea_D, 'Ea_M': Ea_M, 'c_SOC': c_SOC, 'c_DOC': c_DOC, 'c_MBC': c_MBC}
SCON_SS_prior_means = {'u_M': u_M, 'a_SD': a_SD, 'a_DS': a_DS, 'a_M': a_M, 'a_MSC': a_MSC, 'k_S_ref': k_S_ref, 'k_D_ref': k_D_ref, 'k_M_ref': k_M_ref, 'Ea_S': Ea_S, 'Ea_D': Ea_D, 'Ea_M': Ea_M, 's_SOC': s_SOC, 's_DOC': s_DOC, 's_MBC': s_MBC}

#System parameters from deterministic AWB model
u_Q_ref = 0.2
Q = 0.002
a_MSA = 0.5
K_D = 200
K_U = 1
V_D_ref = 0.4
V_U_ref = 0.02
Ea_V_D = 75
Ea_V_U = 50
r_M = 0.0004
r_E = 0.00001
r_L = 0.0005

#SAWB diffusion matrix sigma scale parameters
c_SOC = 1.
c_DOC = 0.01
c_MBC = 0.1
c_EEC = 0.001
s_SOC = 0.1
s_DOC = 0.1
s_MBC = 0.1
s_EEC = 0.1

SAWB_C_prior_means = {'u_Q_ref': u_Q_ref, 'Q': Q, 'a_MSA': a_MSA, 'K_D': K_D, 'K_U': K_U, 'V_D_ref': V_D_ref, 'V_U_ref': V_U_ref, 'Ea_V_D': Ea_V_D, 'Ea_V_U': Ea_V_U, 'r_M': r_M, 'r_E': r_E, 'r_L': r_L, 'c_SOC': c_SOC, 'c_DOC': c_DOC, 'c_MBC': c_MBC, 'c_EEC': c_EEC}
SAWB_SS_prior_means = {'u_Q_ref': u_Q_ref, 'Q': Q, 'a_MSA': a_MSA, 'K_D': K_D, 'K_U': K_U, 'V_D_ref': V_D_ref, 'V_U_ref': V_U_ref, 'Ea_V_D': Ea_V_D, 'Ea_V_U': Ea_V_U, 'r_M': r_M, 'r_E': r_E, 'r_L': r_L, 's_SOC': s_SOC, 's_DOC': s_DOC, 's_MBC': s_MBC, 's_EEC': s_EEC}

#System parameters from deterministic model
u_Q_ref = 0.2
Q = 0.002
a_MSA = 0.5
K_DE = 200
K_UE = 1
V_DE_ref = 0.4
V_UE_ref = 0.02
Ea_V_DE = 75
Ea_V_UE = 50
r_M = 0.0004
r_E = 0.00001
r_L = 0.0005

#Diffusion matrix sigma scale parameters
c_SOC = 1.
c_DOC = 0.01
c_MBC = 0.1
c_EEC = 0.001
s_SOC = 0.1
s_DOC = 0.1
s_MBC = 0.1
s_EEC = 0.1

SAWB_ECA_C_prior_means = {'u_Q_ref': u_Q_ref, 'Q': Q, 'a_MSA': a_MSA, 'K_DE': K_DE, 'K_UE': K_UE, 'V_DE_ref': V_DE_ref, 'V_UE_ref': V_UE_ref, 'Ea_V_DE': Ea_V_DE, 'Ea_V_UE': Ea_V_UE, 'r_M': r_M, 'r_E': r_E, 'r_L': r_L, 'c_SOC': c_SOC, 'c_DOC': c_DOC, 'c_MBC': c_MBC, 'c_EEC': c_EEC}
SAWB_ECA_SS_prior_means = {'u_Q_ref': u_Q_ref, 'Q': Q, 'a_MSA': a_MSA, 'K_DE': K_DE, 'K_UE': K_UE, 'V_DE_ref': V_DE_ref, 'V_UE_ref': V_UE_ref, 'Ea_V_DE': Ea_V_DE, 'Ea_V_UE': Ea_V_UE, 'r_M': r_M, 'r_E': r_E, 'r_L': r_L, 's_SOC': s_SOC, 's_DOC': s_DOC, 's_MBC': s_MBC, 's_EEC': s_EEC}

In [20]:
#Obtain SOC and DOC pool litter inputs for all SBMs.
i_s_tensor = 0.001 + 0.0005 * torch.sin((2 * np.pi / (24 * 365)) * t_span_tensor) #Exogenous SOC input function
i_d_tensor = 0.0001 + 0.00005 * torch.sin((2 * np.pi / (24 * 365)) * t_span_tensor) #Exogenous DOC input function

In [21]:
#Mean field VI code block.

#Define mean-field class: consumes parameter dictionary with values used as initial mean values.
class MeanField(nn.Module):
    def __init__(self, init_params, sdev_scale_factor):
        super().__init__()

        #Use param dict to intialise the means for the mean-field approximations.
        means = []
        keys = []
        for key, value in init_params.items():
            keys += [key]
            means += [value]
        self.means = nn.Parameter(torch.Tensor(means))
        self.sds = nn.Parameter(self.means * sdev_scale_factor)
        #Save keys for forward output.
        self.keys = keys

    def forward(self, n = 30):
        #Update posterior.
        q_dist = D.normal.Normal(LowerBound.apply(self.means, 1e-6), LowerBound.apply(self.sds, 1e-8))
        #q_dist = D.log_normal.LogNormal(LowerBound.apply(self.means, 1e-6), LowerBound.apply(self.sds, 1e-8)) #Testing LogNormal sampling object.
        #Sample theta ~ q(theta).
        samples = LowerBound.apply(q_dist.rsample([n]), 1e-6)
        #samples_log = LowerBound.apply(q_dist.rsample([n]), 1e-6)
        #samples = torch.exp(samples_log) #Exp transformation of log sample SDE theta values for SDE use.
        #Evaluate log prob of theta samples.
        log_q_theta = torch.sum(q_dist.log_prob(samples), -1) #Shape of n.
        #log_q_theta = torch.sum(q_dist.log_prob(samples_log), -1) #Shape of n.
        #Return samples in same dictionary format.
        dict_out = {} #Define dictionary with n samples for each parameter.
        for key, sample in zip(self.keys, torch.split(samples, 1, -1),):
            dict_out[f"{key}"] = sample.squeeze(1) #Each sample is of shape [n].
        #Return samples in dictionary and tensor format.
        return dict_out, samples, log_q_theta

In [22]:
def calc_log_lik(C_PATH, T_SPAN_TENSOR, DT, I_S_TENSOR, I_D_TENSOR, DRIFT_DIFFUSION, PARAMS_DICT, TEMP_GEN, TEMP_REF):
    drift, diffusion_sqrt = DRIFT_DIFFUSION(C_PATH[:, :-1, :], T_SPAN_TENSOR[:, :-1, :], I_S_TENSOR[:, :-1, :], I_D_TENSOR[:, :-1, :], PARAMS_DICT, TEMP_GEN, TEMP_REF)
    print('\nDrift = ', drift)
    print('\nDiffusion = ', diffusion_sqrt)
    euler_maruyama_state_sample_object = D.multivariate_normal.MultivariateNormal(loc = C_PATH[:, :-1, :] + drift * DT, scale_tril = diffusion_sqrt * math.sqrt(DT))
    return euler_maruyama_state_sample_object.log_prob(C_PATH[:, 1:, :]).sum(-1)

In [23]:
def train(DEVICE, L_R, NITER, PRETRAIN_ITER, BATCH_SIZE, PRIOR_SCALE_FACTOR, SDEFLOW, ObsModel, csv_to_obs_df, DATA_CSV, OBS_ERROR_SCALE_FACTOR, STATE_DIM, T, DT, N, T_SPAN_TENSOR, I_S_TENSOR, I_D_TENSOR, DRIFT_DIFFUSION, PARAM_PRIOR_MEANS_DICT, TEMP_GEN, TEMP_REF, ANALYTICAL_STEADY_STATE_INIT):
    if PRETRAIN_ITER >= NITER:
        raise Exception("PRETRAIN_ITER must be < NITER.")
    obs_times, obs_means, obs_error = csv_to_obs_df(DATA_CSV, STATE_DIM + 1, T, OBS_ERROR_SCALE_FACTOR) 
    obs_model = ObsModel(DEVICE, obs_times, DT, obs_means[:-1, :], obs_error[:, :-1]) #Hack for bypassing ObsModel and SDEFlow dimension mismatch issue.
    net = SDEFlow(DEVICE, BATCH_SIZE, obs_model, STATE_DIM, T, DT, N).to(DEVICE)
    prior_means_tensor = torch.Tensor(list(PARAM_PRIOR_MEANS_DICT.values()))
    priors = D.normal.Normal(prior_means_tensor, prior_means_tensor * PRIOR_SCALE_FACTOR)
    q_theta = MeanField(PARAM_PRIOR_MEANS_DICT, PRIOR_SCALE_FACTOR)
    pretrain_optimizer = optim.Adamax(net.parameters(), lr = L_R, eps = 1e-7)
    ELBO_optimizer = optim.Adam(list(net.parameters()) + list(q_theta.parameters()), lr = L_R)
    best_loss_norm = 1e10
    best_loss_ELBO = 1e20
    norm_losses = [best_loss_norm] * 10
    ELBO_losses = [best_loss_ELBO] * 10
    with tqdm(total = NITER, desc = f'\nTrain Diffusion', position = -1) as tq:
        for iter in range(NITER):
            net.train()
            C_PATH, log_prob = net() #Obtain paths with solutions at times after t0.
            theta_dict, theta, log_q_theta = q_theta(BATCH_SIZE)
            print('\ntheta_dict = ', theta_dict)
            C_0 = LowerBound.apply(ANALYTICAL_STEADY_STATE_INIT(I_S_TENSOR[0, 0, 0].item(), I_D_TENSOR[0, 0, 0].item(), theta_dict), 1e-7) #Calculate deterministic initial conditions.
            print('\nC_0 =', C_0)
            #C0 = C0[(None,) * 2].repeat(BATCH_SIZE, 1, 1).to(DEVICE) #Commenting out because analytical steady state init functions now output tensors with appropriate batch size if argument into MeanField forward function is BATCH_SIZE. #Assign initial conditions to C_PATH.
            C_PATH = torch.cat([C_0.unsqueeze(1), C_PATH], 1) #Append deterministic CON initial conditions conditional on parameter values to C path. 
            print('\nC_PATH =', C_PATH)
            print('\nC_PATH mean =', C_PATH.mean(-2))
            if iter <= PRETRAIN_ITER:
                pretrain_optimizer.zero_grad()
                #l1_norm_element = C_PATH - torch.mean(obs_model.mu, -1)
                #l1_norm = torch.sum(torch.abs(l1_norm_element)).mean()
                #best_loss_norm = l1_norm if l1_norm < best_loss_norm else best_loss_norm
                #norm_losses.append(l1_norm.item())
                l2_norm_element = C_PATH - torch.mean(obs_model.mu, -1)
                l2_norm = torch.sqrt(torch.sum(torch.square(l2_norm_element))).mean()
                best_loss_norm = l2_norm if l2_norm < best_loss_norm else best_loss_norm
                norm_losses.append(l2_norm.item())
                if len(norm_losses) > 10:
                    norm_losses.pop(0)
                if iter % 10 == 0:
                    print(f"\nMoving average norm loss at {iter} iterations is: {sum(norm_losses) / len(norm_losses)}. Best norm loss value is: {best_loss_norm}.")
                    print('\nC_PATH mean =', C_PATH.mean(-2))
                    print('\nC_PATH =', C_PATH)
                #l1_norm.backward()
                l2_norm.backward()
                pretrain_optimizer.step()
            else:
                log_p_theta = priors.log_prob(theta).sum(-1)
                ELBO_optimizer.zero_grad()
                log_lik = calc_log_lik(C_PATH, T_SPAN_TENSOR.to(DEVICE), dt, I_S_TENSOR.to(DEVICE), I_D_TENSOR.to(DEVICE), DRIFT_DIFFUSION, theta_dict, TEMP_GEN, TEMP_REF)
                neg_ELBO = -log_p_theta.mean() + log_q_theta.mean() - log_lik.mean() - obs_model(C_PATH) + log_prob.mean() #From equation 14 of Ryder et al., 2019.
                #neg_ELBO = -log_lik.mean() - obs_model(C_PATH) + log_prob.mean() #Old ELBO computation without joint density optimization.
                print('\nneg_ELBO_mean = ', neg_ELBO)
                best_loss_ELBO = neg_ELBO if neg_ELBO < best_loss_ELBO else best_loss_ELBO
                ELBO_losses.append(neg_ELBO)
                if len(ELBO_losses) > 10:
                    ELBO_losses.pop(0)
                if iter % 10 == 0:
                    print(f"\nMoving average ELBO loss at {iter} iterations is: {sum(ELBO_losses) / len(ELBO_losses)}. Best ELBO loss value is: {best_loss_ELBO}.")
                neg_ELBO.backward()
                ELBO_optimizer.step()
            torch.nn.utils.clip_grad_norm_(net.parameters(), 3.0)
            if iter % 100000 == 0 and iter > 0:
                ELBO_optimizer.param_groups[0]['lr'] *= 0.1
            tq.update()

In [None]:
train(devi, l_r, niter, piter, batch_size, prior_scale_factor, SDEFlow, ObsModel, csv_to_obs_df, 'CON_synthetic_sol_df.csv', 0.1, state_dim_SCON, t, dt, n, t_span_tensor, i_s_tensor, i_d_tensor, drift_diffusion_SCON_C, SCON_C_prior_means, temp_gen, temp_ref, analytical_steady_state_init_CON)



Train Diffusion:   0%|          | 0/100 [00:00<?, ?it/s][A


theta_dict =  {'u_M': tensor([0.0019, 0.0017, 0.0023, 0.0015], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.1630, 0.3559, 0.4440, 0.3448], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.1850, 0.2362, 0.3637, 0.4131], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.3385, 0.4074, 0.1838, 0.2952], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.5613, 0.4229, 0.4906, 0.3998], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([2.2214e-05, 1.7081e-05, 2.5876e-05, 2.9317e-05],
       grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0043, 0.0068, 0.0062, 0.0045], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([0.0002, 0.0002, 0.0001, 0.0002], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([67.5441, 77.0345, 90.4182, 89.8730], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([53.1042, 41.5767, 58.1754, 49.2347], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([37.6380, 73.8201, 66.0964, 39.2428], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([0.6889, 1.0593, 1.0529, 0.9332], grad_fn=<SqueezeBackward1>), 'c_D



Train Diffusion:   1%|          | 1/100 [00:04<07:20,  4.45s/it][A


theta_dict =  {'u_M': tensor([0.0023, 0.0020, 0.0022, 0.0024], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.3193, 0.2961, 0.4017, 0.3978], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.3032, 0.2692, 0.3255, 0.2112], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.3602, 0.1746, 0.3883, 0.3251], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.4128, 0.2403, 0.6219, 0.3587], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([1.8799e-05, 3.9144e-05, 1.4167e-05, 2.2314e-05],
       grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0060, 0.0048, 0.0032, 0.0034], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([0.0002, 0.0002, 0.0002, 0.0002], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([67.5699, 84.4437, 71.3358, 87.8786], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([54.8305, 44.2855, 39.9967, 46.0273], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([50.3059, 43.5638, 51.1576, 56.1779], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([0.9113, 0.5562, 1.6316, 1.3007], grad_fn=<SqueezeBackward1>), 'c_D



Train Diffusion:   2%|▏         | 2/100 [00:08<06:53,  4.22s/it][A


theta_dict =  {'u_M': tensor([0.0014, 0.0018, 0.0021, 0.0030], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.4093, 0.3669, 0.3110, 0.3841], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.4689, 0.2748, 0.3263, 0.2711], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.1892, 0.4107, 0.3952, 0.3549], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.5230, 0.4508, 0.2130, 0.5823], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([1.7361e-05, 2.6384e-05, 3.1766e-05, 2.2827e-05],
       grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0047, 0.0072, 0.0058, 0.0041], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([7.8981e-05, 1.1182e-04, 2.3391e-04, 1.6512e-04],
       grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([52.0409, 54.3603, 83.5033, 83.5308], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([59.5335, 50.1159, 52.6264, 49.9327], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([58.8833, 53.2625, 46.7369, 54.5674], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([0.9266, 0.8039, 1.0662, 1.0338], grad_fn=<S



Train Diffusion:   3%|▎         | 3/100 [00:13<07:04,  4.37s/it][A


theta_dict =  {'u_M': tensor([0.0027, 0.0022, 0.0032, 0.0012], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.2778, 0.3373, 0.3564, 0.4195], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.3050, 0.2973, 0.3269, 0.4105], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.3746, 0.3965, 0.3421, 0.3333], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.4801, 0.3823, 0.3419, 0.6623], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([1.1393e-05, 2.6882e-05, 2.5971e-05, 2.4583e-05],
       grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0070, 0.0062, 0.0053, 0.0056], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([0.0001, 0.0002, 0.0002, 0.0002], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([112.0539,  68.9533,  64.4187,  30.0656], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([43.4255, 62.8867, 69.1306, 47.5510], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([45.3163, 59.1884, 56.1983, 39.0235], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([1.1567, 1.0312, 1.1094, 1.4813], grad_fn=<SqueezeBackward1>), 



Train Diffusion:   4%|▍         | 4/100 [00:16<06:38,  4.15s/it][A


theta_dict =  {'u_M': tensor([0.0018, 0.0024, 0.0024, 0.0024], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.3189, 0.4368, 0.3415, 0.3930], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.3493, 0.3009, 0.2175, 0.4331], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.3664, 0.3231, 0.2300, 0.2506], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.4857, 0.3249, 0.9142, 0.3771], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([2.8919e-05, 1.9810e-05, 2.0680e-05, 1.3042e-05],
       grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0056, 0.0059, 0.0041, 0.0023], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([0.0003, 0.0002, 0.0002, 0.0003], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([45.7199, 96.7022, 32.6958, 47.5744], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([41.5376, 42.1935, 57.2472, 20.0566], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([42.6875, 53.2548, 32.5001, 54.0254], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([0.7969, 0.6581, 0.8964, 0.8927], grad_fn=<SqueezeBackward1>), 'c_D



Train Diffusion:   5%|▌         | 5/100 [00:21<06:44,  4.26s/it][A


theta_dict =  {'u_M': tensor([0.0030, 0.0013, 0.0017, 0.0013], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.4095, 0.2801, 0.3730, 0.2506], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.1878, 0.5412, 0.3315, 0.2886], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.4041, 0.3794, 0.3357, 0.3574], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.4991, 0.5373, 0.6099, 0.4217], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([2.2699e-05, 3.7250e-05, 2.2788e-05, 2.0708e-05],
       grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0037, 0.0041, 0.0037, 0.0046], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([0.0003, 0.0002, 0.0001, 0.0002], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([ 60.6719,  68.3832,  78.3684, 127.4213], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([44.4108, 50.3591, 57.1622, 38.9582], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([56.7658, 50.8335, 56.6264, 46.2271], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([0.9075, 1.0471, 1.0939, 0.6906], grad_fn=<SqueezeBackward1>), 



Train Diffusion:   6%|▌         | 6/100 [00:25<06:44,  4.30s/it][A


theta_dict =  {'u_M': tensor([0.0016, 0.0013, 0.0022, 0.0027], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.2939, 0.4021, 0.4862, 0.3530], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.2956, 0.2936, 0.1774, 0.2505], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.4129, 0.2959, 0.3545, 0.1882], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.3038, 0.4971, 0.4999, 0.3187], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([3.2115e-05, 3.2562e-05, 1.7579e-05, 2.5546e-05],
       grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0056, 0.0039, 0.0040, 0.0042], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([0.0003, 0.0001, 0.0003, 0.0002], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([85.6528, 57.3141, 75.4553, 59.8952], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([67.4842, 26.2894, 47.3807, 26.4844], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([31.2309, 27.0744, 49.8253, 43.5244], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([0.8154, 0.8572, 1.1749, 0.7850], grad_fn=<SqueezeBackward1>), 'c_D



Train Diffusion:   7%|▋         | 7/100 [00:29<06:24,  4.14s/it][A


theta_dict =  {'u_M': tensor([0.0030, 0.0030, 0.0030, 0.0030], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.2764, 0.3304, 0.2516, 0.3158], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.5016, 0.2341, 0.4532, 0.2667], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.3938, 0.3188, 0.3446, 0.2376], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.6199, 0.3349, 0.5288, 0.5271], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([0.0010, 0.0010, 0.0010, 0.0010], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0062, 0.0062, 0.0057, 0.0059], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([0.0012, 0.0012, 0.0012, 0.0012], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([58.5809, 77.9943, 76.4782, 51.9912], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([62.5633, 69.6788, 60.7702, 40.1682], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([54.7816, 59.0423, 51.7987, 52.7036], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([1.3964, 1.5569, 0.5921, 0.7583], grad_fn=<SqueezeBackward1>), 'c_DOC': tensor([0.0152, 0.



Train Diffusion:   8%|▊         | 8/100 [00:33<06:23,  4.17s/it][A


theta_dict =  {'u_M': tensor([0.0034, 0.0038, 0.0036, 0.0032], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.3863, 0.2303, 0.2399, 0.3295], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.4030, 0.3940, 0.2728, 0.4274], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.3755, 0.4531, 0.3529, 0.3574], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.5792, 0.5274, 0.1899, 0.5859], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([0.0017, 0.0017, 0.0017, 0.0017], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0067, 0.0067, 0.0067, 0.0067], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([0.0009, 0.0009, 0.0009, 0.0009], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([57.3185, 94.1432, 69.5947, 77.8228], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([45.4376, 61.3014, 41.4896, 66.5174], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([52.3674, 56.1935, 51.1523, 31.8127], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([0.7213, 1.3667, 1.4964, 0.9576], grad_fn=<SqueezeBackward1>), 'c_DOC': tensor([0.0044, 0.



Train Diffusion:   9%|▉         | 9/100 [00:37<06:15,  4.12s/it][A


theta_dict =  {'u_M': tensor([0.0031, 0.0047, 0.0041, 0.0053], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.2641, 0.3395, 0.3820, 0.3012], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.2855, 0.2996, 0.3196, 0.3741], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.4044, 0.3895, 0.2216, 0.4367], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.5218, 0.6464, 0.3581, 0.3897], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([0.0021, 0.0021, 0.0021, 0.0021], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0071, 0.0070, 0.0076, 0.0072], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([9.4937e-04, 7.7053e-05, 4.1179e-04, 1.0276e-03],
       grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([71.4379, 65.2876, 62.5110, 81.9737], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([49.6182, 51.9454, 43.2633, 31.1722], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([65.8665, 55.3030, 47.5366, 49.7051], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([0.7266, 1.0225, 0.6354, 1.1158], grad_fn=<SqueezeBackward1>), 'c_D



Train Diffusion:  10%|█         | 10/100 [00:42<06:14,  4.17s/it][A


theta_dict =  {'u_M': tensor([0.0058, 0.0029, 0.0056, 0.0030], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.2801, 0.3492, 0.1923, 0.2803], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.2930, 0.3079, 0.3721, 0.2394], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.0829, 0.3143, 0.1924, 0.3119], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.3762, 0.4666, 0.3191, 0.5391], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([0.0024, 0.0024, 0.0024, 0.0024], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0082, 0.0073, 0.0091, 0.0086], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([0.0005, 0.0005, 0.0010, 0.0010], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([ 96.0622, 100.7471,  79.6521,  66.0739], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([47.0687, 47.1643, 53.5283, 30.8996], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([59.5668, 40.6946, 64.2781, 63.9662], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([0.9374, 0.9341, 0.4332, 1.0988], grad_fn=<SqueezeBackward1>), 'c_DOC': tensor([0.0111



Train Diffusion:  11%|█         | 11/100 [00:45<06:01,  4.07s/it][A


theta_dict =  {'u_M': tensor([0.0045, 0.0050, 0.0054, 0.0047], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.2980, 0.4017, 0.2423, 0.1412], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.3126, 0.2671, 0.2006, 0.2797], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.5315, 0.3627, 0.0371, 0.3458], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.4802, 0.5014, 0.6279, 0.4738], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([0.0022, 0.0021, 0.0030, 0.0019], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0082, 0.0060, 0.0079, 0.0077], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([0.0049, 0.0020, 0.0009, 0.0021], grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([82.8046, 77.9603, 89.4224, 71.7405], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([37.1029, 34.9862, 39.6007, 33.5043], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([42.2146, 58.4784, 31.4035, 36.8494], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([0.9361, 0.9512, 0.9715, 0.6084], grad_fn=<SqueezeBackward1>), 'c_DOC': tensor([0.0065, 0.



Train Diffusion:  12%|█▏        | 12/100 [00:50<06:02,  4.12s/it][A


theta_dict =  {'u_M': tensor([0.0039, 0.0031, 0.0026, 0.0020], grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.3574, 0.3538, 0.3660, 0.3626], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.2320, 0.3288, 0.2643, 0.2975], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.1897, 0.2506, 0.3075, 0.3009], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.5749, 0.7258, 0.3230, 0.5246], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([0.0020, 0.0030, 0.0038, 0.0038], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0093, 0.0085, 0.0091, 0.0076], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([5.5973e-04, 7.6662e-03, 1.2691e-03, 1.0000e-06],
       grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([93.4618, 85.4348, 63.5249, 94.1094], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([46.2307, 70.5635, 39.8989, 72.3885], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([42.5664, 19.3457, 82.7874, 44.5809], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([1.0450, 0.7804, 0.7821, 0.7648], grad_fn=<SqueezeBackward1>), 'c_D



Train Diffusion:  13%|█▎        | 13/100 [00:54<06:06,  4.22s/it][A


theta_dict =  {'u_M': tensor([6.1694e-03, 1.0000e-06, 2.8642e-03, 7.5608e-03],
       grad_fn=<SqueezeBackward1>), 'a_SD': tensor([0.3750, 0.2935, 0.4584, 0.2579], grad_fn=<SqueezeBackward1>), 'a_DS': tensor([0.3072, 0.5599, 0.3683, 0.4244], grad_fn=<SqueezeBackward1>), 'a_M': tensor([0.2678, 0.2435, 0.2699, 0.2229], grad_fn=<SqueezeBackward1>), 'a_MSC': tensor([0.5535, 0.2846, 0.5376, 0.4371], grad_fn=<SqueezeBackward1>), 'k_S_ref': tensor([0.0014, 0.0022, 0.0020, 0.0039], grad_fn=<SqueezeBackward1>), 'k_D_ref': tensor([0.0106, 0.0112, 0.0064, 0.0111], grad_fn=<SqueezeBackward1>), 'k_M_ref': tensor([2.3454e-03, 4.0899e-03, 1.0000e-06, 1.0000e-06],
       grad_fn=<SqueezeBackward1>), 'Ea_S': tensor([78.2226, 90.2575, 41.3053, 91.7673], grad_fn=<SqueezeBackward1>), 'Ea_D': tensor([50.4739, 59.1815, 43.5821, 43.4856], grad_fn=<SqueezeBackward1>), 'Ea_M': tensor([68.9219, 39.4261, 41.4138, 39.2733], grad_fn=<SqueezeBackward1>), 'c_SOC': tensor([1.0877, 0.8913, 0.9500, 0.5787], grad_fn=<S

In [9]:
t_span_tensor.size()

torch.Size([1, 2501, 1])

In [12]:
prior_means_tensor = torch.Tensor(list(SCON_C_prior_means.values()))
priors = D.normal.Normal(prior_means_tensor, prior_means_tensor * prior_scale_factor)
q_theta = MeanField(SCON_C_prior_means, prior_scale_factor)
theta_dict, theta, log_q_theta = q_theta(batch_size)
print(log_q_theta)

tensor([31.1428, 22.5302, 33.7238], grad_fn=<SumBackward1>)


In [10]:
theta_dict

{'u_M': tensor([0.0014, 0.0026, 0.0019], grad_fn=<SqueezeBackward1>),
 'a_SD': tensor([0.2349, 0.3096, 0.2740], grad_fn=<SqueezeBackward1>),
 'a_DS': tensor([0.3093, 0.2184, 0.4074], grad_fn=<SqueezeBackward1>),
 'a_M': tensor([0.2942, 0.1901, 0.3703], grad_fn=<SqueezeBackward1>),
 'a_MSC': tensor([0.6061, 0.5708, 0.4159], grad_fn=<SqueezeBackward1>),
 'k_S_ref': tensor([2.9325e-05, 2.9959e-05, 3.0455e-05], grad_fn=<SqueezeBackward1>),
 'k_D_ref': tensor([0.0046, 0.0057, 0.0035], grad_fn=<SqueezeBackward1>),
 'k_M_ref': tensor([9.4239e-05, 1.2225e-04, 1.9976e-04], grad_fn=<SqueezeBackward1>),
 'Ea_S': tensor([81.0427, 68.5995, 65.2861], grad_fn=<SqueezeBackward1>),
 'Ea_D': tensor([34.2083, 73.1626, 46.1662], grad_fn=<SqueezeBackward1>),
 'Ea_M': tensor([54.3748, 59.3774, 30.2376], grad_fn=<SqueezeBackward1>),
 'c_SOC': tensor([1.0770, 0.8536, 1.4267], grad_fn=<SqueezeBackward1>),
 'c_DOC': tensor([0.0103, 0.0114, 0.0089], grad_fn=<SqueezeBackward1>),
 'c_MBC': tensor([0.1309, 0.1065, 

In [13]:
print(priors.log_prob(theta))
log_p_theta_1 = priors.log_prob(theta).sum(-1)
print(log_p_theta_1)
log_p_theta_2 = priors.log_prob(theta).sum()
print(log_p_theta_2)
#log_p_theta_1 is correct way of computing log_p_theta

tensor([[ 6.5178,  1.4218,  1.3994,  0.3872,  1.0318, 10.8893,  5.6037,  8.5614,
         -3.8525, -3.5671, -4.5635,  0.4486,  4.9733,  1.8918],
        [ 5.3546,  1.4249,  0.1196, -1.2571,  0.5270, 11.0137,  5.1689,  8.9809,
         -3.8545, -3.6807, -4.7232,  0.4672,  1.6578,  1.3309],
        [ 6.6559,  1.5368,  1.5733,  0.9563,  1.1207, 10.9697,  5.6613,  8.6101,
         -4.1554, -3.8741, -3.4660,  0.4118,  4.9806,  2.7428]],
       grad_fn=<SubBackward0>)
tensor([31.1428, 22.5302, 33.7238], grad_fn=<SumBackward1>)
tensor(87.3968, grad_fn=<SumBackward0>)


In [21]:
obs_times, obs_means, obs_error = csv_to_obs_df('CON_synthetic_sol_df.csv', 3 + 1, t, obs_error_scale_factor)
obs_model = ObsModel(devi, obs_times, dt, obs_means[:-1, :], obs_error[:, :-1])
net = SDEFlow(devi, batch_size, obs_model, 3, t, dt, n).to(devi)

In [22]:
C_PATH, log_prob = net()
print(C_PATH)
C_PATH.size()

tensor([[[0.7703, 0.8147, 0.7190],
         [0.3577, 0.6447, 0.7709],
         [0.5693, 0.3108, 0.8354],
         ...,
         [0.8004, 0.8717, 1.5919],
         [0.7038, 0.8529, 1.4649],
         [0.7325, 0.7718, 0.8040]],

        [[0.8571, 0.8375, 0.7950],
         [0.2069, 0.7113, 0.9912],
         [0.5616, 0.5460, 0.8096],
         ...,
         [0.7837, 0.7142, 1.3335],
         [0.9051, 0.8525, 1.5034],
         [0.6693, 0.6562, 4.9073]],

        [[0.5143, 0.5829, 0.7754],
         [1.4471, 0.6817, 0.8258],
         [0.7452, 0.8609, 0.9399],
         ...,
         [0.8733, 0.7922, 1.5084],
         [0.7237, 0.9007, 1.5550],
         [0.7362, 0.8360, 7.2313]]], grad_fn=<AddBackward0>)


torch.Size([3, 2500, 3])

In [47]:
C_0 = analytical_steady_state_init_CON(i_s_tensor[0, 0, 0].item(), i_d_tensor[0, 0, 0].item(), theta_dict)
print(C_0)
C_0.unsqueeze(1).size()

tensor([[6.3284e+01, 1.3084e-01, 9.9657e-01],
        [4.6585e+01, 5.1670e-02, 4.4936e-01],
        [6.1308e+01, 3.6414e-02, 4.5395e-01]], grad_fn=<StackBackward>)


torch.Size([3, 1, 3])

In [24]:
C_PATH = torch.cat([C_0.unsqueeze(1), C_PATH], 1)
C_PATH

tensor([[[6.3284e+01, 1.3084e-01, 9.9657e-01],
         [7.7030e-01, 8.1470e-01, 7.1902e-01],
         [3.5773e-01, 6.4470e-01, 7.7092e-01],
         ...,
         [8.0044e-01, 8.7174e-01, 1.5919e+00],
         [7.0381e-01, 8.5288e-01, 1.4649e+00],
         [7.3249e-01, 7.7180e-01, 8.0398e-01]],

        [[4.6585e+01, 5.1670e-02, 4.4936e-01],
         [8.5708e-01, 8.3745e-01, 7.9501e-01],
         [2.0688e-01, 7.1126e-01, 9.9121e-01],
         ...,
         [7.8371e-01, 7.1421e-01, 1.3335e+00],
         [9.0510e-01, 8.5253e-01, 1.5034e+00],
         [6.6934e-01, 6.5619e-01, 4.9073e+00]],

        [[6.1308e+01, 3.6414e-02, 4.5395e-01],
         [5.1430e-01, 5.8294e-01, 7.7539e-01],
         [1.4471e+00, 6.8167e-01, 8.2578e-01],
         ...,
         [8.7326e-01, 7.9220e-01, 1.5084e+00],
         [7.2368e-01, 9.0072e-01, 1.5550e+00],
         [7.3621e-01, 8.3604e-01, 7.2313e+00]]], grad_fn=<CatBackward>)

In [25]:
SOC, DOC, MBC =  torch.chunk(C_PATH, 3, -1)
SOC.size()

torch.Size([3, 2501, 1])

In [26]:
current_temp = temp_gen(t_span_tensor, temp_ref)

In [27]:
theta_dict['k_S_ref'].size()

torch.Size([3])

In [28]:
k_S = arrhenius_temp_dep(theta_dict['k_S_ref'], current_temp, theta_dict['Ea_S'], temp_ref)
k_S = k_S.permute(2, 1, 0)
k_D = arrhenius_temp_dep(theta_dict['k_D_ref'], current_temp, theta_dict['Ea_D'], temp_ref)
k_D = k_D.permute(2, 1, 0)
k_M = arrhenius_temp_dep(theta_dict['k_M_ref'], current_temp, theta_dict['Ea_M'], temp_ref)
k_M = k_M.permute(2, 1, 0)

In [29]:
theta_dict_repeat = dict((k, v.repeat(1, 2501, 1).permute(2, 1, 0)) for k, v in theta_dict.items())
drift_SOC = i_s_tensor + theta_dict_repeat['a_DS'] * k_D * DOC + theta_dict_repeat['a_M'] * theta_dict_repeat['a_MSC'] * k_M * MBC - k_S * SOC
drift_DOC = i_d_tensor + theta_dict_repeat['a_SD'] * k_S * SOC + theta_dict_repeat['a_M'] * (1 - theta_dict_repeat['a_MSC']) * k_M * MBC - (theta_dict_repeat['u_M'] + k_D) * DOC
drift_MBC = theta_dict_repeat['u_M'] * DOC - k_M * MBC

In [30]:
drift_MBC.size()

torch.Size([3, 2501, 1])

In [31]:
theta_dict_u_M_test = theta_dict['u_M'].repeat(1, 2501, 1)
theta_dict_u_M_test = theta_dict_u_M_test.permute(2, 1, 0)
print(theta_dict_u_M_test.size())
test = theta_dict_u_M_test * DOC
print(test.size())

torch.Size([3, 2501, 1])
torch.Size([3, 2501, 1])


In [32]:
C_PATH.size()

torch.Size([3, 2501, 3])

In [33]:
drift = torch.empty_like(C_PATH, device = C_PATH.device) #Initiate tensor with same dims as C_PATH to assign drift.
drift[:, :, 0 : 1] = drift_SOC
drift[:, :, 1 : 2] = drift_DOC
drift[:, :, 2 : 3] = drift_MBC

In [34]:
test_m = torch.stack([theta_dict['c_SOC'], theta_dict['c_DOC'], theta_dict['c_MBC']], 1)
print(LowerBound.apply(test_m, 1e-6))
test_m_sqrt = torch.sqrt(test_m)
torch.diag_embed(test_m_sqrt)

tensor([[1.1623, 0.0138, 0.1381],
        [1.4763, 0.0127, 0.0956],
        [0.6716, 0.0106, 0.1080]], grad_fn=<LowerBoundBackward>)


tensor([[[1.0781, 0.0000, 0.0000],
         [0.0000, 0.1175, 0.0000],
         [0.0000, 0.0000, 0.3717]],

        [[1.2150, 0.0000, 0.0000],
         [0.0000, 0.1127, 0.0000],
         [0.0000, 0.0000, 0.3093]],

        [[0.8195, 0.0000, 0.0000],
         [0.0000, 0.1029, 0.0000],
         [0.0000, 0.0000, 0.3286]]], grad_fn=<CopySlices>)

In [35]:
diffusion_sqrt_single = torch.diag_embed(torch.sqrt(LowerBound.apply(torch.stack([theta_dict['c_SOC'], theta_dict['c_DOC'], theta_dict['c_MBC']], 1), 1e-6)))
diffusion_sqrt_single

tensor([[[1.0781, 0.0000, 0.0000],
         [0.0000, 0.1175, 0.0000],
         [0.0000, 0.0000, 0.3717]],

        [[1.2150, 0.0000, 0.0000],
         [0.0000, 0.1127, 0.0000],
         [0.0000, 0.0000, 0.3093]],

        [[0.8195, 0.0000, 0.0000],
         [0.0000, 0.1029, 0.0000],
         [0.0000, 0.0000, 0.3286]]], grad_fn=<CopySlices>)

In [36]:
diffusion_sqrt_single.unsqueeze(1).expand(-1, 2501, -1, -1)

tensor([[[[1.0781, 0.0000, 0.0000],
          [0.0000, 0.1175, 0.0000],
          [0.0000, 0.0000, 0.3717]],

         [[1.0781, 0.0000, 0.0000],
          [0.0000, 0.1175, 0.0000],
          [0.0000, 0.0000, 0.3717]],

         [[1.0781, 0.0000, 0.0000],
          [0.0000, 0.1175, 0.0000],
          [0.0000, 0.0000, 0.3717]],

         ...,

         [[1.0781, 0.0000, 0.0000],
          [0.0000, 0.1175, 0.0000],
          [0.0000, 0.0000, 0.3717]],

         [[1.0781, 0.0000, 0.0000],
          [0.0000, 0.1175, 0.0000],
          [0.0000, 0.0000, 0.3717]],

         [[1.0781, 0.0000, 0.0000],
          [0.0000, 0.1175, 0.0000],
          [0.0000, 0.0000, 0.3717]]],


        [[[1.2150, 0.0000, 0.0000],
          [0.0000, 0.1127, 0.0000],
          [0.0000, 0.0000, 0.3093]],

         [[1.2150, 0.0000, 0.0000],
          [0.0000, 0.1127, 0.0000],
          [0.0000, 0.0000, 0.3093]],

         [[1.2150, 0.0000, 0.0000],
          [0.0000, 0.1127, 0.0000],
          [0.0000, 0.0000, 0.30