In [45]:
import torch
from torch import nn
import torch.distributions as d
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import math
from tqdm import tqdm
import random
from torch.autograd import Function
from torch.utils.tensorboard import SummaryWriter
import argparse
import os
import sys
from pathlib import Path
import shutil
import pandas as pd

In [1]:
torch.manual_seed(0)
STATE_DIM = 3
dt = .1
T = 500 #Run simulation for 500 hours.
N = int(T / dt) + 1
T_span = np.linspace(0, T, N)
T_span_tensor = torch.Tensor(T_span)[(None,) * 2] #T_span needs to be converted to tensor object.

BATCH_SIZE = 1
device = torch.device("".join(["cuda:",f'{args.CUDA_ID}']) if torch.cuda.is_available() else "cpu")
LR = 1e-3
niter = 1000000

NameError: name 'torch' is not defined

In [47]:
obs_df_full = pd.read_csv('CON_synthetic_sol_df.csv') #Must be link to raw Github output if in Colab.
#obs_df_full = pd.read_csv('https://raw.githubusercontent.com/wallyxie/varInferenceSoilBiogeoModelSyntheticData/main/python_sde_variational_inference_code/CON_synthetic_sol_df.csv') #Must be link to raw Github output if in Colab.
obs_df = obs_df_full[obs_df_full['hour'] <= T] #Test with just first 1,000 hours of data.
obs_df

Unnamed: 0,hour,SOC,DOC,MBC
0,0,45.66034,0.071469,0.714689
1,10,45.646801,0.073782,0.713685
2,20,45.650479,0.073273,0.714127
3,30,45.644809,0.074233,0.713766
4,40,45.639221,0.07505,0.71344
5,50,45.642405,0.07462,0.713863
6,60,45.6283,0.076746,0.712869
7,70,45.633973,0.075864,0.713526
8,80,45.622501,0.077571,0.712763
9,90,45.621995,0.077491,0.712908


In [48]:
obs_times = torch.Tensor(np.array(obs_df['hour']))
obs_times

tensor([  0.,  10.,  20.,  30.,  40.,  50.,  60.,  70.,  80.,  90., 100., 110.,
        120., 130., 140., 150., 160., 170., 180., 190., 200., 210., 220., 230.,
        240., 250., 260., 270., 280., 290., 300., 310., 320., 330., 340., 350.,
        360., 370., 380., 390., 400., 410., 420., 430., 440., 450., 460., 470.,
        480., 490., 500.])

In [49]:
obs_means = torch.Tensor(np.array(obs_df.drop(columns = 'hour'))) #Convert to tensor.
obs_means

