In [1]:
#Python-related imports
from datetime import datetime
import time

#Torch-related imports
import torch
import torch.distributions as D
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Function

#PyData imports
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt

In [2]:
# Hack to import from a parent directory
import sys
path = '..'
if path not in sys.path:
    sys.path.append(path)

#Module imports
from TruncatedNormal import *
from mean_field import *
from obs_and_flow import LowerBound
from SBM_SDE_classes import temp_gen, arrhenius_temp_dep, linear_temp_dep, i_s, i_d

In [3]:
torch.manual_seed(0)
np.random.seed(0)
torch.set_printoptions(precision = 8)
device = torch.device('cpu')

## Draw $\theta \sim p(\theta)$

In [4]:
temp_ref = 283
temp_rise = 5 #High estimate of 5 celsius temperature rise by 2100.

prior_scale_factor = 0.333

#Parameter prior means
u_M_mean = 0.0016
a_SD_mean = 0.5
a_DS_mean = 0.5
a_M_mean = 0.5
a_MSC_mean = 0.5
k_S_ref_mean = 0.0005
k_D_ref_mean = 0.0008
k_M_ref_mean = 0.0007
Ea_S_mean = 55
Ea_D_mean = 48
Ea_M_mean = 48
s_SOC_mean = 0.01
s_DOC_mean = 0.01
s_MBC_mean = 0.01

#SCON theta truncated normal distribution parameter details in order of mean, sdev, lower, and upper.
u_M_details = torch.Tensor([u_M_mean, u_M_mean * prior_scale_factor, 0, 1])
a_SD_details = torch.Tensor([a_SD_mean, a_SD_mean * prior_scale_factor, 0, 1])
a_DS_details = torch.Tensor([a_DS_mean, a_DS_mean * prior_scale_factor, 0, 1])
a_M_details = torch.Tensor([a_M_mean, a_M_mean * prior_scale_factor, 0, 1])
a_MSC_details = torch.Tensor([a_MSC_mean, a_MSC_mean * prior_scale_factor, 0, 1])
k_S_ref_details = torch.Tensor([k_S_ref_mean, k_S_ref_mean * prior_scale_factor, 0, 1])
k_D_ref_details = torch.Tensor([k_D_ref_mean, k_D_ref_mean * prior_scale_factor, 0, 1])
k_M_ref_details = torch.Tensor([k_M_ref_mean, k_M_ref_mean * prior_scale_factor, 0, 1])
Ea_S_details = torch.Tensor([Ea_S_mean, Ea_S_mean * prior_scale_factor, 10, 100])
Ea_D_details = torch.Tensor([Ea_D_mean, Ea_D_mean * prior_scale_factor, 10, 100])
Ea_M_details = torch.Tensor([Ea_M_mean, Ea_M_mean * prior_scale_factor, 10, 100])

#SCON-SS diffusion matrix parameter distribution details
s_SOC_details = torch.Tensor([s_SOC_mean, s_SOC_mean * prior_scale_factor, 0, 1])
s_DOC_details = torch.Tensor([s_DOC_mean, s_DOC_mean * prior_scale_factor, 0, 1])
s_MBC_details = torch.Tensor([s_MBC_mean, s_MBC_mean * prior_scale_factor, 0, 1])