tensor([[45.6603,  0.0715,  0.7147],
        [45.6468,  0.0738,  0.7137],
        [45.6505,  0.0733,  0.7141],
        [45.6448,  0.0742,  0.7138],
        [45.6392,  0.0750,  0.7134],
        [45.6424,  0.0746,  0.7139],
        [45.6283,  0.0767,  0.7129],
        [45.6340,  0.0759,  0.7135],
        [45.6225,  0.0776,  0.7128],
        [45.6220,  0.0775,  0.7129],
        [45.6214,  0.0774,  0.7130],
        [45.6097,  0.0789,  0.7122],
        [45.6149,  0.0781,  0.7129],
        [45.6001,  0.0800,  0.7119],
        [45.6039,  0.0792,  0.7124],
        [45.5954,  0.0804,  0.7119],
        [45.5887,  0.0810,  0.7116],
        [45.5904,  0.0806,  0.7120],
        [45.5785,  0.0817,  0.7112],
        [45.5834,  0.0808,  0.7118],
        [45.5716,  0.0821,  0.7111],
        [45.5695,  0.0821,  0.7112],
        [45.5674,  0.0822,  0.7113],
        [45.5540,  0.0837,  0.7105],
        [45.5598,  0.0825,  0.7112],
        [45.5406,  0.0850,  0.7100],
        [45.5450,  0.0839,  0.7106],
 

In [50]:
obs_std = obs_means * 0.5 #Assume observations are distributed normally about the mean with standard deviation at half of the mean.
obs_std

tensor([[22.8302,  0.0357,  0.3573],
        [22.8234,  0.0369,  0.3568],
        [22.8252,  0.0366,  0.3571],
        [22.8224,  0.0371,  0.3569],
        [22.8196,  0.0375,  0.3567],
        [22.8212,  0.0373,  0.3569],
        [22.8141,  0.0384,  0.3564],
        [22.8170,  0.0379,  0.3568],
        [22.8113,  0.0388,  0.3564],
        [22.8110,  0.0387,  0.3565],
        [22.8107,  0.0387,  0.3565],
        [22.8049,  0.0394,  0.3561],
        [22.8074,  0.0390,  0.3564],
        [22.8000,  0.0400,  0.3559],
        [22.8019,  0.0396,  0.3562],
        [22.7977,  0.0402,  0.3560],
        [22.7944,  0.0405,  0.3558],
        [22.7952,  0.0403,  0.3560],
        [22.7892,  0.0409,  0.3556],
        [22.7917,  0.0404,  0.3559],
        [22.7858,  0.0411,  0.3555],
        [22.7847,  0.0411,  0.3556],
        [22.7837,  0.0411,  0.3556],
        [22.7770,  0.0419,  0.3552],
        [22.7799,  0.0413,  0.3556],
        [22.7703,  0.0425,  0.3550],
        [22.7725,  0.0420,  0.3553],
 

In [51]:
I_S = 0.001 + 0.0005 * torch.sin((2 * math.pi / (24 * 365)) * T_span_tensor) #Exogenous SOC input function
I_D =  0.0001 + 0.00005 * torch.sin((2 * math.pi / (24 * 365)) * T_span_tensor) #Exogenous DOC input function

In [52]:
temp_ref = 283

u_M = 0.002
a_SD = 0.33
a_DS = 0.33
a_M = 0.33
a_MSC = 0.5
k_S_ref = 0.000025
k_D_ref = 0.005
k_M_ref = 0.0002
Ea_S = 75
Ea_D = 50
Ea_M = 50

scon_params_dict = {'u_M': u_M, 'a_SD': a_SD, 'a_DS': a_DS, 'a_M': a_M, 'a_MSC': a_MSC, 'k_S_ref': k_S_ref, 'k_D_ref': k_D_ref, 'k_M_ref': k_M_ref, 'Ea_S': Ea_S, 'Ea_D': Ea_D, 'Ea_M': Ea_M}

In [53]:
############################################################
##SOIL BIOGEOCHEMICAL MODEL TEMPERATURE RESPONSE FUNCTIONS##
############################################################

def temp_gen(t, temp_ref):
    temp = temp_ref + t / (20 * 24 * 365) + 10 * np.sin((2 * np.pi / 24) * t) + 10 * np.sin((2 * np.pi / (24 * 365)) * t)
    return temp

def arrhenius_temp_dep(parameter, temp, Ea, temp_ref):
    '''
    For a parameter with Arrhenius temperature dependence, returns the transformed parameter value.
    0.008314 is the gas constant. Temperatures are in K.
    '''
    decayed_parameter = parameter * np.exp(-Ea / 0.008314 * (1 / temp - 1 / temp_ref))
    return decayed_parameter

def linear_temp_dep(parameter, temp, Q, temp_ref):
    '''
    For a parameter with linear temperature dependence, returns the transformed parameter value.
    Q is the slope of the temperature dependence and is a varying parameter.
    Temperatures are in K.
    '''
    modified_parameter = parameter - Q * (temp - temp_ref)
    return modified_parameter

##########################################################################
##DETERMINISTIC SOIL BIOGEOCHEMICAL MODEL INITIAL STEADY STATE ESTIMATES##
##########################################################################

#Analytical_steady_state_init_awb to be coded later.
def analytical_steady_state_init_con(SOC_input, DOC_input, scon_params_dict):
    '''
    Returns a vector of C pool values to initialize an SCON system corresponding to set of parameter values using the analytical steady state equations of the deterministic CON system.
    Vector elements are in order of S_0, D_0, and M_0.
    Expected scon_params_dict = {scon_params_dict = {'u_M': u_M, 'a_SD': a_SD, 'a_DS': a_DS, 'a_M': a_M, 'a_MSC': a_MSC, 'k_S_ref': k_S_ref, 'k_D_ref': k_D_ref, 'k_M_ref': k_M_ref, 'Ea_S': Ea_S, 'Ea_D': Ea_D, 'Ea_M': Ea_M}    
    '''
    D_0 = (DOC_input + SOC_input * scon_params_dict['a_SD']) / (scon_params_dict['u_M'] + scon_params_dict['k_D_ref'] + scon_params_dict['u_M'] * scon_params_dict['a_M'] * (scon_params_dict['a_MSC'] - 1 - scon_params_dict['a_MSC'] * scon_params_dict['a_SD']) - scon_params_dict['a_DS'] * scon_params_dict['k_D_ref'] * scon_params_dict['a_SD'])
    S_0 = (SOC_input + D_0 * (scon_params_dict['a_DS'] * scon_params_dict['k_D_ref'] + scon_params_dict['u_M'] * scon_params_dict['a_M'] * scon_params_dict['a_MSC'])) / scon_params_dict['k_S_ref']
    M_0 = scon_params_dict['u_M'] * D_0 / scon_params_dict['k_M_ref']
    C_0_vector = torch.as_tensor([S_0, D_0, M_0])
    return C_0_vector

####################################################
##STOCHASTIC DIFFERENTIAL EQUATION MODEL FUNCTIONS##
#################################################### 

def drift_and_diffusion_scon(N, T_span_tensor, dt, I_S, I_D, analytical_steady_state_init_con, scon_params_dict, temp_ref, path_count):
    '''
    Returns SCON solutions using the Euler-Maruyama scheme.
    current_temp is output from temp_gen function. 
    Expected scon_params_dict = {scon_params_dict = {'u_M': u_M, 'a_SD': a_SD, 'a_DS': a_DS, 'a_M': a_M, 'a_MSC': a_MSC, 'k_S_ref': k_S_ref, 'k_D_ref': k_D_ref, 'k_M_ref': k_M_ref, 'Ea_S': Ea_S, 'Ea_D': Ea_D, 'Ea_M': Ea_M}
    '''
    system_var_size = 3 #SCON and CON will always have three state variables.
    C_vector = torch.zeros(path_count, system_var_size, N) #Create tensor for storing state variable values.
    drift_vector = C_vector.clone() #Create tensor to assign drift.
    C0 = analytical_steady_state_init_con(I_S[0, 0, 0].item(), I_D[0, 0, 0].item(), scon_params_dict)
    print('\n Initial pre-perturbation SOC, DOC, MBC = ', C0)
    C_vector[:, :, 0] = C0 #Assign deterministically generated initial conditions to all paths.
    diffusion_matrix_sqrt = torch.zeros([drift_vector.size(0), system_var_size, system_var_size, drift_vector.size(2)], device = drift_vector.device) #Create tensor to assign diffusion matrix elements.
    diffusion_matrix_sqrt[:, 0, 0, 0] = torch.sqrt(C0[0]) #Assigned S0 to element 1, 1 of matrix.
    diffusion_matrix_sqrt[:, 1, 1, 0] = torch.sqrt(C0[1]) #Assigned D0 to element 2, 2 of matrix.
    diffusion_matrix_sqrt[:, 2, 2, 0] = torch.sqrt(C0[2]) #Assigned M0 to element 3, 3 of matrix.
    current_temp = temp_gen(T_span_tensor, temp_ref) #Obtain temperature function vector across span of times.
    #Decay parameters are forced by temperature changes.
    k_S = arrhenius_temp_dep(scon_params_dict['k_S_ref'], current_temp, scon_params_dict['Ea_S'], temp_ref) #Apply vectorized temperature-dependent transformation to k_S_ref. 
    k_D = arrhenius_temp_dep(scon_params_dict['k_D_ref'], current_temp, scon_params_dict['Ea_D'], temp_ref) #Apply vectorized temperature-dependent transformation to k_D_ref.
    k_M = arrhenius_temp_dep(scon_params_dict['k_M_ref'], current_temp, scon_params_dict['Ea_M'], temp_ref) #Apply vectorized temperature-dependent transformation to k_M_ref.
    #print('\n k_S, k_D, k_M =', [k_S, k_D, k_M])
    #Drift vector is calculated (without litter input).
    for i in range(1, N):
        SOC, DOC, MBC = [C_vector[:, l : l + 1, i - 1] for l in range(system_var_size)]
        drift_SOC = I_S[0, 0, i - 1] + scon_params_dict['a_DS'] * k_D[0, 0, i - 1] * DOC + scon_params_dict['a_M'] * scon_params_dict['a_MSC'] * k_M[0, 0, i - 1] * MBC - k_S[0, 0, i - 1] * SOC
        drift_DOC = I_D[0, 0, i - 1] + scon_params_dict['a_SD'] * k_S[0, 0, i - 1] * SOC + scon_params_dict['a_M'] * (1 - scon_params_dict['a_MSC']) * k_M[0, 0, i - 1] * MBC - (scon_params_dict['u_M'] + k_D[0, 0, i - 1]) * DOC
        drift_MBC = scon_params_dict['u_M'] * DOC - k_M[0, 0, i - 1] * MBC
        drift_vector[:, :, i - 1] = torch.cat([drift_SOC, drift_DOC, drift_MBC], 1) #Assign drift means to all paths.
        #Diffusion matrix is calculated.
        diffusion_matrix_sqrt[:, 0, 0, i - 1] = torch.sqrt(LowerBound.apply(drift_SOC, 0)).squeeze() #Assigned to element 1, 1 of matrix.
        diffusion_matrix_sqrt[:, 1, 1, i - 1] = torch.sqrt(LowerBound.apply(drift_DOC, 0)).squeeze() #Assigned to element 2, 2 of matrix.
        diffusion_matrix_sqrt[:, 2, 2, i - 1] = torch.sqrt(LowerBound.apply(drift_MBC, 0)).squeeze() #Assigned to element 3, 3 of matrix.
        C_vector[:, :, i] = d.multivariate_normal.MultivariateNormal(loc = C_vector[:, :, i - 1] + drift_vector[:, :, i - 1] * dt, scale_tril = diffusion_matrix_sqrt[:, :, :, i - 1] * math.sqrt(dt)).rsample()
        C_vector_test[:, :, i][C_vector_test[:, :, i] < 1e-9] = 1e-9 #Prevent state variables from dropping below 0.
    return C_vector, drift_vector, diffusion_matrix_sqrt

In [54]:
path_test = 2
C_vector_test = torch.zeros(path_test, 3, N)
drift_vector_test = torch.zeros(path_test, 3, N)
C0 = analytical_steady_state_init_con(I_S[0, 0, 0].item(), I_D[0, 0, 0].item(), scon_params_dict)
C_vector_test[:, :, 0] = C0
print(C_vector_test)
SOC, DOC, MBC = [C_vector_test[:, l : l + 1, 0] for l in range(3)]
print(SOC, DOC, MBC)

tensor([[[45.6603,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0715,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.7147,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[45.6603,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0715,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.7147,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])
tensor([[45.6603],
        [45.6603]]) tensor([[0.0715],
        [0.0715]]) tensor([[0.7147],
        [0.7147]])


In [55]:
i = 1
SOC, DOC, MBC = [C_vector_test[:, l : l + 1, i - 1] for l in range(3)]
print(SOC, DOC, MBC)

tensor([[45.6603],
        [45.6603]]) tensor([[0.0715],
        [0.0715]]) tensor([[0.7147],
        [0.7147]])


In [56]:
print(scon_params_dict)

current_temp = temp_gen(T_span_tensor, temp_ref)

k_S = arrhenius_temp_dep(scon_params_dict['k_S_ref'], current_temp, scon_params_dict['Ea_S'], temp_ref)
k_D = arrhenius_temp_dep(scon_params_dict['k_D_ref'], current_temp, scon_params_dict['Ea_D'], temp_ref)
k_M = arrhenius_temp_dep(scon_params_dict['k_M_ref'], current_temp, scon_params_dict['Ea_M'], temp_ref)

print(k_S, k_D, k_M)
print(k_S[0, 0, 5])

{'u_M': 0.002, 'a_SD': 0.33, 'a_DS': 0.33, 'a_M': 0.33, 'a_MSC': 0.5, 'k_S_ref': 2.5e-05, 'k_D_ref': 0.005, 'k_M_ref': 0.0002, 'Ea_S': 75, 'Ea_D': 50, 'Ea_M': 50}
tensor([[[2.5000e-05, 2.5750e-05, 2.6519e-05,  ..., 1.3450e-05,
          1.3644e-05, 1.3851e-05]]]) tensor([[[0.0050, 0.0051, 0.0052,  ..., 0.0033, 0.0033, 0.0034]]]) tensor([[[0.0002, 0.0002, 0.0002,  ..., 0.0001, 0.0001, 0.0001]]])
tensor(2.8952e-05)


In [57]:
class LowerBound(Function):
    @staticmethod
    def forward(ctx, inputs, bound):
        b = torch.ones(inputs.size()) * bound
        b = b.to(inputs.device)
        b = b.type(inputs.dtype)
        ctx.save_for_backward(inputs, b)
        return torch.max(inputs, b)
    @staticmethod
    def backward(ctx, grad_output):
        inputs, b = ctx.saved_tensors

        pass_through_1 = inputs >= b
        pass_through_2 = grad_output < 0

        pass_through = pass_through_1 | pass_through_2
        return pass_through.type(grad_output.dtype) * grad_output, None

class MaskedConv1d(nn.Conv1d):
    def __init__(self, mask_type, *args, **kwargs):
        super(MaskedConv1d, self).__init__(*args, **kwargs)
        assert mask_type in {'A', 'B'}
        self.register_buffer('mask', self.weight.data.clone())
        _, _, kW = self.weight.size()
        self.mask.fill_(1)
        self.mask[:, :, kW // 2 + 1*(mask_type == 'B'):] = 0

    def forward(self, x):
        self.weight.data *= self.mask
        return super(MaskedConv1d, self).forward(x)

class ResNetBlock(nn.Module):

    def __init__(self, inp_cha, out_cha, stride = 1, batch_norm=True):
        super().__init__()
        self.conv1 = MaskedConv1d('B', inp_cha,  out_cha, 15, stride, 7, bias=False)
        self.conv2 = MaskedConv1d('B', out_cha,  out_cha, 15, 1, 7, bias=False)

        self.act1 = nn.PReLU(out_cha, init = 0.2)
        self.act2 = nn.PReLU(out_cha, init = 0.2)

        if batch_norm:
            self.bn1 = nn.BatchNorm1d(out_cha)
            self.bn2 = nn.BatchNorm1d(out_cha)
        else:
            self.bn1 = nn.Identity()
            self.bn2 = nn.Identity()

        # If dimensions change, transform shortcut with a conv layer
        if inp_cha != out_cha or stride > 1:
            self.conv_skip = MaskedConv1d('B', inp_cha,  out_cha, 3, stride, 1, bias=False)
        else:
            self.conv_skip = nn.Identity()

    def forward(self, x):
        residual = x
        x = self.act1(self.bn1(self.conv1(x)))
        x = self.act2(self.bn2(self.conv2(x) + self.conv_skip(residual)))
        return x

class ResNetBlockUnMasked(nn.Module):

    def __init__(self, inp_cha, out_cha, stride = 1, batch_norm=True):
        super().__init__()
        self.conv1 = nn.Conv1d(inp_cha,  out_cha, 15, stride, 7)
        self.conv2 = nn.Conv1d(out_cha,  out_cha, 15, 1, 7)

        self.act1 = nn.PReLU(out_cha, init = 0.2)
        self.act2 = nn.PReLU(out_cha, init = 0.2)

        if batch_norm:
            self.bn1 = nn.BatchNorm1d(out_cha)
            self.bn2 = nn.BatchNorm1d(out_cha)
        else:
            self.bn1 = nn.Identity()
            self.bn2 = nn.Identity()

        # If dimensions change, transform shortcut with a conv layer
        if inp_cha != out_cha or stride > 1:
            self.conv_skip = nn.Conv1d(inp_cha,  out_cha, 15, stride, 7, bias=False)
        else:
            self.conv_skip = nn.Identity()

    def forward(self, x):
        residual = x
        x = self.act1(self.bn1(self.conv1(x)))
        x = self.act2(self.bn2(self.conv2(x) + self.conv_skip(residual)))
        return x

class CouplingLayer(nn.Module):

    def __init__(self, cond_inputs, stride):
        super().__init__()
        self.net = nn.Sequential(ResNetBlock(1+cond_inputs, 96),
                                 MaskedConv1d('B', 96,  2, 15, stride, 7, bias=False))

        self.feature_net = nn.Sequential(ResNetBlockUnMasked(cond_inputs, 96),
                                          ResNetBlockUnMasked(96, cond_inputs))

        self.unpack = True if cond_inputs > 1 else False

    def forward(self, x, cond_inputs):
        if self.unpack:
            cond_inputs = torch.cat([*cond_inputs], 1)
        cond_inputs = self.feature_net(cond_inputs)
        feature_vec = torch.cat([x, cond_inputs], 1)
        output = self.net(feature_vec)
        mu, sigma = torch.chunk(output, 2, 1)
        mu = self._pass_through_units(mu)
        sigma = self._pass_through_units(sigma, mu=False)
        x = mu + sigma*x
        return x, sigma

    def _pass_through_units(self, params, mu=True):
        B, _, L = params.shape
        padding = 2 if STATE_DIM % 2==1 else 1
        if mu:
            pad = torch.zeros([B, padding, L], device=params.device)
        else:
            params = LowerBound.apply(params, 1e-6)
            pad = torch.ones([B, padding, L], device=params.device)
        return torch.cat([pad, params], 1).transpose(2, 1).reshape(B, 1, -1)

class PermutationLayer(nn.Module):

    def __init__(self):
        super().__init__()
        self.index = torch.randperm(STATE_DIM)

    def forward(self, x):
        B, S, L = x.shape
        x_reshape = x.reshape(B, S, -1, STATE_DIM)
        x_perm = x_reshape[:, :, :, self.index]
        x = x_perm.reshape(B, S, L)
        return x

In [58]:
path_test = 2
C_vector_test = torch.zeros(path_test, 3, N)
C0 = analytical_steady_state_init_con(I_S[0, 0, 0].item(), I_D[0, 0, 0].item(), scon_params_dict)
C_vector_test[:, :, 0] = C0

system_var_size = 3

#diffusion_matrix_sqrt = torch.zeros([drift_vector_test.size(0), drift_vector_test.size(2), system_var_size, system_var_size], device = C_vector_test.device) #Create 3 x 3 zeros tensor to assign diffusion matrix elements.
diffusion_matrix_sqrt_test = torch.zeros([C_vector_test.size(0), system_var_size, system_var_size, C_vector_test.size(2)]) #Create 3 x 3 zeros tensor to assign diffusion matrix elements.
diffusion_matrix_sqrt_test[:, 0, 0, 0] = torch.sqrt(C0[0]) #Assigned S0 to element 1, 1 of matrix.
diffusion_matrix_sqrt_test[:, 1, 1, 0] = torch.sqrt(C0[1]) #Assigned D0 to element 2, 2 of matrix.
diffusion_matrix_sqrt_test[:, 2, 2, 0] = torch.sqrt(C0[2]) #Assigned M0 to element 3, 3 of matrix.

drift_vector_test = torch.zeros(path_test, 3, N)

#First step from 0 to dt trial
i = 1
SOC, DOC, MBC = [C_vector_test[:, l : l + 1, i - 1] for l in range(system_var_size)]
print('\n C_vector_test:', C_vector_test)
print('\n C0:', SOC, DOC, MBC)
print('\n k:', k_D[0, 0, i - 1], k_S[0, 0, i - 1], k_M[0, 0, i - 1])
print('\n I_S:', I_S[0, 0, i - 1])
print("\n scon_params_dict['a_DS'] * k_D[0, 0, i - 1] * DOC:", scon_params_dict['a_DS'] * k_D[0, 0, i - 1] * DOC)
print("\n scon_params_dict['a_M'] * scon_params_dict['a_MSC'] * k_M[0, 0, i - 1] * MBC:", scon_params_dict['a_M'] * scon_params_dict['a_MSC'] * k_M[0, 0, i - 1] * MBC)
print('\n -k_S[0, 0, i - 1] * SOC:', -k_S[0, 0, i - 1] * SOC)
print("\n I_S[0, 0, i - 1] + scon_params_dict['a_DS'] * k_D[0, 0, i - 1] * DOC + scon_params_dict['a_M'] * scon_params_dict['a_MSC'] * k_M[0, 0, i - 1] * MBC - k_S[0, 0, i - 1] * SOC:", I_S[0, 0, i - 1] + scon_params_dict['a_DS'] * k_D[0, 0, i - 1] * DOC + scon_params_dict['a_M'] * scon_params_dict['a_MSC'] * k_M[0, 0, i - 1] * MBC - k_S[0, 0, i - 1] * SOC)
drift_SOC = I_S[0, 0, i - 1] + scon_params_dict['a_DS'] * k_D[0, 0, i - 1] * DOC + scon_params_dict['a_M'] * scon_params_dict['a_MSC'] * k_M[0, 0, i - 1] * MBC - k_S[0, 0, i - 1] * SOC
drift_DOC = I_D[0, 0, i - 1] + scon_params_dict['a_SD'] * k_S[0, 0, i - 1] * SOC + scon_params_dict['a_M'] * (1 - scon_params_dict['a_MSC']) * k_M[0, 0, i - 1] * MBC - (scon_params_dict['u_M'] + k_D[0, 0, i - 1]) * DOC
print("\n scon_params_dict['u_M'] * DOC:", scon_params_dict['u_M'] * DOC)
print("\n -k_M[0, 0, i - 1] * MBC:", -k_M[0, 0, i - 1] * MBC)
drift_MBC = scon_params_dict['u_M'] * DOC - k_M[0, 0, i - 1] * MBC
print('\n drift_SOC:', drift_SOC)
print('\n drift_DOC:', drift_DOC)
print('\n drift_MBC:', drift_MBC)
print('\n torch.cat([drift_SOC, drift_DOC, drift_MBC], 1):', torch.cat([drift_SOC, drift_DOC, drift_MBC], 1))
#drift_vector_test[:, :, i - 1] = torch.cat([drift_SOC, drift_DOC, drift_MBC], 1) #Assign drift to all paths.
drift_vector_test[:, 0 : 1, i - 1] = drift_SOC
drift_vector_test[:, 1 : 2, i - 1] = drift_DOC
drift_vector_test[:, 2 : 3, i - 1] = drift_MBC
#Diffusion matrix is calculated (recall litter input is not a part of the drift vector or diffusion matrix).
diffusion_matrix_sqrt_test[:, 0, 0, i - 1] = torch.sqrt(LowerBound.apply(drift_SOC, 1e-10)).squeeze() #Assigned to element 1, 1 of matrix.
diffusion_matrix_sqrt_test[:, 1, 1, i - 1] = torch.sqrt(LowerBound.apply(drift_DOC, 1e-10)).squeeze() #Assigned to element 2, 2 of matrix.
diffusion_matrix_sqrt_test[:, 2, 2, i - 1] = torch.sqrt(LowerBound.apply(drift_MBC, 1e-10)).squeeze() #Assigned to element 3, 3 of matrix.

print('\n Check initial conditions not erased:', C_vector_test[:, :, 0])
print('\n Check drift vector:', drift_vector_test[:, :, i - 1] * dt)
update_test = C_vector_test[:, :, i - 1] + drift_vector_test[:, :, i - 1] * dt
print('\n Check diffusion matrix:', diffusion_matrix_sqrt_test[:, :, :, i - 1])
print('\n C_vector_test[0, :, i - 1]:', C_vector_test[0, :, i - 1])
print('\n drift_vector_test[:, :, i - 1] * dt:', drift_vector_test[0, :, i - 1] * dt)
print('\n Check Euler step without noise:', update_test)
C_vector_test[:, :, i] = d.multivariate_normal.MultivariateNormal(loc = C_vector_test[:, :, i - 1] + drift_vector_test[:, :, i - 1] * dt, scale_tril = diffusion_matrix_sqrt_test[:, :, :, i - 1] * math.sqrt(dt)).rsample()
#Error message for the above. Probably need to permute somewhere.
#for k in range(path_test):
    #C_vector_test[k, :, i] = d.multivariate_normal.MultivariateNormal(loc = C_vector_test[k, :, i - 1] + drift_vector_test[k, :, i - 1] * dt, scale_tril = diffusion_matrix_sqrt[k, :, :, i - 1] * math.sqrt(dt))


 C_vector_test: tensor([[[45.6603,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0715,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.7147,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[45.6603,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0715,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.7147,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])

 C0: tensor([[45.6603],
        [45.6603]]) tensor([[0.0715],
        [0.0715]]) tensor([[0.7147],
        [0.7147]])

 k: tensor(0.0050) tensor(2.5000e-05) tensor(0.0002)

 I_S: tensor(0.0010)

 scon_params_dict['a_DS'] * k_D[0, 0, i - 1] * DOC: tensor([[0.0001],
        [0.0001]])

 scon_params_dict['a_M'] * scon_params_dict['a_MSC'] * k_M[0, 0, i - 1] * MBC: tensor([[2.3585e-05],
        [2.3585e-05]])

 -k_S[0, 0, i - 1] * SOC: tensor([[-0.0011],
        [-0.0011]])

 I_S[0, 0, i - 1] + scon_params_dict['a_DS'] * k_D[0, 0, i - 1] * DOC + scon

In [59]:
C_vector_test = torch.zeros(path_test, 3, N)
C0 = analytical_steady_state_init_con(I_S[0, 0, 0].item(), I_D[0, 0, 0].item(), scon_params_dict)
C_vector_test[:, :, 0] = C0

system_var_size = 3

diffusion_matrix_sqrt_test = torch.zeros([C_vector_test.size(0), system_var_size, system_var_size, C_vector_test.size(2)]) #Create 3 x 3 zeros tensor to assign diffusion matrix elements.
diffusion_matrix_sqrt_test[:, 0, 0, 0] = torch.sqrt(C0[0]) #Assigned S0 to element 1, 1 of matrix.
diffusion_matrix_sqrt_test[:, 1, 1, 0] = torch.sqrt(C0[1]) #Assigned D0 to element 2, 2 of matrix.
diffusion_matrix_sqrt_test[:, 2, 2, 0] = torch.sqrt(C0[2]) #Assigned M0 to element 3, 3 of matrix.

drift_vector_test = torch.zeros(path_test, 3, N)


#Euler-Maruyama loop trial
#for i in range(1, len(T_span_tensor[0, 0, :])):
#for i in range(1, 2):
for i in range(1, N):
    SOC, DOC, MBC = [C_vector_test[:, l : l + 1, i - 1] for l in range(system_var_size)]
    print('SOC, DOC, MBC = ', SOC, ",", DOC, ",", MBC)
    drift_SOC = I_S[0, 0, i - 1] + scon_params_dict['a_DS'] * k_D[0, 0, i - 1] * DOC + scon_params_dict['a_M'] * scon_params_dict['a_MSC'] * k_M[0, 0, i - 1] * MBC - k_S[0, 0, i - 1] * SOC
    drift_DOC = I_D[0, 0, i - 1] + scon_params_dict['a_SD'] * k_S[0, 0, i - 1] * SOC + scon_params_dict['a_M'] * (1 - scon_params_dict['a_MSC']) * k_M[0, 0, i - 1] * MBC - (scon_params_dict['u_M'] + k_D[0, 0, i - 1]) * DOC
    drift_MBC = scon_params_dict['u_M'] * DOC - k_M[0, 0, i - 1] * MBC
    #drift_vector_test[:, :, i - 1] = torch.cat([drift_SOC, drift_DOC, drift_MBC], 1) #Assign drift to all paths.
    drift_vector_test[:, 0 : 1, i - 1] = drift_SOC
    drift_vector_test[:, 1 : 2, i - 1] = drift_DOC
    drift_vector_test[:, 2 : 3, i - 1] = drift_MBC    #Diffusion matrix is calculated (recall litter input is not a part of the drift vector or diffusion matrix).
    diffusion_matrix_sqrt_test[:, 0, 0, i - 1] = torch.sqrt(LowerBound.apply(drift_SOC, 1e-10)).squeeze() #Assigned to element 1, 1 of matrix.
    diffusion_matrix_sqrt_test[:, 1, 1, i - 1] = torch.sqrt(LowerBound.apply(drift_DOC, 1e-10)).squeeze() #Assigned to element 2, 2 of matrix.
    diffusion_matrix_sqrt_test[:, 2, 2, i - 1] = torch.sqrt(LowerBound.apply(drift_MBC, 1e-10)).squeeze() #Assigned to element 3, 3 of matrix.
    C_vector_test[:, :, i] = d.multivariate_normal.MultivariateNormal(loc = C_vector_test[:, :, i - 1] + drift_vector_test[:, :, i - 1] * dt, scale_tril = diffusion_matrix_sqrt_test[:, :, :, i - 1] * math.sqrt(dt)).rsample()
    C_vector_test[:, :, i][C_vector_test[:, :, i] < 1e-9] = 1e-9
print(C_vector_test)

SOC, DOC, MBC =  tensor([[45.6603],
        [45.6603]]) , tensor([[0.0715],
        [0.0715]]) , tensor([[0.7147],
        [0.7147]])
SOC, DOC, MBC =  tensor([[45.6603],
        [45.6603]]) , tensor([[0.0715],
        [0.0715]]) , tensor([[0.7147],
        [0.7147]])
SOC, DOC, MBC =  tensor([[45.6603],
        [45.6603]]) , tensor([[0.0722],
        [0.0711]]) , tensor([[0.7147],
        [0.7147]])
SOC, DOC, MBC =  tensor([[45.6603],
        [45.6603]]) , tensor([[0.0713],
        [0.0721]]) , tensor([[0.7147],
        [0.7147]])
SOC, DOC, MBC =  tensor([[45.6603],
        [45.6603]]) , tensor([[0.0729],
        [0.0721]]) , tensor([[0.7147],
        [0.7147]])
SOC, DOC, MBC =  tensor([[45.6603],
        [45.6603]]) , tensor([[0.0716],
        [0.0694]]) , tensor([[0.7147],
        [0.7147]])
SOC, DOC, MBC =  tensor([[45.6603],
        [45.6603]]) , tensor([[0.0708],
        [0.0693]]) , tensor([[0.7147],
        [0.7147]])
SOC, DOC, MBC =  tensor([[45.6603],
        [45.6603]]) , tens

SOC, DOC, MBC =  tensor([[45.6479],
        [45.6468]]) , tensor([[0.1026],
        [0.0280]]) , tensor([[0.7141],
        [0.7134]])
SOC, DOC, MBC =  tensor([[45.6478],
        [45.6466]]) , tensor([[0.1026],
        [0.0399]]) , tensor([[0.7141],
        [0.7134]])
SOC, DOC, MBC =  tensor([[45.6477],
        [45.6465]]) , tensor([[0.1026],
        [0.0478]]) , tensor([[0.7141],
        [0.7134]])
SOC, DOC, MBC =  tensor([[45.6476],
        [45.6465]]) , tensor([[0.1026],
        [0.0432]]) , tensor([[0.7141],
        [0.7134]])
SOC, DOC, MBC =  tensor([[45.6475],
        [45.6463]]) , tensor([[0.1025],
        [0.0417]]) , tensor([[0.7141],
        [0.7134]])
SOC, DOC, MBC =  tensor([[45.6474],
        [45.6463]]) , tensor([[0.1025],
        [0.0390]]) , tensor([[0.7141],
        [0.7133]])
SOC, DOC, MBC =  tensor([[45.6474],
        [45.6462]]) , tensor([[0.1025],
        [0.0427]]) , tensor([[0.7141],
        [0.7133]])
SOC, DOC, MBC =  tensor([[45.6473],
        [45.6461]]) , tens

SOC, DOC, MBC =  tensor([[45.6837],
        [45.6792]]) , tensor([[0.1006],
        [0.0667]]) , tensor([[0.6955],
        [0.7261]])
SOC, DOC, MBC =  tensor([[45.6921],
        [45.6851]]) , tensor([[0.1006],
        [0.0667]]) , tensor([[0.6985],
        [0.7267]])
SOC, DOC, MBC =  tensor([[45.6942],
        [45.6826]]) , tensor([[0.1006],
        [0.0667]]) , tensor([[0.6890],
        [0.7249]])
SOC, DOC, MBC =  tensor([[45.6897],
        [45.6760]]) , tensor([[0.1005],
        [0.0667]]) , tensor([[0.6924],
        [0.7244]])
SOC, DOC, MBC =  tensor([[45.6913],
        [45.6698]]) , tensor([[0.1005],
        [0.0667]]) , tensor([[0.6883],
        [0.7254]])
SOC, DOC, MBC =  tensor([[45.6922],
        [45.6676]]) , tensor([[0.1005],
        [0.0667]]) , tensor([[0.6835],
        [0.7238]])
SOC, DOC, MBC =  tensor([[45.6848],
        [45.6694]]) , tensor([[0.1005],
        [0.0667]]) , tensor([[0.6783],
        [0.7221]])
SOC, DOC, MBC =  tensor([[45.6997],
        [45.6691]]) , tens

        [45.6885]]) , tensor([[0.0993],
        [0.0341]]) , tensor([[0.7070],
        [0.7113]])
SOC, DOC, MBC =  tensor([[45.6752],
        [45.6885]]) , tensor([[0.0993],
        [0.0425]]) , tensor([[0.7070],
        [0.7113]])
SOC, DOC, MBC =  tensor([[45.6751],
        [45.6884]]) , tensor([[0.0993],
        [0.0440]]) , tensor([[0.7070],
        [0.7113]])
SOC, DOC, MBC =  tensor([[45.6750],
        [45.6883]]) , tensor([[0.0992],
        [0.0454]]) , tensor([[0.7070],
        [0.7113]])
SOC, DOC, MBC =  tensor([[45.6750],
        [45.6882]]) , tensor([[0.0992],
        [0.0456]]) , tensor([[0.7070],
        [0.7113]])
SOC, DOC, MBC =  tensor([[45.6749],
        [45.6881]]) , tensor([[0.0992],
        [0.0529]]) , tensor([[0.7070],
        [0.7112]])
SOC, DOC, MBC =  tensor([[45.6748],
        [45.6880]]) , tensor([[0.0992],
        [0.0527]]) , tensor([[0.7070],
        [0.7112]])
SOC, DOC, MBC =  tensor([[45.6747],
        [45.6879]]) , tensor([[0.0992],
        [0.0518]]) , t

SOC, DOC, MBC =  tensor([[45.6621],
        [45.6750]]) , tensor([[0.0887],
        [0.1006]]) , tensor([[0.7063],
        [0.7112]])
SOC, DOC, MBC =  tensor([[45.6621],
        [45.6750]]) , tensor([[0.0887],
        [0.1005]]) , tensor([[0.7063],
        [0.7135]])
SOC, DOC, MBC =  tensor([[45.6620],
        [45.6749]]) , tensor([[0.0887],
        [0.1005]]) , tensor([[0.7063],
        [0.7124]])
SOC, DOC, MBC =  tensor([[45.6620],
        [45.6749]]) , tensor([[0.0886],
        [0.1005]]) , tensor([[0.7063],
        [0.7123]])
SOC, DOC, MBC =  tensor([[45.6619],
        [45.6748]]) , tensor([[0.0886],
        [0.1005]]) , tensor([[0.7063],
        [0.7123]])
SOC, DOC, MBC =  tensor([[45.6619],
        [45.6748]]) , tensor([[0.0886],
        [0.1005]]) , tensor([[0.7063],
        [0.7138]])
SOC, DOC, MBC =  tensor([[45.6619],
        [45.6748]]) , tensor([[0.0886],
        [0.1005]]) , tensor([[0.7065],
        [0.7128]])
SOC, DOC, MBC =  tensor([[45.6618],
        [45.6747]]) , tens

SOC, DOC, MBC =  tensor([[45.6436],
        [45.7055]]) , tensor([[0.0878],
        [0.0993]]) , tensor([[0.6945],
        [0.6641]])
SOC, DOC, MBC =  tensor([[45.6515],
        [45.7047]]) , tensor([[0.0878],
        [0.0993]]) , tensor([[0.6939],
        [0.6659]])
SOC, DOC, MBC =  tensor([[45.6443],
        [45.7101]]) , tensor([[0.0878],
        [0.0993]]) , tensor([[0.7010],
        [0.6701]])
SOC, DOC, MBC =  tensor([[45.6471],
        [45.7065]]) , tensor([[0.0878],
        [0.0992]]) , tensor([[0.6982],
        [0.6672]])
SOC, DOC, MBC =  tensor([[45.6497],
        [45.7104]]) , tensor([[0.0878],
        [0.0992]]) , tensor([[0.6959],
        [0.6684]])
SOC, DOC, MBC =  tensor([[45.6590],
        [45.6989]]) , tensor([[0.0877],
        [0.0992]]) , tensor([[0.7020],
        [0.6655]])
SOC, DOC, MBC =  tensor([[45.6505],
        [45.6930]]) , tensor([[0.0877],
        [0.0992]]) , tensor([[0.7000],
        [0.6674]])
SOC, DOC, MBC =  tensor([[45.6397],
        [45.6869]]) , tens

SOC, DOC, MBC =  tensor([[45.6574],
        [45.7995]]) , tensor([[0.0869],
        [0.0980]]) , tensor([[0.7446],
        [0.6689]])
SOC, DOC, MBC =  tensor([[45.6534],
        [45.8040]]) , tensor([[0.0869],
        [0.0980]]) , tensor([[0.7434],
        [0.6682]])
SOC, DOC, MBC =  tensor([[45.6573],
        [45.8042]]) , tensor([[0.0869],
        [0.0979]]) , tensor([[0.7418],
        [0.6689]])
SOC, DOC, MBC =  tensor([[45.6542],
        [45.8035]]) , tensor([[0.0869],
        [0.0979]]) , tensor([[0.7407],
        [0.6682]])
SOC, DOC, MBC =  tensor([[45.6517],
        [45.8023]]) , tensor([[0.0869],
        [0.0979]]) , tensor([[0.7411],
        [0.6664]])
SOC, DOC, MBC =  tensor([[45.6534],
        [45.8004]]) , tensor([[0.0868],
        [0.0979]]) , tensor([[0.7418],
        [0.6638]])
SOC, DOC, MBC =  tensor([[45.6534],
        [45.8010]]) , tensor([[0.0868],
        [0.0979]]) , tensor([[0.7417],
        [0.6627]])
SOC, DOC, MBC =  tensor([[45.6534],
        [45.8010]]) , tens

        [0.1015]]) , tensor([[0.7455],
        [0.6597]])
SOC, DOC, MBC =  tensor([[45.6462],
        [45.7940]]) , tensor([[0.0994],
        [0.0990]]) , tensor([[0.7454],
        [0.6597]])
SOC, DOC, MBC =  tensor([[45.6460],
        [45.7938]]) , tensor([[0.1011],
        [0.1026]]) , tensor([[0.7454],
        [0.6597]])
SOC, DOC, MBC =  tensor([[45.6458],
        [45.7936]]) , tensor([[0.0970],
        [0.1002]]) , tensor([[0.7454],
        [0.6597]])
SOC, DOC, MBC =  tensor([[45.6455],
        [45.7934]]) , tensor([[0.1000],
        [0.0981]]) , tensor([[0.7454],
        [0.6597]])
SOC, DOC, MBC =  tensor([[45.6453],
        [45.7931]]) , tensor([[0.1026],
        [0.1062]]) , tensor([[0.7454],
        [0.6597]])
SOC, DOC, MBC =  tensor([[45.6451],
        [45.7929]]) , tensor([[0.1024],
        [0.1062]]) , tensor([[0.7454],
        [0.6597]])
SOC, DOC, MBC =  tensor([[45.6449],
        [45.7927]]) , tensor([[0.1029],
        [0.1062]]) , tensor([[0.7454],
        [0.6596]])
SOC,

        [45.7958]]) , tensor([[0.0771],
        [0.1035]]) , tensor([[0.7650],
        [0.6537]])
SOC, DOC, MBC =  tensor([[45.5833],
        [45.8015]]) , tensor([[0.0771],
        [0.1035]]) , tensor([[0.7650],
        [0.6552]])
SOC, DOC, MBC =  tensor([[45.5935],
        [45.7957]]) , tensor([[0.0771],
        [0.1035]]) , tensor([[0.7595],
        [0.6585]])
SOC, DOC, MBC =  tensor([[45.5927],
        [45.7986]]) , tensor([[0.0771],
        [0.1034]]) , tensor([[0.7575],
        [0.6615]])
SOC, DOC, MBC =  tensor([[45.5935],
        [45.8062]]) , tensor([[0.0771],
        [0.1034]]) , tensor([[0.7564],
        [0.6564]])
SOC, DOC, MBC =  tensor([[45.5908],
        [45.8070]]) , tensor([[0.0771],
        [0.1034]]) , tensor([[0.7532],
        [0.6492]])
SOC, DOC, MBC =  tensor([[45.5861],
        [45.8022]]) , tensor([[0.0770],
        [0.1034]]) , tensor([[0.7562],
        [0.6440]])
SOC, DOC, MBC =  tensor([[45.5996],
        [45.7990]]) , tensor([[0.0770],
        [0.1034]]) , t

SOC, DOC, MBC =  tensor([[45.6061],
        [45.8078]]) , tensor([[0.0927],
        [0.1022]]) , tensor([[0.7517],
        [0.6494]])
SOC, DOC, MBC =  tensor([[45.6060],
        [45.8076]]) , tensor([[0.0952],
        [0.1022]]) , tensor([[0.7517],
        [0.6494]])
SOC, DOC, MBC =  tensor([[45.6058],
        [45.8075]]) , tensor([[0.0956],
        [0.1022]]) , tensor([[0.7517],
        [0.6494]])
SOC, DOC, MBC =  tensor([[45.6056],
        [45.8073]]) , tensor([[0.0976],
        [0.1022]]) , tensor([[0.7517],
        [0.6494]])
SOC, DOC, MBC =  tensor([[45.6055],
        [45.8071]]) , tensor([[0.0974],
        [0.1022]]) , tensor([[0.7517],
        [0.6494]])
SOC, DOC, MBC =  tensor([[45.6053],
        [45.8070]]) , tensor([[0.0984],
        [0.1022]]) , tensor([[0.7517],
        [0.6493]])
SOC, DOC, MBC =  tensor([[45.6051],
        [45.8068]]) , tensor([[0.0975],
        [0.1022]]) , tensor([[0.7516],
        [0.6493]])
SOC, DOC, MBC =  tensor([[45.6049],
        [45.8066]]) , tens

        [0.6511]])
SOC, DOC, MBC =  tensor([[45.5948],
        [45.7965]]) , tensor([[0.0922],
        [0.1048]]) , tensor([[0.7510],
        [0.6523]])
SOC, DOC, MBC =  tensor([[45.5947],
        [45.7964]]) , tensor([[0.0922],
        [0.1048]]) , tensor([[0.7510],
        [0.6524]])
SOC, DOC, MBC =  tensor([[45.5947],
        [45.7964]]) , tensor([[0.0922],
        [0.1048]]) , tensor([[0.7510],
        [0.6534]])
SOC, DOC, MBC =  tensor([[45.5946],
        [45.7963]]) , tensor([[0.0922],
        [0.1048]]) , tensor([[0.7510],
        [0.6532]])
SOC, DOC, MBC =  tensor([[45.5945],
        [45.7963]]) , tensor([[0.0922],
        [0.1047]]) , tensor([[0.7510],
        [0.6553]])
SOC, DOC, MBC =  tensor([[45.5945],
        [45.7962]]) , tensor([[0.0922],
        [0.1047]]) , tensor([[0.7510],
        [0.6537]])
SOC, DOC, MBC =  tensor([[45.5944],
        [45.7962]]) , tensor([[0.0922],
        [0.1047]]) , tensor([[0.7510],
        [0.6524]])
SOC, DOC, MBC =  tensor([[45.5944],
       

SOC, DOC, MBC =  tensor([[45.6714],
        [45.8364]]) , tensor([[0.0912],
        [0.1032]]) , tensor([[0.7378],
        [0.6378]])
SOC, DOC, MBC =  tensor([[45.6737],
        [45.8429]]) , tensor([[0.0911],
        [0.1032]]) , tensor([[0.7415],
        [0.6365]])
SOC, DOC, MBC =  tensor([[45.6821],
        [45.8502]]) , tensor([[0.0911],
        [0.1032]]) , tensor([[0.7469],
        [0.6447]])
SOC, DOC, MBC =  tensor([[45.6634],
        [45.8574]]) , tensor([[0.0911],
        [0.1032]]) , tensor([[0.7488],
        [0.6488]])
SOC, DOC, MBC =  tensor([[45.6636],
        [45.8527]]) , tensor([[0.0911],
        [0.1031]]) , tensor([[0.7494],
        [0.6484]])
SOC, DOC, MBC =  tensor([[45.6723],
        [45.8716]]) , tensor([[0.0911],
        [0.1031]]) , tensor([[0.7424],
        [0.6479]])
SOC, DOC, MBC =  tensor([[45.6763],
        [45.8742]]) , tensor([[0.0911],
        [0.1031]]) , tensor([[0.7440],
        [0.6457]])
SOC, DOC, MBC =  tensor([[45.6789],
        [45.8764]]) , tens

SOC, DOC, MBC =  tensor([[45.7058],
        [45.9785]]) , tensor([[0.0901],
        [0.1017]]) , tensor([[0.7676],
        [0.6216]])
SOC, DOC, MBC =  tensor([[45.7058],
        [45.9785]]) , tensor([[0.0901],
        [0.1017]]) , tensor([[0.7678],
        [0.6172]])
SOC, DOC, MBC =  tensor([[45.7058],
        [45.9785]]) , tensor([[0.0900],
        [0.1017]]) , tensor([[0.7678],
        [0.6171]])
SOC, DOC, MBC =  tensor([[45.7058],
        [45.9785]]) , tensor([[0.0900],
        [0.1017]]) , tensor([[0.7678],
        [0.6188]])
SOC, DOC, MBC =  tensor([[45.7057],
        [45.9784]]) , tensor([[0.0900],
        [0.1017]]) , tensor([[0.7678],
        [0.6187]])
SOC, DOC, MBC =  tensor([[45.7057],
        [45.9784]]) , tensor([[0.0900],
        [0.1016]]) , tensor([[0.7678],
        [0.6189]])
SOC, DOC, MBC =  tensor([[45.7057],
        [45.9784]]) , tensor([[0.0900],
        [0.1016]]) , tensor([[0.7678],
        [0.6190]])
SOC, DOC, MBC =  tensor([[45.7056],
        [45.9783]]) , tens

SOC, DOC, MBC =  tensor([[45.6964],
        [45.9690]]) , tensor([[0.1061],
        [0.0733]]) , tensor([[0.7672],
        [0.6123]])
SOC, DOC, MBC =  tensor([[45.6962],
        [45.9687]]) , tensor([[0.1061],
        [0.0786]]) , tensor([[0.7672],
        [0.6123]])
SOC, DOC, MBC =  tensor([[45.6960],
        [45.9685]]) , tensor([[0.1061],
        [0.0763]]) , tensor([[0.7672],
        [0.6123]])
SOC, DOC, MBC =  tensor([[45.6958],
        [45.9683]]) , tensor([[0.1061],
        [0.0694]]) , tensor([[0.7672],
        [0.6123]])
SOC, DOC, MBC =  tensor([[45.6956],
        [45.9681]]) , tensor([[0.1061],
        [0.0679]]) , tensor([[0.7672],
        [0.6123]])
SOC, DOC, MBC =  tensor([[45.6954],
        [45.9679]]) , tensor([[0.1061],
        [0.0687]]) , tensor([[0.7671],
        [0.6123]])
SOC, DOC, MBC =  tensor([[45.6952],
        [45.9676]]) , tensor([[0.1061],
        [0.0666]]) , tensor([[0.7671],
        [0.6123]])
SOC, DOC, MBC =  tensor([[45.6950],
        [45.9674]]) , tens

        [45.9535]]) , tensor([[0.1049],
        [0.0743]]) , tensor([[0.7724],
        [0.6021]])
SOC, DOC, MBC =  tensor([[45.6865],
        [45.9629]]) , tensor([[0.1049],
        [0.0742]]) , tensor([[0.7711],
        [0.6019]])
SOC, DOC, MBC =  tensor([[45.6790],
        [45.9702]]) , tensor([[0.1049],
        [0.0742]]) , tensor([[0.7696],
        [0.5970]])
SOC, DOC, MBC =  tensor([[45.6694],
        [45.9743]]) , tensor([[0.1049],
        [0.0742]]) , tensor([[0.7675],
        [0.5970]])
SOC, DOC, MBC =  tensor([[45.6729],
        [45.9778]]) , tensor([[0.1049],
        [0.0742]]) , tensor([[0.7677],
        [0.5959]])
SOC, DOC, MBC =  tensor([[45.6804],
        [45.9801]]) , tensor([[0.1048],
        [0.0742]]) , tensor([[0.7693],
        [0.5972]])
SOC, DOC, MBC =  tensor([[45.6839],
        [45.9850]]) , tensor([[0.1048],
        [0.0742]]) , tensor([[0.7721],
        [0.5991]])
SOC, DOC, MBC =  tensor([[45.6797],
        [45.9884]]) , tensor([[0.1048],
        [0.0742]]) , t

        [45.9704]]) , tensor([[0.1031],
        [0.0736]]) , tensor([[0.7729],
        [0.6052]])
SOC, DOC, MBC =  tensor([[45.6965],
        [45.9611]]) , tensor([[0.1031],
        [0.0736]]) , tensor([[0.7720],
        [0.6043]])
SOC, DOC, MBC =  tensor([[45.6871],
        [45.9582]]) , tensor([[0.1031],
        [0.0736]]) , tensor([[0.7685],
        [0.6024]])
SOC, DOC, MBC =  tensor([[45.6885],
        [45.9513]]) , tensor([[0.1031],
        [0.0736]]) , tensor([[0.7684],
        [0.5991]])
SOC, DOC, MBC =  tensor([[45.6865],
        [45.9393]]) , tensor([[0.1030],
        [0.0736]]) , tensor([[0.7706],
        [0.5994]])
SOC, DOC, MBC =  tensor([[45.6815],
        [45.9511]]) , tensor([[0.1030],
        [0.0736]]) , tensor([[0.7671],
        [0.5996]])
SOC, DOC, MBC =  tensor([[45.6666],
        [45.9562]]) , tensor([[0.1030],
        [0.0736]]) , tensor([[0.7685],
        [0.6022]])
SOC, DOC, MBC =  tensor([[45.6633],
        [45.9548]]) , tensor([[0.1030],
        [0.0736]]) , t

        [0.0950]]) , tensor([[0.7716],
        [0.6049]])
SOC, DOC, MBC =  tensor([[45.6679],
        [45.9396]]) , tensor([[0.0848],
        [0.0943]]) , tensor([[0.7716],
        [0.6049]])
SOC, DOC, MBC =  tensor([[45.6677],
        [45.9394]]) , tensor([[0.0865],
        [0.0897]]) , tensor([[0.7716],
        [0.6049]])
SOC, DOC, MBC =  tensor([[45.6675],
        [45.9391]]) , tensor([[0.0871],
        [0.0824]]) , tensor([[0.7716],
        [0.6049]])
SOC, DOC, MBC =  tensor([[45.6673],
        [45.9389]]) , tensor([[0.1030],
        [0.0870]]) , tensor([[0.7716],
        [0.6049]])
SOC, DOC, MBC =  tensor([[45.6671],
        [45.9387]]) , tensor([[0.1048],
        [0.0892]]) , tensor([[0.7715],
        [0.6049]])
SOC, DOC, MBC =  tensor([[45.6668],
        [45.9384]]) , tensor([[0.1049],
        [0.0860]]) , tensor([[0.7715],
        [0.6049]])
SOC, DOC, MBC =  tensor([[45.6666],
        [45.9382]]) , tensor([[0.1059],
        [0.0858]]) , tensor([[0.7715],
        [0.6049]])