##SCON-SS theta rsample draws
#u_M = TruncatedNormal(loc = u_M_details[0], scale = u_M_details[1], a = u_M_details[2], b = u_M_details[3]).rsample().cpu().detach().numpy()
#a_SD = TruncatedNormal(loc = a_SD_details[0], scale = a_SD_details[1], a = a_SD_details[2], b = a_SD_details[3]).rsample().cpu().detach().numpy()
#a_DS = TruncatedNormal(loc = a_DS_details[0], scale = a_DS_details[1], a = a_DS_details[2], b = a_DS_details[3]).rsample().cpu().detach().numpy()
#a_M = TruncatedNormal(loc = a_M_details[0], scale = a_M_details[1], a = a_M_details[2], b = a_M_details[3]).rsample().cpu().detach().numpy()
#a_MSC = TruncatedNormal(loc = a_MSC_details[0], scale = a_MSC_details[1], a = a_MSC_details[2], b = a_MSC_details[3]).rsample().cpu().detach().numpy()
#k_S_ref = TruncatedNormal(loc = k_S_ref_details[0], scale = k_S_ref_details[1], a = k_S_ref_details[2], b = k_S_ref_details[3]).rsample().cpu().detach().numpy()
#k_D_ref = TruncatedNormal(loc = k_D_ref_details[0], scale = k_D_ref_details[1], a = k_D_ref_details[2], b = k_D_ref_details[3]).rsample().cpu().detach().numpy()
#k_M_ref = TruncatedNormal(loc = k_M_ref_details[0], scale = k_M_ref_details[1], a = k_M_ref_details[2], b = k_M_ref_details[3]).rsample().cpu().detach().numpy()
#Ea_S = TruncatedNormal(loc = Ea_S_details[0], scale = Ea_S_details[1], a = Ea_S_details[2], b = Ea_S_details[3]).rsample().cpu().detach().numpy()
#Ea_D = TruncatedNormal(loc = Ea_D_details[0], scale = Ea_D_details[1], a = Ea_D_details[2], b = Ea_D_details[3]).rsample().cpu().detach().numpy()
#Ea_M = TruncatedNormal(loc = Ea_M_details[0], scale = Ea_M_details[1], a = Ea_M_details[2], b = Ea_M_details[3]).rsample().cpu().detach().numpy()
#s_SOC = TruncatedNormal(loc = s_SOC_details[0], scale = s_SOC_details[1], a = s_SOC_details[2], b = s_SOC_details[3]).rsample().cpu().detach().numpy()
#s_DOC = TruncatedNormal(loc = s_DOC_details[0], scale = s_DOC_details[1], a = s_DOC_details[2], b = s_DOC_details[3]).rsample().cpu().detach().numpy()
#s_MBC = TruncatedNormal(loc = s_MBC_details[0], scale = s_MBC_details[1], a = s_MBC_details[2], b = s_MBC_details[3]).rsample().cpu().detach().numpy()

SCON_SS_priors_details = {'u_M': u_M_details, 'a_SD': a_SD_details, 'a_DS': a_DS_details, 'a_M': a_M_details, 'a_MSC': a_MSC_details, 'k_S_ref': k_S_ref_details, 'k_D_ref': k_D_ref_details, 'k_M_ref': k_M_ref_details, 'Ea_S': Ea_S_details, 'Ea_D': Ea_D_details, 'Ea_M': Ea_M_details, 'c_SOC': s_SOC_details, 'c_DOC': s_DOC_details, 'c_MBC': s_MBC_details}

In [5]:
#Convert prior details dictionary values to tensors.
param_names = SCON_SS_priors_details.keys()
prior_list = list(zip(*(SCON_SS_priors_details[k] for k in param_names)))
prior_means_tensor, prior_sds_tensor, prior_lowers_tensor, prior_uppers_tensor = torch.tensor(prior_list)
prior = TruncatedNormal(loc = prior_means_tensor, scale = prior_sds_tensor, a = prior_lowers_tensor, b = prior_uppers_tensor)
theta = prior.sample()

params_dict = {k: v for k, v in zip(param_names, theta)}
params_dict

{'u_M': tensor(0.00159590),
 'a_SD': tensor(0.62165385),
 'a_DS': tensor(0.27633050),
 'a_M': tensor(0.31480956),
 'a_MSC': tensor(0.41646621),
 'k_S_ref': tensor(0.00055727),
 'k_D_ref': tensor(0.00079384),
 'k_M_ref': tensor(0.00099425),
 'Ea_S': tensor(52.98742676),
 'Ea_D': tensor(53.52291489),
 'Ea_M': tensor(42.02923584),
 'c_SOC': tensor(0.00917803),
 'c_DOC': tensor(0.00339356),
 'c_MBC': tensor(0.00682223)}

## Draw $x \sim p(x|\theta)$

In [66]:
#Generate data from SBM SDEs
#x in order of SOC, DOC, MBC (and EEC for AWB family models)