SOC,

SOC, DOC, MBC =  tensor([[45.7227],
        [45.9271]]) , tensor([[0.1047],
        [0.0996]]) , tensor([[0.7788],
        [0.6232]])
SOC, DOC, MBC =  tensor([[45.7210],
        [45.9312]]) , tensor([[0.1047],
        [0.0996]]) , tensor([[0.7793],
        [0.6218]])
SOC, DOC, MBC =  tensor([[45.7218],
        [45.9352]]) , tensor([[0.1047],
        [0.0995]]) , tensor([[0.7776],
        [0.6213]])
SOC, DOC, MBC =  tensor([[45.7252],
        [45.9253]]) , tensor([[0.1047],
        [0.0995]]) , tensor([[0.7808],
        [0.6256]])
SOC, DOC, MBC =  tensor([[45.7292],
        [45.9258]]) , tensor([[0.1047],
        [0.0995]]) , tensor([[0.7826],
        [0.6313]])
SOC, DOC, MBC =  tensor([[45.7381],
        [45.9220]]) , tensor([[0.1046],
        [0.0995]]) , tensor([[0.7787],
        [0.6247]])
SOC, DOC, MBC =  tensor([[45.7387],
        [45.9182]]) , tensor([[0.1046],
        [0.0995]]) , tensor([[0.7777],
        [0.6251]])
SOC, DOC, MBC =  tensor([[45.7357],
        [45.9191]]) , tens

SOC, DOC, MBC =  tensor([[45.7711],
        [45.9057]]) , tensor([[0.1025],
        [0.0977]]) , tensor([[0.7514],
        [0.6061]])
SOC, DOC, MBC =  tensor([[45.7711],
        [45.9057]]) , tensor([[0.1025],
        [0.0976]]) , tensor([[0.7516],
        [0.6083]])
SOC, DOC, MBC =  tensor([[45.7711],
        [45.9056]]) , tensor([[0.1025],
        [0.0976]]) , tensor([[0.7509],
        [0.6066]])
SOC, DOC, MBC =  tensor([[45.7711],
        [45.9056]]) , tensor([[0.1025],
        [0.0976]]) , tensor([[0.7537],
        [0.6024]])
SOC, DOC, MBC =  tensor([[45.7710],
        [45.9056]]) , tensor([[0.1025],
        [0.0976]]) , tensor([[0.7527],
        [0.6020]])
SOC, DOC, MBC =  tensor([[45.7710],
        [45.9056]]) , tensor([[0.1024],
        [0.0976]]) , tensor([[0.7545],
        [0.6038]])
SOC, DOC, MBC =  tensor([[45.7710],
        [45.9055]]) , tensor([[0.1024],
        [0.0975]]) , tensor([[0.7559],
        [0.6020]])
SOC, DOC, MBC =  tensor([[45.7710],
        [45.9055]]) , tens

        [45.8929]]) , tensor([[0.0771],
        [0.0470]]) , tensor([[0.7547],
        [0.5969]])
SOC, DOC, MBC =  tensor([[45.7584],
        [45.8927]]) , tensor([[0.0728],
        [0.0400]]) , tensor([[0.7547],
        [0.5969]])
SOC, DOC, MBC =  tensor([[45.7582],
        [45.8925]]) , tensor([[0.0588],
        [0.0412]]) , tensor([[0.7547],
        [0.5969]])
SOC, DOC, MBC =  tensor([[45.7581],
        [45.8923]]) , tensor([[0.0605],
        [0.0572]]) , tensor([[0.7547],
        [0.5969]])
SOC, DOC, MBC =  tensor([[45.7579],
        [45.8922]]) , tensor([[0.0611],
        [0.0539]]) , tensor([[0.7546],
        [0.5969]])
SOC, DOC, MBC =  tensor([[45.7577],
        [45.8920]]) , tensor([[0.0448],
        [0.0535]]) , tensor([[0.7546],
        [0.5968]])
SOC, DOC, MBC =  tensor([[45.7575],
        [45.8918]]) , tensor([[0.0418],
        [0.0577]]) , tensor([[0.7546],
        [0.5968]])
SOC, DOC, MBC =  tensor([[45.7574],
        [45.8916]]) , tensor([[0.0378],
        [0.0560]]) , t

SOC, DOC, MBC =  tensor([[45.7167],
        [45.9347]]) , tensor([[0.0589],
        [0.0688]]) , tensor([[0.7591],
        [0.5887]])
SOC, DOC, MBC =  tensor([[45.7026],
        [45.9352]]) , tensor([[0.0588],
        [0.0688]]) , tensor([[0.7598],
        [0.5891]])
SOC, DOC, MBC =  tensor([[45.7146],
        [45.9371]]) , tensor([[0.0588],
        [0.0688]]) , tensor([[0.7592],
        [0.5904]])
SOC, DOC, MBC =  tensor([[45.7281],
        [45.9340]]) , tensor([[0.0588],
        [0.0688]]) , tensor([[0.7586],
        [0.5923]])
SOC, DOC, MBC =  tensor([[45.7420],
        [45.9298]]) , tensor([[0.0588],
        [0.0688]]) , tensor([[0.7565],
        [0.5887]])
SOC, DOC, MBC =  tensor([[45.7319],
        [45.9309]]) , tensor([[0.0588],
        [0.0687]]) , tensor([[0.7545],
        [0.5872]])
SOC, DOC, MBC =  tensor([[45.7241],
        [45.9281]]) , tensor([[0.0588],
        [0.0687]]) , tensor([[0.7542],
        [0.5840]])
SOC, DOC, MBC =  tensor([[45.7071],
        [45.9281]]) , tens

SOC, DOC, MBC =  tensor([[45.6989],
        [45.9173]]) , tensor([[0.0479],
        [0.0682]]) , tensor([[0.7329],
        [0.5827]])
SOC, DOC, MBC =  tensor([[45.6957],
        [45.9176]]) , tensor([[0.0516],
        [0.0677]]) , tensor([[0.7329],
        [0.5843]])
SOC, DOC, MBC =  tensor([[45.6939],
        [45.9184]]) , tensor([[0.0492],
        [0.0669]]) , tensor([[0.7329],
        [0.5844]])
SOC, DOC, MBC =  tensor([[45.6936],
        [45.9159]]) , tensor([[0.0492],
        [0.0651]]) , tensor([[0.7329],
        [0.5828]])
SOC, DOC, MBC =  tensor([[45.6936],
        [45.9159]]) , tensor([[0.0528],
        [0.0644]]) , tensor([[0.7329],
        [0.5817]])
SOC, DOC, MBC =  tensor([[45.6936],
        [45.9159]]) , tensor([[0.0520],
        [0.0631]]) , tensor([[0.7329],
        [0.5823]])
SOC, DOC, MBC =  tensor([[45.6936],
        [45.9158]]) , tensor([[0.0541],
        [0.0641]]) , tensor([[0.7328],
        [0.5827]])
SOC, DOC, MBC =  tensor([[45.6936],
        [45.9158]]) , tens

        [0.0132]]) , tensor([[0.7319],
        [0.5820]])
SOC, DOC, MBC =  tensor([[45.6805],
        [45.9015]]) , tensor([[0.0631],
        [0.0241]]) , tensor([[0.7319],
        [0.5820]])
SOC, DOC, MBC =  tensor([[45.6803],
        [45.9013]]) , tensor([[0.0735],
        [0.0407]]) , tensor([[0.7319],
        [0.5820]])
SOC, DOC, MBC =  tensor([[45.6801],
        [45.9011]]) , tensor([[0.0860],
        [0.0377]]) , tensor([[0.7318],
        [0.5819]])
SOC, DOC, MBC =  tensor([[45.6799],
        [45.9009]]) , tensor([[0.0884],
        [0.0243]]) , tensor([[0.7318],
        [0.5819]])
SOC, DOC, MBC =  tensor([[45.6797],
        [45.9007]]) , tensor([[0.0812],
        [0.0039]]) , tensor([[0.7318],
        [0.5819]])
SOC, DOC, MBC =  tensor([[45.6795],
        [45.9005]]) , tensor([[0.0788],
        [0.0111]]) , tensor([[0.7318],
        [0.5819]])
SOC, DOC, MBC =  tensor([[45.6793],
        [45.9003]]) , tensor([[0.0824],
        [0.0119]]) , tensor([[0.7318],
        [0.5819]])
SOC,

SOC, DOC, MBC =  tensor([[45.6629],
        [45.9371]]) , tensor([[6.5911e-02],
        [1.0000e-09]]) , tensor([[0.7271],
        [0.5811]])
SOC, DOC, MBC =  tensor([[45.6691],
        [45.9335]]) , tensor([[6.5907e-02],
        [1.0000e-09]]) , tensor([[0.7253],
        [0.5811]])
SOC, DOC, MBC =  tensor([[45.6793],
        [45.9537]]) , tensor([[0.0659],
        [0.0091]]) , tensor([[0.7257],
        [0.5811]])
SOC, DOC, MBC =  tensor([[45.6841],
        [45.9484]]) , tensor([[0.0659],
        [0.0094]]) , tensor([[0.7199],
        [0.5811]])
SOC, DOC, MBC =  tensor([[45.6868],
        [45.9416]]) , tensor([[0.0659],
        [0.0124]]) , tensor([[0.7238],
        [0.5811]])
SOC, DOC, MBC =  tensor([[45.6754],
        [45.9400]]) , tensor([[0.0659],
        [0.0113]]) , tensor([[0.7250],
        [0.5811]])
SOC, DOC, MBC =  tensor([[45.6749],
        [45.9357]]) , tensor([[0.0659],
        [0.0216]]) , tensor([[0.7293],
        [0.5811]])
SOC, DOC, MBC =  tensor([[45.6671],
        [4

SOC, DOC, MBC =  tensor([[45.6552],
        [45.8916]]) , tensor([[0.0689],
        [0.0311]]) , tensor([[0.7493],
        [0.5597]])
SOC, DOC, MBC =  tensor([[45.6552],
        [45.8915]]) , tensor([[0.0678],
        [0.0348]]) , tensor([[0.7492],
        [0.5597]])
SOC, DOC, MBC =  tensor([[45.6552],
        [45.8915]]) , tensor([[0.0658],
        [0.0332]]) , tensor([[0.7492],
        [0.5597]])
SOC, DOC, MBC =  tensor([[45.6551],
        [45.8915]]) , tensor([[0.0674],
        [0.0360]]) , tensor([[0.7492],
        [0.5597]])
SOC, DOC, MBC =  tensor([[45.6551],
        [45.8915]]) , tensor([[0.0614],
        [0.0401]]) , tensor([[0.7492],
        [0.5597]])
SOC, DOC, MBC =  tensor([[45.6551],
        [45.8914]]) , tensor([[0.0589],
        [0.0338]]) , tensor([[0.7492],
        [0.5597]])
SOC, DOC, MBC =  tensor([[45.6551],
        [45.8914]]) , tensor([[0.0553],
        [0.0276]]) , tensor([[0.7492],
        [0.5597]])
SOC, DOC, MBC =  tensor([[45.6551],
        [45.8914]]) , tens

        [45.8783]]) , tensor([[0.0340],
        [0.0858]]) , tensor([[0.7478],
        [0.5590]])
SOC, DOC, MBC =  tensor([[45.6416],
        [45.8780]]) , tensor([[0.0479],
        [0.0850]]) , tensor([[0.7478],
        [0.5589]])
SOC, DOC, MBC =  tensor([[45.6413],
        [45.8778]]) , tensor([[0.0496],
        [0.0910]]) , tensor([[0.7478],
        [0.5589]])
SOC, DOC, MBC =  tensor([[45.6411],
        [45.8776]]) , tensor([[0.0509],
        [0.0830]]) , tensor([[0.7478],
        [0.5589]])
SOC, DOC, MBC =  tensor([[45.6409],
        [45.8774]]) , tensor([[0.0674],
        [0.0810]]) , tensor([[0.7477],
        [0.5589]])
SOC, DOC, MBC =  tensor([[45.6407],
        [45.8772]]) , tensor([[0.0718],
        [0.0761]]) , tensor([[0.7477],
        [0.5589]])
SOC, DOC, MBC =  tensor([[45.6405],
        [45.8770]]) , tensor([[0.0643],
        [0.0760]]) , tensor([[0.7477],
        [0.5589]])
SOC, DOC, MBC =  tensor([[45.6403],
        [45.8768]]) , tensor([[0.0695],
        [0.0673]]) , t

SOC, DOC, MBC =  tensor([[45.6350],
        [45.9340]]) , tensor([[0.0876],
        [0.0869]]) , tensor([[0.7191],
        [0.5575]])
SOC, DOC, MBC =  tensor([[45.6414],
        [45.9362]]) , tensor([[0.0876],
        [0.0869]]) , tensor([[0.7213],
        [0.5565]])
SOC, DOC, MBC =  tensor([[45.6353],
        [45.9306]]) , tensor([[0.0876],
        [0.0869]]) , tensor([[0.7196],
        [0.5592]])
SOC, DOC, MBC =  tensor([[45.6426],
        [45.9177]]) , tensor([[0.0875],
        [0.0868]]) , tensor([[0.7180],
        [0.5656]])
SOC, DOC, MBC =  tensor([[45.6409],
        [45.9203]]) , tensor([[0.0875],
        [0.0868]]) , tensor([[0.7194],
        [0.5649]])
SOC, DOC, MBC =  tensor([[45.6378],
        [45.9318]]) , tensor([[0.0875],
        [0.0868]]) , tensor([[0.7221],
        [0.5623]])
SOC, DOC, MBC =  tensor([[45.6436],
        [45.9508]]) , tensor([[0.0875],
        [0.0868]]) , tensor([[0.7168],
        [0.5610]])
SOC, DOC, MBC =  tensor([[45.6379],
        [45.9521]]) , tens

        [45.8706]]) , tensor([[0.0866],
        [0.0858]]) , tensor([[0.7321],
        [0.5503]])
SOC, DOC, MBC =  tensor([[45.6289],
        [45.8706]]) , tensor([[0.0865],
        [0.0858]]) , tensor([[0.7321],
        [0.5541]])
SOC, DOC, MBC =  tensor([[45.6289],
        [45.8705]]) , tensor([[0.0865],
        [0.0858]]) , tensor([[0.7321],
        [0.5531]])
SOC, DOC, MBC =  tensor([[45.6288],
        [45.8705]]) , tensor([[0.0865],
        [0.0858]]) , tensor([[0.7321],
        [0.5530]])
SOC, DOC, MBC =  tensor([[45.6288],
        [45.8704]]) , tensor([[0.0865],
        [0.0858]]) , tensor([[0.7321],
        [0.5533]])
SOC, DOC, MBC =  tensor([[45.6287],
        [45.8704]]) , tensor([[0.0865],
        [0.0858]]) , tensor([[0.7321],
        [0.5559]])
SOC, DOC, MBC =  tensor([[45.6286],
        [45.8703]]) , tensor([[0.0865],
        [0.0858]]) , tensor([[0.7321],
        [0.5554]])
SOC, DOC, MBC =  tensor([[45.6285],
        [45.8702]]) , tensor([[0.0873],
        [0.0868]]) , t

        [45.8560]]) , tensor([[0.0906],
        [0.1095]]) , tensor([[0.7312],
        [0.5568]])
SOC, DOC, MBC =  tensor([[45.6141],
        [45.8558]]) , tensor([[0.0959],
        [0.1095]]) , tensor([[0.7312],
        [0.5564]])
SOC, DOC, MBC =  tensor([[45.6139],
        [45.8557]]) , tensor([[0.0979],
        [0.1095]]) , tensor([[0.7312],
        [0.5560]])
SOC, DOC, MBC =  tensor([[45.6138],
        [45.8555]]) , tensor([[0.0973],
        [0.1095]]) , tensor([[0.7312],
        [0.5547]])
SOC, DOC, MBC =  tensor([[45.6136],
        [45.8554]]) , tensor([[0.0984],
        [0.1094]]) , tensor([[0.7312],
        [0.5545]])
SOC, DOC, MBC =  tensor([[45.6135],
        [45.8552]]) , tensor([[0.0984],
        [0.1094]]) , tensor([[0.7312],
        [0.5524]])
SOC, DOC, MBC =  tensor([[45.6134],
        [45.8551]]) , tensor([[0.0984],
        [0.1094]]) , tensor([[0.7312],
        [0.5539]])
SOC, DOC, MBC =  tensor([[45.6132],
        [45.8550]]) , tensor([[0.0984],
        [0.1094]]) , t

        [0.5357]])
SOC, DOC, MBC =  tensor([[45.5868],
        [45.7997]]) , tensor([[0.0974],
        [0.1079]]) , tensor([[0.7435],
        [0.5325]])
SOC, DOC, MBC =  tensor([[45.5916],
        [45.8006]]) , tensor([[0.0974],
        [0.1079]]) , tensor([[0.7512],
        [0.5298]])
SOC, DOC, MBC =  tensor([[45.5949],
        [45.8011]]) , tensor([[0.0974],
        [0.1079]]) , tensor([[0.7490],
        [0.5303]])
SOC, DOC, MBC =  tensor([[45.5993],
        [45.8054]]) , tensor([[0.0974],
        [0.1078]]) , tensor([[0.7494],
        [0.5266]])
SOC, DOC, MBC =  tensor([[45.6088],
        [45.7983]]) , tensor([[0.0973],
        [0.1078]]) , tensor([[0.7456],
        [0.5309]])
SOC, DOC, MBC =  tensor([[45.6111],
        [45.8096]]) , tensor([[0.0973],
        [0.1078]]) , tensor([[0.7507],
        [0.5311]])
SOC, DOC, MBC =  tensor([[45.6109],
        [45.8029]]) , tensor([[0.0973],
        [0.1078]]) , tensor([[0.7483],
        [0.5315]])
SOC, DOC, MBC =  tensor([[45.6195],
       

        [0.1059]]) , tensor([[0.7567],
        [0.4641]])
SOC, DOC, MBC =  tensor([[45.5688],
        [45.7606]]) , tensor([[0.0959],
        [0.1059]]) , tensor([[0.7567],
        [0.4602]])
SOC, DOC, MBC =  tensor([[45.5688],
        [45.7606]]) , tensor([[0.0959],
        [0.1059]]) , tensor([[0.7564],
        [0.4596]])
SOC, DOC, MBC =  tensor([[45.5687],
        [45.7606]]) , tensor([[0.0958],
        [0.1059]]) , tensor([[0.7562],
        [0.4537]])
SOC, DOC, MBC =  tensor([[45.5687],
        [45.7606]]) , tensor([[0.0958],
        [0.1059]]) , tensor([[0.7562],
        [0.4586]])
SOC, DOC, MBC =  tensor([[45.5687],
        [45.7605]]) , tensor([[0.0958],
        [0.1058]]) , tensor([[0.7562],
        [0.4510]])
SOC, DOC, MBC =  tensor([[45.5686],
        [45.7605]]) , tensor([[0.0958],
        [0.1058]]) , tensor([[0.7562],
        [0.4502]])
SOC, DOC, MBC =  tensor([[45.5686],
        [45.7604]]) , tensor([[0.0958],
        [0.1058]]) , tensor([[0.7562],
        [0.4501]])
SOC,

SOC, DOC, MBC =  tensor([[45.5546],
        [45.7464]]) , tensor([[0.1116],
        [0.1091]]) , tensor([[0.7554],
        [0.4751]])
SOC, DOC, MBC =  tensor([[45.5545],
        [45.7462]]) , tensor([[0.1115],
        [0.1091]]) , tensor([[0.7554],
        [0.4755]])
SOC, DOC, MBC =  tensor([[45.5543],
        [45.7461]]) , tensor([[0.1115],
        [0.1091]]) , tensor([[0.7554],
        [0.4749]])
SOC, DOC, MBC =  tensor([[45.5541],
        [45.7459]]) , tensor([[0.1115],
        [0.1091]]) , tensor([[0.7554],
        [0.4748]])
SOC, DOC, MBC =  tensor([[45.5540],
        [45.7457]]) , tensor([[0.1115],
        [0.1091]]) , tensor([[0.7554],
        [0.4769]])
SOC, DOC, MBC =  tensor([[45.5538],
        [45.7455]]) , tensor([[0.1115],
        [0.1090]]) , tensor([[0.7554],
        [0.4763]])
SOC, DOC, MBC =  tensor([[45.5536],
        [45.7454]]) , tensor([[0.1115],
        [0.1090]]) , tensor([[0.7554],
        [0.4769]])
SOC, DOC, MBC =  tensor([[45.5535],
        [45.7452]]) , tens

SOC, DOC, MBC =  tensor([[45.6051],
        [45.7495]]) , tensor([[0.1098],
        [0.1074]]) , tensor([[0.7601],
        [0.4481]])
SOC, DOC, MBC =  tensor([[45.6138],
        [45.7500]]) , tensor([[0.1097],
        [0.1074]]) , tensor([[0.7697],
        [0.4522]])
SOC, DOC, MBC =  tensor([[45.6077],
        [45.7401]]) , tensor([[0.1097],
        [0.1073]]) , tensor([[0.7637],
        [0.4563]])
SOC, DOC, MBC =  tensor([[45.5994],
        [45.7409]]) , tensor([[0.1097],
        [0.1073]]) , tensor([[0.7640],
        [0.4567]])
SOC, DOC, MBC =  tensor([[45.6015],
        [45.7378]]) , tensor([[0.1097],
        [0.1073]]) , tensor([[0.7654],
        [0.4630]])
SOC, DOC, MBC =  tensor([[45.5968],
        [45.7450]]) , tensor([[0.1096],
        [0.1073]]) , tensor([[0.7675],
        [0.4642]])
SOC, DOC, MBC =  tensor([[45.5977],
        [45.7662]]) , tensor([[0.1096],
        [0.1072]]) , tensor([[0.7708],
        [0.4710]])
SOC, DOC, MBC =  tensor([[45.5896],
        [45.7556]]) , tens

        [45.8159]]) , tensor([[0.1079],
        [0.1056]]) , tensor([[0.8043],
        [0.5173]])
SOC, DOC, MBC =  tensor([[45.6228],
        [45.8130]]) , tensor([[0.1079],
        [0.1055]]) , tensor([[0.8046],
        [0.5172]])
SOC, DOC, MBC =  tensor([[45.6228],
        [45.8116]]) , tensor([[0.1079],
        [0.1055]]) , tensor([[0.8013],
        [0.5167]])