def alpha_SCON_multi(x, SCON_params_dict, I_S, I_D, current_temp, temp_ref, arrhenius_temp, linear_temp):
    #Partition SOC, DOC, and MBC values.
    state_dim = 3
    SOC, DOC, MBC = torch.chunk(x, state_dim, -1) # (batch_size, state_dim) > (batch_size, 1)
    
    #Force temperature-dependent parameters.
    k_S = arrhenius_temp(SCON_params_dict['k_S_ref'], current_temp, SCON_params_dict['Ea_S'], temp_ref)
    k_D = arrhenius_temp(SCON_params_dict['k_D_ref'], current_temp, SCON_params_dict['Ea_D'], temp_ref)
    k_M = arrhenius_temp(SCON_params_dict['k_M_ref'], current_temp, SCON_params_dict['Ea_M'], temp_ref)
    
    #Evolve drift.
    drift_SOC = I_S + SCON_params_dict['a_DS'] * k_D * DOC + SCON_params_dict['a_M'] * SCON_params_dict['a_MSC'] * k_M * MBC - k_S * SOC
    drift_DOC = I_D + SCON_params_dict['a_SD'] * k_S * SOC + SCON_params_dict['a_M'] * (1 - SCON_params_dict['a_MSC']) * k_M * MBC - (SCON_params_dict['u_M'] + k_D) * DOC
    drift_MBC = SCON_params_dict['u_M'] * DOC - k_M * MBC
    
    return torch.cat([drift_SOC, drift_DOC, drift_MBC], -1)

def beta_SCON_C_multi(x, SCON_C_params_dict):
    b11 = torch.as_tensor(SCON_C_params_dict['c_SOC'])
    b22 = torch.as_tensor(SCON_C_params_dict['c_DOC'])
    b33 = torch.as_tensor(SCON_C_params_dict['c_MBC'])
    
    return torch.diag_embed(torch.cat(torch.atleast_1d([b11, b22, b33]), -1))

def beta_SCON_SS_multi(x, SCON_SS_params_dict):
    state_dim = 3
    SOC, DOC, MBC = np.array_split(x, state_dim, 1) #Partition SOC, DOC, and MBC values.
    b11 = torch.Tensor(SCON_SS_params_dict['s_SOC'] * SOC)
    b22 = torch.Tensor(SCON_SS_params_dict['s_DOC'] * DOC)
    b33 = torch.Tensor(SCON_SS_params_dict['s_MBC'] * MBC)
    b_matrix = torch.diag_embed(torch.cat([b11, b22, b33], 1)) 
    return b_matrix

def alpha_SAWB_multi(x, SAWB_params_dict, I_S, I_D, current_temp, temp_ref, arrhenius_temp, linear_temp):
    #Partition SOC, DOC, MBC, and EEC values.
    state_dim = 4
    SOC, DOC, MBC, EEC = torch.chunk(x, state_dim, 1)
    
    #Force temperature-dependent parameters.
    u_Q = linear_temp(SAWB_params_dict['u_Q_ref'], current_temp, SAWB_params_dict['Q'], temp_ref)
    V_D = arrhenius_temp(SAWB_params_dict['V_D_ref'], current_temp, SAWB_params_dict['Ea_V_D'], temp_ref)
    V_U = arrhenius_temp(SAWB_params_dict['V_U_ref'], current_temp, SAWB_params_dict['Ea_V_U'], temp_ref)
    
    #Evolve drift.
    drift_SOC = I_S + SAWB_params_dict['a_MSA'] * SAWB_params_dict['r_M'] * MBC - ((V_D * EEC * SOC) / (SAWB_params_dict['K_D'] + SOC))
    drift_DOC = I_D + (1 - SAWB_params_dict['a_MSA']) * SAWB_params_dict['r_M'] * MBC + ((V_D * EEC * SOC) / (SAWB_params_dict['K_D'] + SOC)) + SAWB_params_dict['r_L'] * EEC - ((V_U * MBC * DOC) / (SAWB_params_dict['K_U'] + DOC))
    drift_MBC = (u_Q * (V_U * MBC * DOC) / (SAWB_params_dict['K_U'] + DOC)) - (SAWB_params_dict['r_M'] + SAWB_params_dict['r_E']) * MBC
    drift_EEC = SAWB_params_dict['r_E'] * MBC - SAWB_params_dict['r_L'] * EEC
    
    return torch.cat([drift_SOC, drift_DOC, drift_MBC, drift_EEC], 1)