SOC, DOC, MBC =  tensor([[45.6230],
        [45.8106]]) , tensor([[0.1079],
        [0.1055]]) , tensor([[0.7989],
        [0.5133]])
SOC, DOC, MBC =  tensor([[45.6230],
        [45.8106]]) , tensor([[0.1078],
        [0.1055]]) , tensor([[0.7982],
        [0.5140]])
SOC, DOC, MBC =  tensor([[45.6230],
        [45.8106]]) , tensor([[0.1078],
        [0.1054]]) , tensor([[0.7979],
        [0.5167]])
SOC, DOC, MBC =  tensor([[45.6230],
        [45.8106]]) , tensor([[0.1078],
        [0.1054]]) , tensor([[0.8012],
        [0.5240]])
SOC, DOC, MBC =  tensor([[45.6230],
        [45.8106]]) , tensor([[0.1078],
        [0.1054]]) , t

SOC, DOC, MBC =  tensor([[45.6097],
        [45.7970]]) , tensor([[0.1102],
        [0.1100]]) , tensor([[0.8003],
        [0.5343]])
SOC, DOC, MBC =  tensor([[45.6095],
        [45.7968]]) , tensor([[0.1102],
        [0.1100]]) , tensor([[0.8003],
        [0.5343]])
SOC, DOC, MBC =  tensor([[45.6093],
        [45.7966]]) , tensor([[0.1102],
        [0.1100]]) , tensor([[0.8002],
        [0.5343]])
SOC, DOC, MBC =  tensor([[45.6091],
        [45.7964]]) , tensor([[0.1102],
        [0.1100]]) , tensor([[0.8002],
        [0.5343]])
SOC, DOC, MBC =  tensor([[45.6089],
        [45.7962]]) , tensor([[0.1102],
        [0.1100]]) , tensor([[0.8002],
        [0.5343]])
SOC, DOC, MBC =  tensor([[45.6087],
        [45.7960]]) , tensor([[0.1102],
        [0.1100]]) , tensor([[0.8002],
        [0.5343]])
SOC, DOC, MBC =  tensor([[45.6085],
        [45.7958]]) , tensor([[0.1102],
        [0.1100]]) , tensor([[0.8002],
        [0.5343]])
SOC, DOC, MBC =  tensor([[45.6083],
        [45.7956]]) , tens

SOC, DOC, MBC =  tensor([[45.6195],
        [45.8055]]) , tensor([[0.1084],
        [0.1082]]) , tensor([[0.8023],
        [0.5478]])
SOC, DOC, MBC =  tensor([[45.6108],
        [45.7994]]) , tensor([[0.1084],
        [0.1081]]) , tensor([[0.7997],
        [0.5495]])
SOC, DOC, MBC =  tensor([[45.6173],
        [45.8060]]) , tensor([[0.1083],
        [0.1081]]) , tensor([[0.7966],
        [0.5520]])
SOC, DOC, MBC =  tensor([[45.6181],
        [45.8206]]) , tensor([[0.1083],
        [0.1081]]) , tensor([[0.7954],
        [0.5531]])
SOC, DOC, MBC =  tensor([[45.6107],
        [45.8306]]) , tensor([[0.1083],
        [0.1081]]) , tensor([[0.7968],
        [0.5529]])
SOC, DOC, MBC =  tensor([[45.6256],
        [45.8215]]) , tensor([[0.1083],
        [0.1080]]) , tensor([[0.7968],
        [0.5566]])
SOC, DOC, MBC =  tensor([[45.6260],
        [45.8048]]) , tensor([[0.1082],
        [0.1080]]) , tensor([[0.7980],
        [0.5549]])
SOC, DOC, MBC =  tensor([[45.6311],
        [45.7889]]) , tens

        [0.1061]]) , tensor([[0.8255],
        [0.5820]])
SOC, DOC, MBC =  tensor([[45.6884],
        [45.8282]]) , tensor([[0.1063],
        [0.1061]]) , tensor([[0.8255],
        [0.5804]])
SOC, DOC, MBC =  tensor([[45.6883],
        [45.8281]]) , tensor([[0.1063],
        [0.1061]]) , tensor([[0.8255],
        [0.5811]])
SOC, DOC, MBC =  tensor([[45.6882],
        [45.8280]]) , tensor([[0.1063],
        [0.1061]]) , tensor([[0.8255],
        [0.5788]])
SOC, DOC, MBC =  tensor([[45.6881],
        [45.8279]]) , tensor([[0.1062],
        [0.1061]]) , tensor([[0.8255],
        [0.5779]])
SOC, DOC, MBC =  tensor([[45.6880],
        [45.8279]]) , tensor([[0.1062],
        [0.1060]]) , tensor([[0.8255],
        [0.5796]])
SOC, DOC, MBC =  tensor([[45.6879],
        [45.8278]]) , tensor([[0.1062],
        [0.1060]]) , tensor([[0.8255],
        [0.5793]])
SOC, DOC, MBC =  tensor([[45.6878],
        [45.8276]]) , tensor([[0.1062],
        [0.1060]]) , tensor([[0.8255],
        [0.5778]])
SOC,

        [45.8118]]) , tensor([[0.1126],
        [0.1074]]) , tensor([[0.8245],
        [0.5805]])
SOC, DOC, MBC =  tensor([[45.6723],
        [45.8117]]) , tensor([[0.1126],
        [0.1074]]) , tensor([[0.8245],
        [0.5805]])
SOC, DOC, MBC =  tensor([[45.6722],
        [45.8115]]) , tensor([[0.1125],
        [0.1074]]) , tensor([[0.8245],
        [0.5802]])
SOC, DOC, MBC =  tensor([[45.6721],
        [45.8114]]) , tensor([[0.1125],
        [0.1074]]) , tensor([[0.8245],
        [0.5797]])
SOC, DOC, MBC =  tensor([[45.6719],
        [45.8113]]) , tensor([[0.1125],
        [0.1074]]) , tensor([[0.8245],
        [0.5815]])
SOC, DOC, MBC =  tensor([[45.6718],
        [45.8112]]) , tensor([[0.1125],
        [0.1074]]) , tensor([[0.8245],
        [0.5819]])
SOC, DOC, MBC =  tensor([[45.6717],
        [45.8111]]) , tensor([[0.1125],
        [0.1074]]) , tensor([[0.8245],
        [0.5833]])
SOC, DOC, MBC =  tensor([[45.6716],
        [45.8110]]) , tensor([[0.1124],
        [0.1073]]) , t

SOC, DOC, MBC =  tensor([[45.7210],
        [45.7888]]) , tensor([[0.1100],
        [0.1051]]) , tensor([[0.8206],
        [0.6065]])
SOC, DOC, MBC =  tensor([[45.7303],
        [45.7857]]) , tensor([[0.1099],
        [0.1051]]) , tensor([[0.8225],
        [0.6120]])
SOC, DOC, MBC =  tensor([[45.7231],
        [45.7732]]) , tensor([[0.1099],
        [0.1051]]) , tensor([[0.8238],
        [0.6157]])
SOC, DOC, MBC =  tensor([[45.7323],
        [45.7761]]) , tensor([[0.1099],
        [0.1051]]) , tensor([[0.8245],
        [0.6162]])
SOC, DOC, MBC =  tensor([[45.7416],
        [45.7809]]) , tensor([[0.1099],
        [0.1051]]) , tensor([[0.8256],
        [0.6122]])
SOC, DOC, MBC =  tensor([[45.7525],
        [45.7807]]) , tensor([[0.1098],
        [0.1050]]) , tensor([[0.8246],
        [0.6167]])
SOC, DOC, MBC =  tensor([[45.7510],
        [45.7856]]) , tensor([[0.1098],
        [0.1050]]) , tensor([[0.8265],
        [0.6202]])
SOC, DOC, MBC =  tensor([[45.7529],
        [45.7883]]) , tens

SOC, DOC, MBC =  tensor([[45.7364],
        [45.8100]]) , tensor([[0.1084],
        [0.1055]]) , tensor([[0.7967],
        [0.6122]])
SOC, DOC, MBC =  tensor([[45.7362],
        [45.8098]]) , tensor([[0.1084],
        [0.1053]]) , tensor([[0.7967],
        [0.6122]])
SOC, DOC, MBC =  tensor([[45.7360],
        [45.8095]]) , tensor([[0.1084],
        [0.1017]]) , tensor([[0.7967],
        [0.6122]])
SOC, DOC, MBC =  tensor([[45.7357],
        [45.8093]]) , tensor([[0.1079],
        [0.1029]]) , tensor([[0.7967],
        [0.6122]])
SOC, DOC, MBC =  tensor([[45.7355],
        [45.8091]]) , tensor([[0.1069],
        [0.1043]]) , tensor([[0.7967],
        [0.6122]])
SOC, DOC, MBC =  tensor([[45.7353],
        [45.8088]]) , tensor([[0.1128],
        [0.1025]]) , tensor([[0.7967],
        [0.6122]])
SOC, DOC, MBC =  tensor([[45.7350],
        [45.8086]]) , tensor([[0.1128],
        [0.1045]]) , tensor([[0.7967],
        [0.6122]])
SOC, DOC, MBC =  tensor([[45.7348],
        [45.8084]]) , tens

SOC, DOC, MBC =  tensor([[45.7220],
        [45.7953]]) , tensor([[0.1125],
        [0.1112]]) , tensor([[0.7995],
        [0.6112]])
SOC, DOC, MBC =  tensor([[45.7220],
        [45.7953]]) , tensor([[0.1124],
        [0.1112]]) , tensor([[0.7993],
        [0.6086]])
SOC, DOC, MBC =  tensor([[45.7220],
        [45.7953]]) , tensor([[0.1124],
        [0.1112]]) , tensor([[0.8028],
        [0.6088]])
SOC, DOC, MBC =  tensor([[45.7220],
        [45.7953]]) , tensor([[0.1124],
        [0.1111]]) , tensor([[0.8027],
        [0.6063]])
SOC, DOC, MBC =  tensor([[45.7220],
        [45.7953]]) , tensor([[0.1123],
        [0.1111]]) , tensor([[0.8026],
        [0.6065]])
SOC, DOC, MBC =  tensor([[45.7228],
        [45.7976]]) , tensor([[0.1123],
        [0.1111]]) , tensor([[0.8024],
        [0.6075]])
SOC, DOC, MBC =  tensor([[45.7231],
        [45.7987]]) , tensor([[0.1123],
        [0.1111]]) , tensor([[0.8020],
        [0.6056]])
SOC, DOC, MBC =  tensor([[45.7194],
        [45.7977]]) , tens

SOC, DOC, MBC =  tensor([[45.7922],
        [45.8098]]) , tensor([[0.1106],
        [0.1093]]) , tensor([[0.8168],
        [0.6220]])
SOC, DOC, MBC =  tensor([[45.7934],
        [45.8064]]) , tensor([[0.1105],
        [0.1093]]) , tensor([[0.8195],
        [0.6176]])
SOC, DOC, MBC =  tensor([[45.7971],
        [45.8170]]) , tensor([[0.1105],
        [0.1092]]) , tensor([[0.8234],
        [0.6203]])
SOC, DOC, MBC =  tensor([[45.8050],
        [45.8169]]) , tensor([[0.1105],
        [0.1092]]) , tensor([[0.8199],
        [0.6137]])
SOC, DOC, MBC =  tensor([[45.8015],
        [45.8197]]) , tensor([[0.1105],
        [0.1092]]) , tensor([[0.8200],
        [0.6099]])
SOC, DOC, MBC =  tensor([[45.8156],
        [45.8201]]) , tensor([[0.1104],
        [0.1092]]) , tensor([[0.8194],
        [0.6144]])
SOC, DOC, MBC =  tensor([[45.8101],
        [45.8312]]) , tensor([[0.1104],
        [0.1091]]) , tensor([[0.8262],
        [0.6129]])
SOC, DOC, MBC =  tensor([[45.8081],
        [45.8377]]) , tens

SOC, DOC, MBC =  tensor([[45.8001],
        [45.7695]]) , tensor([[0.1007],
        [0.1107]]) , tensor([[0.8177],
        [0.6263]])
SOC, DOC, MBC =  tensor([[45.7999],
        [45.7693]]) , tensor([[0.1080],
        [0.1107]]) , tensor([[0.8177],
        [0.6263]])
SOC, DOC, MBC =  tensor([[45.7996],
        [45.7690]]) , tensor([[0.1105],
        [0.1107]]) , tensor([[0.8177],
        [0.6263]])
SOC, DOC, MBC =  tensor([[45.7994],
        [45.7688]]) , tensor([[0.1106],
        [0.1107]]) , tensor([[0.8176],
        [0.6263]])
SOC, DOC, MBC =  tensor([[45.7991],
        [45.7685]]) , tensor([[0.1100],
        [0.1109]]) , tensor([[0.8176],
        [0.6263]])
SOC, DOC, MBC =  tensor([[45.7989],
        [45.7683]]) , tensor([[0.1083],
        [0.1106]]) , tensor([[0.8176],
        [0.6263]])
SOC, DOC, MBC =  tensor([[45.7986],
        [45.7680]]) , tensor([[0.1082],
        [0.1089]]) , tensor([[0.8176],
        [0.6263]])
SOC, DOC, MBC =  tensor([[45.7983],
        [45.7677]]) , tens

SOC, DOC, MBC =  tensor([[45.7843],
        [45.7571]]) , tensor([[0.1134],
        [0.1060]]) , tensor([[0.8226],
        [0.6265]])
SOC, DOC, MBC =  tensor([[45.7830],
        [45.7563]]) , tensor([[0.1134],
        [0.1059]]) , tensor([[0.8227],
        [0.6257]])
SOC, DOC, MBC =  tensor([[45.7849],
        [45.7514]]) , tensor([[0.1133],
        [0.1059]]) , tensor([[0.8235],
        [0.6243]])
SOC, DOC, MBC =  tensor([[45.7814],
        [45.7496]]) , tensor([[0.1133],
        [0.1059]]) , tensor([[0.8269],
        [0.6221]])
SOC, DOC, MBC =  tensor([[45.7804],
        [45.7517]]) , tensor([[0.1133],
        [0.1059]]) , tensor([[0.8284],
        [0.6262]])
SOC, DOC, MBC =  tensor([[45.7743],
        [45.7550]]) , tensor([[0.1133],
        [0.1059]]) , tensor([[0.8239],
        [0.6247]])
SOC, DOC, MBC =  tensor([[45.7781],
        [45.7553]]) , tensor([[0.1132],
        [0.1058]]) , tensor([[0.8216],
        [0.6236]])
SOC, DOC, MBC =  tensor([[45.7711],
        [45.7456]]) , tens

        [0.1041]]) , tensor([[0.8441],
        [0.6237]])
SOC, DOC, MBC =  tensor([[45.8056],
        [45.7144]]) , tensor([[0.1113],
        [0.1041]]) , tensor([[0.8348],
        [0.6236]])
SOC, DOC, MBC =  tensor([[45.8109],
        [45.7158]]) , tensor([[0.1112],
        [0.1041]]) , tensor([[0.8334],
        [0.6240]])
SOC, DOC, MBC =  tensor([[45.8104],
        [45.7266]]) , tensor([[0.1112],
        [0.1040]]) , tensor([[0.8332],
        [0.6287]])
SOC, DOC, MBC =  tensor([[45.8121],
        [45.7216]]) , tensor([[0.1112],
        [0.1040]]) , tensor([[0.8336],
        [0.6273]])
SOC, DOC, MBC =  tensor([[45.8205],
        [45.7205]]) , tensor([[0.1112],
        [0.1040]]) , tensor([[0.8330],
        [0.6256]])
SOC, DOC, MBC =  tensor([[45.8187],
        [45.7083]]) , tensor([[0.1111],
        [0.1040]]) , tensor([[0.8343],
        [0.6218]])