def beta_SAWB_C_multi(x, SAWB_C_params_dict):
    b11 = torch.Tensor(SAWB_C_params_dict['c_SOC'])
    b22 = torch.Tensor(SAWB_C_params_dict['c_DOC'])
    b33 = torch.Tensor(SAWB_C_params_dict['c_MBC'])
    b44 = torch.Tensor(SAWB_C_params_dict['c_EEC'])
    b_matrix = torch.diag_embed(torch.cat([b11, b22, b33, b44], 1)) 
    return b_matrix

def beta_SAWB_SS_multi(x, SAWB_SS_params_dict):
    state_dim = 4
    SOC, DOC, MBC, EEC = np.array_split(x, state_dim, 1) #Partition SOC, DOC, MBC, and EEC values.
    b11 = torch.Tensor(SAWB_SS_params_dict['s_SOC'] * SOC)
    b22 = torch.Tensor(SAWB_SS_params_dict['s_DOC'] * DOC)
    b33 = torch.Tensor(SAWB_SS_params_dict['s_MBC'] * MBC)
    b44 = torch.Tensor(SAWB_SS_params_dict['s_EEC'] * EEC)
    b_matrix = torch.diag_embed(torch.cat([b11, b22, b33, b44], 1)) 
    return b_matrix

def alpha_SAWB_ECA_multi(x, SAWB_ECA_params_dict, I_S, I_D, current_temp, temp_ref, arrhenius_temp, linear_temp):
    #Partition SOC, DOC, MBC, and EEC values.
    state_dim = 4
    SOC, DOC, MBC, EEC = torch.chunk(x, state_dim, 1)
    
    #Force temperature-dependent parameters.
    u_Q = linear_temp(SAWB_ECA_params_dict['u_Q_ref'], current_temp, SAWB_ECA_params_dict['Q'], temp_ref)
    V_DE = arrhenius_temp(SAWB_ECA_params_dict['V_DE_ref'], current_temp, SAWB_ECA_params_dict['Ea_V_DE'], temp_ref)
    V_UE = arrhenius_temp(SAWB_ECA_params_dict['V_UE_ref'], current_temp, SAWB_ECA_params_dict['Ea_V_UE'], temp_ref)
    
    #Evolve drift.
    drift_SOC = I_S + SAWB_ECA_params_dict['a_MSA'] * SAWB_ECA_params_dict['r_M'] * MBC - ((V_DE * EEC * SOC) / (SAWB_ECA_params_dict['K_DE'] + EEC + SOC))
    drift_DOC = I_D + (1 - SAWB_ECA_params_dict['a_MSA']) * SAWB_ECA_params_dict['r_M'] * MBC + ((V_DE * EEC * SOC) / (SAWB_ECA_params_dict['K_DE'] + EEC + SOC)) + SAWB_ECA_params_dict['r_L'] * EEC - ((V_UE * MBC * DOC) / (SAWB_ECA_params_dict['K_UE'] + MBC + DOC))
    drift_MBC = (u_Q * (V_UE * MBC * DOC) / (SAWB_ECA_params_dict['K_UE'] + MBC + DOC)) - (SAWB_ECA_params_dict['r_M'] + SAWB_ECA_params_dict['r_E']) * MBC
    drift_EEC = SAWB_ECA_params_dict['r_E'] * MBC - SAWB_ECA_params_dict['r_L'] * EEC
    
    return torch.cat([drift_SOC, drift_DOC, drift_MBC, drift_EEC], 1)

def beta_SAWB_ECA_C_multi(x, SAWB_ECA_C_params_dict):
    b11 = torch.Tensor(SAWB_ECA_C_params_dict['c_SOC'])
    b22 = torch.Tensor(SAWB_ECA_C_params_dict['c_DOC'])
    b33 = torch.Tensor(SAWB_ECA_C_params_dict['c_MBC'])
    b44 = torch.Tensor(SAWB_ECA_C_params_dict['c_EEC'])
    b_matrix = torch.diag_embed(torch.cat([b11, b22, b33, b44], 1)) 
    return b_matrix

def beta_SAWB_ECA_SS_multi(x, SAWB_ECA_SS_params_dict):
    state_dim = 4
    SOC, DOC, MBC, EEC = np.array_split(x, state_dim, 1) #Partition SOC, DOC, MBC, and EEC values.
    b11 = torch.Tensor(SAWB_ECA_SS_params_dict['s_SOC'] * SOC)
    b22 = torch.Tensor(SAWB_ECA_SS_params_dict['s_DOC'] * DOC)
    b33 = torch.Tensor(SAWB_ECA_SS_params_dict['s_MBC'] * MBC)
    b44 = torch.Tensor(SAWB_ECA_SS_params_dict['s_EEC'] * EEC)
    b_matrix = torch.diag_embed(torch.cat([b11, b22, b33, b44], 1)) 
    return b_matrix