SOC, DOC, MBC =  tensor([[45.8149],
        [45.7088]]) , tensor([[0.1111],
        [0.1040]]) , tensor([[0.8316],
        [0.6232]])
SOC,

SOC, DOC, MBC =  tensor([[45.7879],
        [45.7329]]) , tensor([[0.1150],
        [0.1134]]) , tensor([[0.8214],
        [0.6221]])
SOC, DOC, MBC =  tensor([[45.7876],
        [45.7326]]) , tensor([[0.1150],
        [0.1127]]) , tensor([[0.8214],
        [0.6221]])
SOC, DOC, MBC =  tensor([[45.7873],
        [45.7323]]) , tensor([[0.1149],
        [0.1144]]) , tensor([[0.8214],
        [0.6221]])
SOC, DOC, MBC =  tensor([[45.7870],
        [45.7320]]) , tensor([[0.1149],
        [0.1144]]) , tensor([[0.8213],
        [0.6221]])
SOC, DOC, MBC =  tensor([[45.7868],
        [45.7318]]) , tensor([[0.1149],
        [0.1144]]) , tensor([[0.8213],
        [0.6221]])
SOC, DOC, MBC =  tensor([[45.7865],
        [45.7315]]) , tensor([[0.1149],
        [0.1144]]) , tensor([[0.8213],
        [0.6221]])
SOC, DOC, MBC =  tensor([[45.7862],
        [45.7312]]) , tensor([[0.1149],
        [0.1144]]) , tensor([[0.8213],
        [0.6221]])
SOC, DOC, MBC =  tensor([[45.7859],
        [45.7309]]) , tens

        [45.7157]]) , tensor([[0.1135],
        [0.1130]]) , tensor([[0.8087],
        [0.6165]])
SOC, DOC, MBC =  tensor([[45.7847],
        [45.7167]]) , tensor([[0.1135],
        [0.1130]]) , tensor([[0.8146],
        [0.6130]])
SOC, DOC, MBC =  tensor([[45.7844],
        [45.7160]]) , tensor([[0.1135],
        [0.1130]]) , tensor([[0.8168],
        [0.6144]])
SOC, DOC, MBC =  tensor([[45.7751],
        [45.7209]]) , tensor([[0.1135],
        [0.1129]]) , tensor([[0.8134],
        [0.6239]])
SOC, DOC, MBC =  tensor([[45.7653],
        [45.7275]]) , tensor([[0.1134],
        [0.1129]]) , tensor([[0.8152],
        [0.6184]])
SOC, DOC, MBC =  tensor([[45.7691],
        [45.7277]]) , tensor([[0.1134],
        [0.1129]]) , tensor([[0.8127],
        [0.6149]])
SOC, DOC, MBC =  tensor([[45.7703],
        [45.7273]]) , tensor([[0.1134],
        [0.1128]]) , tensor([[0.8149],
        [0.6163]])
SOC, DOC, MBC =  tensor([[45.7642],
        [45.7261]]) , tensor([[0.1133],
        [0.1128]]) , t

        [0.1110]]) , tensor([[0.8441],
        [0.6104]])
SOC, DOC, MBC =  tensor([[45.7500],
        [45.6515]]) , tensor([[0.1116],
        [0.1110]]) , tensor([[0.8433],
        [0.6014]])
SOC, DOC, MBC =  tensor([[45.7551],
        [45.6599]]) , tensor([[0.1115],
        [0.1110]]) , tensor([[0.8377],
        [0.6000]])
SOC, DOC, MBC =  tensor([[45.7567],
        [45.6572]]) , tensor([[0.1115],
        [0.1109]]) , tensor([[0.8403],
        [0.6000]])
SOC, DOC, MBC =  tensor([[45.7579],
        [45.6586]]) , tensor([[0.1115],
        [0.1109]]) , tensor([[0.8439],
        [0.6010]])
SOC, DOC, MBC =  tensor([[45.7509],
        [45.6623]]) , tensor([[0.1114],
        [0.1109]]) , tensor([[0.8439],
        [0.6009]])
SOC, DOC, MBC =  tensor([[45.7601],
        [45.6622]]) , tensor([[0.1114],
        [0.1109]]) , tensor([[0.8455],
        [0.5925]])
SOC, DOC, MBC =  tensor([[45.7610],
        [45.6611]]) , tensor([[0.1114],
        [0.1109]]) , tensor([[0.8473],
        [0.5896]])
SOC,

        [45.6492]]) , tensor([[0.1110],
        [0.1125]]) , tensor([[0.8405],
        [0.5835]])
SOC, DOC, MBC =  tensor([[45.7695],
        [45.6489]]) , tensor([[0.1139],
        [0.1125]]) , tensor([[0.8405],
        [0.5835]])
SOC, DOC, MBC =  tensor([[45.7692],
        [45.6487]]) , tensor([[0.1142],
        [0.1126]]) , tensor([[0.8405],
        [0.5835]])
SOC, DOC, MBC =  tensor([[45.7689],
        [45.6484]]) , tensor([[0.1148],
        [0.1121]]) , tensor([[0.8405],
        [0.5835]])
SOC, DOC, MBC =  tensor([[45.7687],
        [45.6481]]) , tensor([[0.1145],
        [0.1138]]) , tensor([[0.8404],
        [0.5835]])
SOC, DOC, MBC =  tensor([[45.7684],
        [45.6478]]) , tensor([[0.1124],
        [0.1138]]) , tensor([[0.8404],
        [0.5835]])
SOC, DOC, MBC =  tensor([[45.7681],
        [45.6475]]) , tensor([[0.1145],
        [0.1138]]) , tensor([[0.8404],
        [0.5834]])
SOC, DOC, MBC =  tensor([[45.7678],
        [45.6472]]) , tensor([[0.1142],
        [0.1138]]) , t

SOC, DOC, MBC =  tensor([[45.7507],
        [45.6430]]) , tensor([[0.1151],
        [0.1130]]) , tensor([[0.8447],
        [0.5643]])
SOC, DOC, MBC =  tensor([[45.7506],
        [45.6379]]) , tensor([[0.1151],
        [0.1130]]) , tensor([[0.8451],
        [0.5667]])
SOC, DOC, MBC =  tensor([[45.7491],
        [45.6390]]) , tensor([[0.1151],
        [0.1129]]) , tensor([[0.8490],
        [0.5687]])
SOC, DOC, MBC =  tensor([[45.7432],
        [45.6374]]) , tensor([[0.1151],
        [0.1129]]) , tensor([[0.8520],
        [0.5694]])
SOC, DOC, MBC =  tensor([[45.7437],
        [45.6372]]) , tensor([[0.1150],
        [0.1129]]) , tensor([[0.8494],
        [0.5715]])
SOC, DOC, MBC =  tensor([[45.7431],
        [45.6476]]) , tensor([[0.1150],
        [0.1129]]) , tensor([[0.8512],
        [0.5722]])
SOC, DOC, MBC =  tensor([[45.7369],
        [45.6443]]) , tensor([[0.1150],
        [0.1128]]) , tensor([[0.8521],
        [0.5705]])
SOC, DOC, MBC =  tensor([[45.7428],
        [45.6429]]) , tens

SOC, DOC, MBC =  tensor([[45.6545],
        [45.7135]]) , tensor([[0.1130],
        [0.1109]]) , tensor([[0.8577],
        [0.5502]])
SOC, DOC, MBC =  tensor([[45.6598],
        [45.7177]]) , tensor([[0.1129],
        [0.1109]]) , tensor([[0.8572],
        [0.5520]])
SOC, DOC, MBC =  tensor([[45.6558],
        [45.7152]]) , tensor([[0.1129],
        [0.1108]]) , tensor([[0.8545],
        [0.5439]])
SOC, DOC, MBC =  tensor([[45.6579],
        [45.7074]]) , tensor([[0.1129],
        [0.1108]]) , tensor([[0.8532],
        [0.5474]])
SOC, DOC, MBC =  tensor([[45.6603],
        [45.7081]]) , tensor([[0.1129],
        [0.1108]]) , tensor([[0.8585],
        [0.5406]])
SOC, DOC, MBC =  tensor([[45.6590],
        [45.7093]]) , tensor([[0.1128],
        [0.1107]]) , tensor([[0.8559],
        [0.5394]])
SOC, DOC, MBC =  tensor([[45.6627],
        [45.6982]]) , tensor([[0.1128],
        [0.1107]]) , tensor([[0.8557],
        [0.5386]])
SOC, DOC, MBC =  tensor([[45.6630],
        [45.6961]]) , tens

        [0.1095]]) , tensor([[0.8591],
        [0.5093]])
SOC, DOC, MBC =  tensor([[45.6822],
        [45.6944]]) , tensor([[0.1115],
        [0.1095]]) , tensor([[0.8591],
        [0.5095]])
SOC, DOC, MBC =  tensor([[45.6820],
        [45.6942]]) , tensor([[0.1115],
        [0.1095]]) , tensor([[0.8591],
        [0.5095]])
SOC, DOC, MBC =  tensor([[45.6818],
        [45.6940]]) , tensor([[0.1115],
        [0.1095]]) , tensor([[0.8591],
        [0.5095]])
SOC, DOC, MBC =  tensor([[45.6816],
        [45.6938]]) , tensor([[0.1115],
        [0.1094]]) , tensor([[0.8591],
        [0.5095]])
SOC, DOC, MBC =  tensor([[45.6813],
        [45.6936]]) , tensor([[0.1115],
        [0.1094]]) , tensor([[0.8591],
        [0.5095]])
SOC, DOC, MBC =  tensor([[45.6811],
        [45.6933]]) , tensor([[0.1115],
        [0.1094]]) , tensor([[0.8590],
        [0.5095]])
SOC, DOC, MBC =  tensor([[45.6809],
        [45.6931]]) , tensor([[0.1115],
        [0.1094]]) , tensor([[0.8590],
        [0.5095]])
SOC,

SOC, DOC, MBC =  tensor([[45.6652],
        [45.6770]]) , tensor([[0.1155],
        [0.1136]]) , tensor([[0.8580],
        [0.5015]])
SOC, DOC, MBC =  tensor([[45.6651],
        [45.6770]]) , tensor([[0.1155],
        [0.1136]]) , tensor([[0.8579],
        [0.5060]])
SOC, DOC, MBC =  tensor([[45.6651],
        [45.6770]]) , tensor([[0.1155],
        [0.1136]]) , tensor([[0.8577],
        [0.5048]])
SOC, DOC, MBC =  tensor([[45.6651],
        [45.6770]]) , tensor([[0.1155],
        [0.1135]]) , tensor([[0.8573],
        [0.5060]])
SOC, DOC, MBC =  tensor([[45.6650],
        [45.6769]]) , tensor([[0.1154],
        [0.1135]]) , tensor([[0.8576],
        [0.5026]])
SOC, DOC, MBC =  tensor([[45.6650],
        [45.6769]]) , tensor([[0.1154],
        [0.1135]]) , tensor([[0.8566],
        [0.5069]])
SOC, DOC, MBC =  tensor([[45.6650],
        [45.6769]]) , tensor([[0.1154],
        [0.1135]]) , tensor([[0.8554],
        [0.5081]])
SOC, DOC, MBC =  tensor([[45.6650],
        [45.6769]]) , tens

SOC, DOC, MBC =  tensor([[45.6624],
        [45.7491]]) , tensor([[0.1136],
        [0.1117]]) , tensor([[0.8636],
        [0.5002]])
SOC, DOC, MBC =  tensor([[45.6663],
        [45.7575]]) , tensor([[0.1135],
        [0.1116]]) , tensor([[0.8589],
        [0.4995]])
SOC, DOC, MBC =  tensor([[45.6654],
        [45.7752]]) , tensor([[0.1135],
        [0.1116]]) , tensor([[0.8581],
        [0.5023]])
SOC, DOC, MBC =  tensor([[45.6596],
        [45.7766]]) , tensor([[0.1135],
        [0.1116]]) , tensor([[0.8615],
        [0.5034]])
SOC, DOC, MBC =  tensor([[45.6468],
        [45.7675]]) , tensor([[0.1135],
        [0.1116]]) , tensor([[0.8548],
        [0.5142]])
SOC, DOC, MBC =  tensor([[45.6481],
        [45.7541]]) , tensor([[0.1134],
        [0.1115]]) , tensor([[0.8541],
        [0.5269]])
SOC, DOC, MBC =  tensor([[45.6393],
        [45.7701]]) , tensor([[0.1134],
        [0.1115]]) , tensor([[0.8528],
        [0.5225]])
SOC, DOC, MBC =  tensor([[45.6437],
        [45.7625]]) , tens

SOC, DOC, MBC =  tensor([[45.6146],
        [45.7580]]) , tensor([[0.1117],
        [0.1097]]) , tensor([[0.8750],
        [0.5334]])
SOC, DOC, MBC =  tensor([[45.6144],
        [45.7578]]) , tensor([[0.1117],
        [0.1097]]) , tensor([[0.8750],
        [0.5334]])
SOC, DOC, MBC =  tensor([[45.6142],
        [45.7575]]) , tensor([[0.1117],
        [0.1097]]) , tensor([[0.8750],
        [0.5334]])
SOC, DOC, MBC =  tensor([[45.6139],
        [45.7573]]) , tensor([[0.1117],
        [0.1097]]) , tensor([[0.8749],
        [0.5334]])
SOC, DOC, MBC =  tensor([[45.6137],
        [45.7571]]) , tensor([[0.1116],
        [0.1097]]) , tensor([[0.8749],
        [0.5334]])
SOC, DOC, MBC =  tensor([[45.6135],
        [45.7569]]) , tensor([[0.1116],
        [0.1097]]) , tensor([[0.8749],
        [0.5334]])
SOC, DOC, MBC =  tensor([[45.6133],
        [45.7566]]) , tensor([[0.1116],
        [0.1097]]) , tensor([[0.8749],
        [0.5334]])
SOC, DOC, MBC =  tensor([[45.6130],
        [45.7564]]) , tens

SOC, DOC, MBC =  tensor([[45.5969],
        [45.7400]]) , tensor([[0.1168],
        [0.1150]]) , tensor([[0.8736],
        [0.5386]])
SOC, DOC, MBC =  tensor([[45.5969],
        [45.7399]]) , tensor([[0.1168],
        [0.1150]]) , tensor([[0.8732],
        [0.5386]])
SOC, DOC, MBC =  tensor([[45.5969],
        [45.7399]]) , tensor([[0.1168],
        [0.1150]]) , tensor([[0.8726],
        [0.5330]])
SOC, DOC, MBC =  tensor([[45.5968],
        [45.7399]]) , tensor([[0.1167],
        [0.1149]]) , tensor([[0.8728],
        [0.5336]])
SOC, DOC, MBC =  tensor([[45.5968],
        [45.7399]]) , tensor([[0.1167],
        [0.1149]]) , tensor([[0.8723],
        [0.5352]])
SOC, DOC, MBC =  tensor([[45.5968],
        [45.7399]]) , tensor([[0.1167],
        [0.1149]]) , tensor([[0.8721],
        [0.5379]])
SOC, DOC, MBC =  tensor([[45.5968],
        [45.7398]]) , tensor([[0.1166],
        [0.1148]]) , tensor([[0.8719],
        [0.5390]])
SOC, DOC, MBC =  tensor([[45.5968],
        [45.7398]]) , tens

SOC, DOC, MBC =  tensor([[45.5084],
        [45.7405]]) , tensor([[0.1144],
        [0.1127]]) , tensor([[0.8673],
        [0.4902]])
SOC, DOC, MBC =  tensor([[45.5075],
        [45.7392]]) , tensor([[0.1144],
        [0.1127]]) , tensor([[0.8647],
        [0.4939]])
SOC, DOC, MBC =  tensor([[45.5097],
        [45.7354]]) , tensor([[0.1144],
        [0.1126]]) , tensor([[0.8638],
        [0.4960]])