In [67]:
def generate_x(BATCH_SIZE, ALPHA, BETA, X0_LOC, X0_SCALE, T, DT, THETA_DICT, I_S_FUNC, I_D_FUNC, TEMP_FUNC, TEMP_REF, TEMP_RISE, OBS_EVERY, OBS_ERROR_SCALE, lower_bound = 1e-4):
    if ALPHA == alpha_SCON_multi:
        state_dim = 3
    elif ALPHA == alpha_SAWB_multi:
        state_dim = 4
    elif ALPHA == alpha_SAWB_ECA_multi:
        state_dim = 4
        
    N = int(T / DT) + 1
    M = int(T / OBS_EVERY) + 1
    x = torch.zeros([BATCH_SIZE, N, state_dim])
    
    # Draw initial condition x0
    X0_LOC = torch.as_tensor(X0_LOC)
    x0_dist = D.normal.Normal(loc = X0_LOC, scale = X0_LOC * X0_SCALE)
    x0_samples = x0_dist.sample((batch_size, )) # (batch_size, state_dim)
    x0_samples[x0_samples < lower_bound] = lower_bound #Bound initial conditions above 0. 
    print('X0_samples = ', x0_samples)
    x[:, 0, :] = x0_samples
    
    # Vectorize variable calculations where possible
    hours = torch.tensor(np.linspace(0, T, N), dtype=torch.float) # 0
    I_S_tensor = I_S_FUNC(hours)
    I_D_tensor = I_D_FUNC(hours)
    temps = TEMP_FUNC(hours, TEMP_REF, TEMP_RISE)
    
    #Take Euler-Maruyama step. 
    for i in range(1, N):
        a = ALPHA(x[:, i - 1, :], THETA_DICT, I_S_tensor[i], I_D_tensor[i], temps[i], TEMP_REF, arrhenius_temp_dep, linear_temp_dep)
        b = BETA(x[:, i - 1, :], THETA_DICT)
        loc = x[:, i - 1, :] + a * DT
        covariance_matrix = b * DT
        print(loc.shape, covariance_matrix.shape) # (batch_size, state_dim), (batch_size, state_dim, state_dim)
        
        x_i_dist = D.multivariate_normal.MultivariateNormal(loc = x[:, i - 1, :] + a * DT, covariance_matrix = b * DT)
        #print(x_i_dist.sample((batch_size, )).shape, x[:, i, :].shape) # x_i to be (batch_size, state_dim)
        x[:, i, :] = x_i_dist.sample()
        #for batch in range(batch_size):
        #    x[batch, i, :] = np.random.multivariate_normal(mean = x[batch, i - 1, :] + a[batch, :] * DT, cov = b[batch, :, :] * DT)
        
        x[:, i, :][x[:, i, :] < lower_bound] = lower_bound #Bound all x above 0.
    
    return x

In [68]:
batch_size = 5
dt = 0.01
t = 10
x0_SCON = [65, 0.4, 2.5]
obs_every = 5
obs_error_scale = 0.1
x0_scale = 0.25

In [69]:
t0 = time.time()
x = generate_x(batch_size, alpha_SCON_multi, beta_SCON_C_multi, x0_SCON, x0_scale, t, dt, params_dict, i_s, i_d, temp_gen, temp_ref, temp_rise, obs_every, obs_error_scale)
print(x.shape, time.time() - t0)

X0_samples =  tensor([[80.11885071,  0.44007134,  3.02441311],
        [85.93804932,  0.50682271,  1.40449238],
        [32.07121658,  0.59080827,  2.62880635],
        [51.16740036,  0.28266990,  2.91556501],
        [83.51375580,  0.45361611,  2.89769650]])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3])

torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([

torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([

torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([5, 3]) torch.Size([3, 3])
torch.Size([

In [None]:
print('y shape: ', y_dict['y'].shape)
print('t_y shape: ', y_dict['t_y'].shape)
print('x shape: ', y_dict['x'].shape)

In [None]:
print(y_dict['x'][:, :, 0].shape)
print(y_dict['x'][:, :, 0].transpose().shape)
print(y_dict['y'][:, :, 0].shape)
print(y_dict['y'][:, :, 0].transpose().shape)

In [None]:
fig, axs = plt.subplots(4, sharex = True)
axs[0].plot(y_dict['t_x'], y_dict['x'][:, :, 0].transpose(), color = "m", label = 'SOC x')
axs[0].scatter(np.repeat(y_dict['t_y'][:, None], batch_size, axis = 1), y_dict['y'][:, :, 0].transpose(), color = "m", alpha = 0.3, label = 'SOC y')
axs[1].plot(y_dict['t_x'], y_dict['x'][:, :, 1].transpose(), color = "c", label = 'DOC x')
axs[1].scatter(np.repeat(y_dict['t_y'][:, None], batch_size, axis = 1), y_dict['y'][:, :, 1].transpose(), color = "c", alpha = 0.3, label = 'DOC y')
axs[2].plot(y_dict['t_x'], y_dict['x'][:, :, 2].transpose(), color = "g", label = 'MBC x')
axs[2].scatter(np.repeat(y_dict['t_y'][:, None], batch_size, axis = 1), y_dict['y'][:, :, 2].transpose(), color = "g", alpha = 0.3, label = 'MBC y')
axs[3].plot(y_dict['t_x'], y_dict['x'][:, :, 3].transpose(), color = "orange", label = 'CO2')
axs[3].scatter(np.repeat(y_dict['t_y'][:, None], batch_size, axis = 1), y_dict['y'][:, :, 3].transpose(), color = "orange", alpha = 0.3, label = 'CO2 y')

In [None]:
now = datetime.now()
sbm_model = 'SCON-SS_CO2_trunc' + now.strftime('_%Y_%m_%d_%H_%M')
dir_path = '../generated_data/'
save_string = dir_path + f'{sbm_model}_sample_y_t_{t}_dt_{dt}_sd_scale_{prior_scale_factor}'.replace('.','-')
save_string_x = dir_path + f'{sbm_model}_sample_x_t_{t}_dt_{dt}_sd_scale_{prior_scale_factor}'.replace('.','-')
fig.savefig(save_string + '.png', dpi = 300)

In [None]:
#Save CSV of stochastic path.
df_y = pd.DataFrame(data = {'hour': y_dict['t_y'], 'SOC': y_dict['y'][0, :], 'DOC': y_dict['y'][1, :], 'MBC': y_dict['y'][2, :], 'CO2': y_dict['y'][3, :]})
df_y.to_csv(save_string + '.csv', index = False)

In [None]:
#Save rsampled theta values.
torch.save(SCON_SS_params_dict, save_string + '_rsample.pt')

#Save priors dict.
torch.save(SCON_SS_priors_dict, save_string + '_hyperparams.pt')

In [None]:
print(df_y)

In [None]:
y_det_dict = get_SBM_SDE_euler_maruyama_y_det(batch_size, alpha_SCON_multi, x0_SCON, x0_scale, t, dt, params_dict, I_S_func, I_D_func, temp_func, temp_ref, temp_rise, obs_every, obs_error_scale)

print('y: ', y_det_dict['y'])
print('x: ', y_det_dict['x'])

In [None]:
fig2, axs2 = plt.subplots(4, sharex = True)
axs2[0].plot(y_det_dict['t_x'], y_det_dict['x'][0, :], color = "m", label = 'SOC x')
axs2[0].scatter(y_det_dict['t_y'], y_det_dict['y'][0, :], color = "m", alpha = 0.3, label = 'SOC y')
axs2[1].plot(y_det_dict['t_x'], y_det_dict['x'][1, :], color = "c", label = 'DOC x')
axs2[1].scatter(y_det_dict['t_y'], y_det_dict['y'][1, :], color = "c", alpha = 0.3, label = 'DOC y')
axs2[2].plot(y_det_dict['t_x'], y_det_dict['x'][2, :], color = "g", label = 'MBC x')
axs2[2].scatter(y_det_dict['t_y'], y_det_dict['y'][2, :], color = "g", alpha = 0.3, label = 'MBC y')
axs2[3].plot(y_det_dict['t_x'], y_det_dict['x'][3, :], color = "orange", label = 'CO2')
axs2[3].scatter(y_det_dict['t_y'], y_det_dict['y'][3, :], color = "orange", alpha = 0.3, label = 'CO2 y')

In [None]:
save_string_det = dir_path + f'{sbm_model}_sample_det_y_t_{t}_dt_{dt}_sd_scale_{prior_scale_factor}'.replace('.','-')
fig2.savefig(save_string_det + '.png', dpi = 300)

In [None]:
df_y_det = pd.DataFrame(data = {'hour': y_det_dict['t_y'], 'SOC': y_det_dict['y'][0, :], 'DOC': y_det_dict['y'][1, :], 'MBC': y_det_dict['y'][2, :], 'CO2': y_det_dict['y'][3, :]})
df_y_det.to_csv(save_string_det + '.csv', index = False)

In [None]:
print(df_y_det)

## Archive

In [None]:
X0_test = np.array([65, 0.4, 2.5])
X0_test_samples = np.random.normal(loc = X0_test, scale = 0.25 * X0_test, size = np.array([batch_size, 3]))
print(X0_test_samples)
X0_test_samples.shape

In [None]:
x_test[:, 0, :] = X0_test_samples
print(x_test[:, 0, :])
print(x_test[:, 0, :].shape)
x0 = x_test[:, 0, :]
print(x0[:, 0])
print(x0[:, 1])
print(x0[:, 2])
print(x0[:, 0].shape)

x0_stack = np.stack((x0[:, 0], x0[:, 1], x0[:, 2]), 1)
print(x0_stack.shape)

x0_cat = np.concatenate((np.expand_dims(x0[:, 0], axis = 1), np.expand_dims(x0[:, 1], axis = 1), np.expand_dims(x0[:, 2], axis = 1)), 1)
print(x0_cat.shape)

In [None]:
print(x_test[:, 1, :].shape)
SOC, DOC, MBC = np.array_split(x0, 3, 1)
print(SOC)
print(SOC.shape)

In [None]:
params_dict['a_SD'].shape

In [None]:
CO2_test = np.zeros([batch_size, n, 1])
print(CO2_test.shape)
CO2_test[:, 0, :] = get_CO2_CON_gen_y_multi(x_test[:, 0, :], params_dict, temp_func(0, temp_ref, temp_rise), temp_ref)

In [None]:
I_S = I_S_func(0)
I_D = I_D_func(0)
print(x_test[:, 1, :].shape)
blah = alpha_SCON_multi(x_test[:, 0, :], params_dict, I_S, I_D, temp_func(0.1, temp_ref, temp_rise), temp_ref, arrhenius_temp, linear_temp)
print(blah.shape)
x_test[:, 1, :] = alpha_SCON_multi(x_test[:, 0, :], params_dict, I_S, I_D, temp_func(0.1, temp_ref, temp_rise), temp_ref, arrhenius_temp, linear_temp)
print(x_test[:, 1, :])
print(x_test[:, 1, 0])

In [None]:
a = np.random.rand(batch_size, 3)
b = np.random.rand(batch_size, 3, 3)
c = np.zeros([batch_size, 3])

#loop
for i in range(batch_size):
    c[i, :] = np.random.multivariate_normal(mean = a[i, :], cov = b[i, :, :])

print(c)
print(c.shape)

a_melt = a.ravel()
print(a_melt.shape)

b_melt = b.ravel()
print(b_melt.shape)

b_reshape = b.reshape([30, 3])
print(b_reshape.shape)

c = np.random.multivariate_normal(mean = a_melt, cov = b_reshape)

x_test_2 = x_test
x_test_2[:, 1, :] = c[1, :]

In [None]:
blah2 = beta_SCON_SS_multi(x_test[:, 0, :], params_dict)
print(blah2)

In [51]:
a, b, c = torch.arange(3)
a, b, c

(tensor(0), tensor(1), tensor(2))

In [52]:
torch.cat(torch.atleast_1d([a, b, c]), -1)

tensor([0, 1, 2])

In [53]:
torch.cat([a, b, c], -1)

RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated

In [76]:
x = torch.arange(3)
x1, x2, x3 = torch.chunk(x, 3, -1)
x1.shape

torch.Size([1])