SOC, DOC, MBC =  tensor([[45.5076],
        [45.7260]]) , tensor([[0.1143],
        [0.1126]]) , tensor([[0.8604],
        [0.4921]])
SOC, DOC, MBC =  tensor([[45.5007],
        [45.7197]]) , tensor([[0.1143],
        [0.1126]]) , tensor([[0.8557],
        [0.4966]])
SOC, DOC, MBC =  tensor([[45.5087],
        [45.7194]]) , tensor([[0.1143],
        [0.1125]]) , tensor([[0.8568],
        [0.4940]])
SOC, DOC, MBC =  tensor([[45.4889],
        [45.7201]]) , tensor([[0.1143],
        [0.1125]]) , tensor([[0.8566],
        [0.5003]])
SOC, DOC, MBC =  tensor([[45.5021],
        [45.7207]]) , tens

SOC, DOC, MBC =  tensor([[45.5085],
        [45.6597]]) , tensor([[0.1128],
        [0.1111]]) , tensor([[0.8491],
        [0.5012]])
SOC, DOC, MBC =  tensor([[45.5083],
        [45.6594]]) , tensor([[0.1128],
        [0.1111]]) , tensor([[0.8491],
        [0.5012]])
SOC, DOC, MBC =  tensor([[45.5080],
        [45.6592]]) , tensor([[0.1128],
        [0.1110]]) , tensor([[0.8491],
        [0.5012]])
SOC, DOC, MBC =  tensor([[45.5078],
        [45.6589]]) , tensor([[0.1128],
        [0.1111]]) , tensor([[0.8491],
        [0.5012]])
SOC, DOC, MBC =  tensor([[45.5075],
        [45.6587]]) , tensor([[0.1128],
        [0.1111]]) , tensor([[0.8490],
        [0.5012]])
SOC, DOC, MBC =  tensor([[45.5073],
        [45.6584]]) , tensor([[0.1128],
        [0.1110]]) , tensor([[0.8490],
        [0.5012]])
SOC, DOC, MBC =  tensor([[45.5070],
        [45.6581]]) , tensor([[0.1119],
        [0.1109]]) , tensor([[0.8490],
        [0.5012]])
SOC, DOC, MBC =  tensor([[45.5067],
        [45.6579]]) , tens

        [45.6437]]) , tensor([[0.1133],
        [0.1156]]) , tensor([[0.8479],
        [0.5034]])
SOC, DOC, MBC =  tensor([[45.4925],
        [45.6436]]) , tensor([[0.1133],
        [0.1156]]) , tensor([[0.8479],
        [0.5051]])
SOC, DOC, MBC =  tensor([[45.4924],
        [45.6435]]) , tensor([[0.1133],
        [0.1155]]) , tensor([[0.8479],
        [0.5035]])
SOC, DOC, MBC =  tensor([[45.4923],
        [45.6434]]) , tensor([[0.1132],
        [0.1155]]) , tensor([[0.8479],
        [0.5004]])
SOC, DOC, MBC =  tensor([[45.4922],
        [45.6433]]) , tensor([[0.1132],
        [0.1155]]) , tensor([[0.8479],
        [0.4978]])
SOC, DOC, MBC =  tensor([[45.4921],
        [45.6432]]) , tensor([[0.1132],
        [0.1155]]) , tensor([[0.8479],
        [0.4995]])
SOC, DOC, MBC =  tensor([[45.4920],
        [45.6431]]) , tensor([[0.1132],
        [0.1154]]) , tensor([[0.8479],
        [0.5002]])
SOC, DOC, MBC =  tensor([[45.4919],
        [45.6430]]) , tensor([[0.1132],
        [0.1154]]) , t

        [45.6388]]) , tensor([[0.1117],
        [0.1139]]) , tensor([[0.8834],
        [0.5073]])
SOC, DOC, MBC =  tensor([[45.4225],
        [45.6179]]) , tensor([[0.1117],
        [0.1138]]) , tensor([[0.8846],
        [0.5085]])
SOC, DOC, MBC =  tensor([[45.4257],
        [45.6335]]) , tensor([[0.1117],
        [0.1138]]) , tensor([[0.8868],
        [0.5023]])
SOC, DOC, MBC =  tensor([[45.4337],
        [45.6286]]) , tensor([[0.1117],
        [0.1138]]) , tensor([[0.8842],
        [0.5043]])
SOC, DOC, MBC =  tensor([[45.4458],
        [45.6023]]) , tensor([[0.1116],
        [0.1138]]) , tensor([[0.8834],
        [0.5003]])
SOC, DOC, MBC =  tensor([[45.4386],
        [45.5949]]) , tensor([[0.1116],
        [0.1137]]) , tensor([[0.8795],
        [0.4992]])
SOC, DOC, MBC =  tensor([[45.4520],
        [45.5900]]) , tensor([[0.1116],
        [0.1137]]) , tensor([[0.8857],
        [0.4961]])
SOC, DOC, MBC =  tensor([[45.4485],
        [45.5935]]) , tensor([[0.1116],
        [0.1137]]) , t

SOC, DOC, MBC =  tensor([[45.4490],
        [45.5166]]) , tensor([[0.1098],
        [0.1118]]) , tensor([[0.8572],
        [0.5596]])
SOC, DOC, MBC =  tensor([[45.4490],
        [45.5166]]) , tensor([[0.1098],
        [0.1118]]) , tensor([[0.8570],
        [0.5628]])
SOC, DOC, MBC =  tensor([[45.4490],
        [45.5166]]) , tensor([[0.1098],
        [0.1118]]) , tensor([[0.8570],
        [0.5647]])
SOC, DOC, MBC =  tensor([[45.4489],
        [45.5166]]) , tensor([[0.1098],
        [0.1118]]) , tensor([[0.8570],
        [0.5656]])
SOC, DOC, MBC =  tensor([[45.4489],
        [45.5165]]) , tensor([[0.1097],
        [0.1117]]) , tensor([[0.8570],
        [0.5655]])
SOC, DOC, MBC =  tensor([[45.4489],
        [45.5165]]) , tensor([[0.1097],
        [0.1117]]) , tensor([[0.8570],
        [0.5640]])
SOC, DOC, MBC =  tensor([[45.4488],
        [45.5165]]) , tensor([[0.1097],
        [0.1117]]) , tensor([[0.8570],
        [0.5607]])
SOC, DOC, MBC =  tensor([[45.4488],
        [45.5164]]) , tens

SOC, DOC, MBC =  tensor([[45.4342],
        [45.5020]]) , tensor([[0.0695],
        [0.1174]]) , tensor([[0.8558],
        [0.5642]])
SOC, DOC, MBC =  tensor([[45.4339],
        [45.5017]]) , tensor([[0.0677],
        [0.1174]]) , tensor([[0.8558],
        [0.5642]])
SOC, DOC, MBC =  tensor([[45.4336],
        [45.5014]]) , tensor([[0.0844],
        [0.1174]]) , tensor([[0.8558],
        [0.5642]])
SOC, DOC, MBC =  tensor([[45.4333],
        [45.5011]]) , tensor([[0.0875],
        [0.1174]]) , tensor([[0.8558],
        [0.5642]])
SOC, DOC, MBC =  tensor([[45.4330],
        [45.5009]]) , tensor([[0.0888],
        [0.1174]]) , tensor([[0.8557],
        [0.5642]])
SOC, DOC, MBC =  tensor([[45.4327],
        [45.5006]]) , tensor([[0.0920],
        [0.1174]]) , tensor([[0.8557],
        [0.5642]])
SOC, DOC, MBC =  tensor([[45.4325],
        [45.5003]]) , tensor([[0.0942],
        [0.1174]]) , tensor([[0.8557],
        [0.5642]])
SOC, DOC, MBC =  tensor([[45.4322],
        [45.5001]]) , tens

        [45.4757]]) , tensor([[0.1091],
        [0.1159]]) , tensor([[0.8565],
        [0.5595]])
SOC, DOC, MBC =  tensor([[45.4316],
        [45.4741]]) , tensor([[0.1091],
        [0.1158]]) , tensor([[0.8576],
        [0.5556]])
SOC, DOC, MBC =  tensor([[45.4286],
        [45.4836]]) , tensor([[0.1091],
        [0.1158]]) , tensor([[0.8581],
        [0.5538]])
SOC, DOC, MBC =  tensor([[45.4321],
        [45.4769]]) , tensor([[0.1090],
        [0.1158]]) , tensor([[0.8525],
        [0.5453]])
SOC, DOC, MBC =  tensor([[45.4372],
        [45.4700]]) , tensor([[0.1090],
        [0.1157]]) , tensor([[0.8537],
        [0.5447]])
SOC, DOC, MBC =  tensor([[45.4376],
        [45.4656]]) , tensor([[0.1090],
        [0.1157]]) , tensor([[0.8569],
        [0.5447]])
SOC, DOC, MBC =  tensor([[45.4456],
        [45.4665]]) , tensor([[0.1089],
        [0.1157]]) , tensor([[0.8537],
        [0.5441]])
SOC, DOC, MBC =  tensor([[45.4389],
        [45.4622]]) , tensor([[0.1089],
        [0.1156]]) , t

In [60]:
print(C_vector_test[:, :, 0].size())
print(drift_vector_test[:, :, 0].size())
print(C_vector_test[:, :, 0])
print(drift_vector_test[:, :, 0])

torch.Size([2, 3])
torch.Size([2, 3])
tensor([[45.6603,  0.0715,  0.7147],
        [45.6603,  0.0715,  0.7147]])
tensor([[-1.1642e-10,  0.0000e+00,  0.0000e+00],
        [-1.1642e-10,  0.0000e+00,  0.0000e+00]])


In [198]:
print(C_vector_test[:, :, 10])
print(drift_vector_test[:, :, 10])
print(diffusion_matrix_sqrt_test[:, :, :, 10])

tensor([[ 4.5654e+01, -1.1616e-02,  7.1385e-01],
        [ 4.5660e+01, -1.0476e-01,  7.1448e-01]])
tensor([[-0.0005, -0.0109, -0.0002],
        [-0.0007, -0.1033, -0.0004]])
tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]])


In [24]:
drift_and_diffusion_scon(10, T_span_tensor, dt, I_S, I_D, analytical_steady_state_init_con, scon_params_dict, temp_ref, path_test)


 Initial pre-perturbation SOC, DOC, MBC =  tensor([45.6603,  0.0715,  0.7147])


(tensor([[[45.6603, 45.6603, 45.6603, 45.6603, 45.6603, 45.6603, 45.6603,
           45.6603, 45.6603, 45.6602],
          [ 0.0715,  0.0715,  0.0705,  0.0700,  0.0708,  0.0707,  0.0704,
            0.0687,  0.0730,  0.0725],
          [ 0.7147,  0.7147,  0.7147,  0.7147,  0.7147,  0.7147,  0.7147,
            0.7147,  0.7147,  0.7147]],
 
         [[45.6603, 45.6603, 45.6603, 45.6603, 45.6603, 45.6603, 45.6603,
           45.6603, 45.6603, 45.6602],
          [ 0.0715,  0.0715,  0.0715,  0.0715,  0.0734,  0.0728,  0.0725,
            0.0740,  0.0749,  0.0728],
          [ 0.7147,  0.7147,  0.7147,  0.7147,  0.7147,  0.7147,  0.7147,
            0.7147,  0.7147,  0.7147]]]),
 tensor([[[-1.1642e-10, -3.1375e-05, -6.5296e-05, -9.9348e-05, -1.3197e-04,
           -1.6710e-04, -2.0344e-04, -2.4333e-04, -2.7268e-04,  0.0000e+00],
          [ 0.0000e+00,  4.6595e-06,  1.6510e-05,  2.5313e-05,  2.4672e-05,
            3.1017e-05,  3.8984e-05,  5.8079e-05,  3.0927e-05,  0.0000e+00],
          

In [200]:
drift_and_diffusion_scon(N, T_span_tensor, dt, I_S, I_D, analytical_steady_state_init_con, scon_params_dict, temp_ref, path_test)


 Initial pre-perturbation SOC, DOC, MBC =  tensor([45.6603,  0.0715,  0.7147])


(tensor([[[45.6603, 45.6603, 45.6604,  ...,     nan,     nan,     nan],
          [ 0.0715,  0.1609,  0.0507,  ...,     nan,     nan,     nan],
          [ 0.7147,  0.7147,  0.7143,  ...,     nan,     nan,     nan]],
 
         [[45.6603, 45.6603, 45.6673,  ...,     nan,     nan,     nan],
          [ 0.0715,  0.2821,  0.6116,  ...,     nan,     nan,     nan],
          [ 0.7147,  0.7147,  0.7123,  ...,     nan,     nan,     nan]]]),
 tensor([[[-1.1642e-10,  1.1914e-04, -9.9310e-05,  ...,         nan,
                   nan,  0.0000e+00],
          [ 7.1462e-02,  1.6027e-01,  5.0839e-02,  ...,         nan,
                   nan,  0.0000e+00],
          [ 0.0000e+00,  1.7604e-04, -4.7223e-05,  ...,         nan,
                   nan,  0.0000e+00]],
 
         [[-1.1642e-10,  3.2306e-04,  8.6309e-04,  ...,         nan,
                   nan,  0.0000e+00],
          [ 7.1462e-02,  2.8057e-01,  6.0766e-01,  ...,         nan,
                   nan,  0.0000e+00],
          [ 0.0000e+00, 

In [None]:
class SDEFlow(nn.Module):

    def __init__(self, cond_inputs=1):
        super().__init__()

        stride = 3 if STATE_DIM % 2==1 else 2

        self.CL_1 = CouplingLayer(cond_inputs, stride)
        self.CL_2 = CouplingLayer(cond_inputs, stride)
        self.CL_3 = CouplingLayer(cond_inputs, stride)
        self.CL_4 = CouplingLayer(cond_inputs, stride)
        self.CL_5 = CouplingLayer(cond_inputs, stride)

        self.P_1 = PermutationLayer()
        self.P_2 = PermutationLayer()
        self.P_3 = PermutationLayer()
        self.P_4 = PermutationLayer()

        self.base_dist = d.normal.Normal(loc = 0., scale = 1.0)

    def forward(self, batch_size, *args, **kwargs):

        eps = self.base_dist.sample([batch_size, 1, STATE_DIM*N]).to(device)
        log_prob = self.base_dist.log_prob(eps).sum(-1)

        times = torch.arange(0, T+dt, dt, device=eps.device)[(None,)*2].repeat(batch_size, STATE_DIM, 1).transpose(-2, -1).reshape(batch_size, 1, -1)

        CL_1, CL_1_sigma = self.CL_1(eps, times)
        P_1 = self.P_1(CL_1)

        CL_2, CL_2_sigma = self.CL_2(P_1, times)
        P_2 = self.P_2(CL_2)

        CL_3, CL_3_sigma = self.CL_3(P_2, times)
        P_3 = self.P_3(CL_3)

        CL_4, CL_4_sigma = self.CL_4(P_3, times)
        P_4 = self.P_4(CL_4)

        y, CL_5_sigma = self.CL_5(P_4, times)

        for sigma in [CL_1_sigma, CL_2_sigma, CL_3_sigma, CL_4_sigma, CL_5_sigma]:
            log_prob -= torch.log(sigma).sum(-1)

        return y.reshape(batch_size, STATE_DIM, -1), log_prob

class ObsModel(nn.Module):

    def __init__(self, times, mu, scale):
        super().__init__()

        self.idx = self._get_idx(times)
        self.times = times
        self.mu = torch.Tensor(mu).to(device)
        self.scale = scale
        
    def forward(self, x):
            obs_ll = d.normal.Normal(self.mu.permute(1, 0), self.scale).log_prob(x[:, self.idx, :])
            return torch.sum(obs_ll, [-1, -2]).mean()

    def _get_idx(self, times):
        return list((times/dt).astype(int))  