In [1]:
# Hack to import from a parent directory
import sys
path = '..'
if path not in sys.path:
    sys.path.append(path)

In [2]:
#Python-related imports
import math
from tqdm import tqdm
from typing import Dict, Tuple, Union

#Torch-related imports
import torch
from torch.autograd import Function
from torch import nn
import torch.distributions as D
import torch.nn.functional as F
import torch.optim as optim

#Module imports
from mean_field import *
from obs_and_flow import *
from SBM_SDE_classes import *
from training import *

In [3]:
#PyTorch settings
torch.manual_seed(0)
print('cuda device available?: ', torch.cuda.is_available())
active_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
torch.set_printoptions(precision = 8)

#IAF SSM time parameters
dt_flow = 1.0 #Increased from 0.1 to reduce memory.
t = 1000 #In hours.
n = int(t / dt_flow) + 1
t_span = np.linspace(0, t, n)
t_span_tensor = torch.reshape(torch.Tensor(t_span), [1, n, 1]).to(active_device) #T_span needs to be converted to tensor object. Additionally, facilitates conversion of I_S and I_D to tensor objects.

#SBM temperature forcing parameters
temp_ref = 283
temp_rise = 5 #High estimate of 5 celsius temperature rise by 2100.

#Training parameters
n_iter = 10
ptrain_iter = 4
train_lr = 2e-5 #ELBO learning rate
ptrain_lr = 1e-4
batch_size = 32 #32 is presently max batch_size with 16 GB VRAM at t = 5000 so far.
eval_batch_size = 32
obs_error_scale = 0.1 #Observation (y) standard deviation.
prior_scale_factor = 0.333 #Proportion of prior standard deviation to prior means.
num_layers = 5

#Specify desired SBM SDE model type and details.
state_dim_SCON = 3
SBM_SDE_class = 'SCON'
diffusion_type = 'SS'
learn_CO2 = True
theta_dist = 'TruncatedNormal' #String needs to be exact name of the distribution class. Options are 'TruncatedNormal' and 'RescaledLogitNormal'.

cuda device available?:  False


In [4]:
state_dim_SCON = 3
SBM_SDE_class = 'SCON'
diffusion_type = 'SS'
learn_CO2 = True
theta_dist = 'TruncatedNormal' #String needs to be exact name of the distribution class. Options are 'TruncatedNormal' and 'RescaledLogitNormal'.
fix_dict = None

In [5]:
csv_data_path = os.path.join('../generated_data/', 'SCON-SS_CO2_trunc_2021_11_10_20_12_sample_y_t_100000_dt_0-01_sd_scale_0-25.csv')

In [6]:
obs_times, obs_means, obs_error = csv_to_obs_df(csv_data_path, state_dim_SCON + 1, t, obs_error_scale) 

In [7]:
obs_model = ObsModel(active_device, TIMES = obs_times, DT = dt_flow, MU = obs_means, SCALE = obs_error).to(active_device)

In [8]:
obs_model.mu

tensor([[1.33655746e+02, 1.37599396e+02, 1.32552750e+02, 1.36447433e+02,
         1.19090103e+02, 1.27675194e+02, 1.29534744e+02, 1.34008728e+02,
         1.32003387e+02, 1.26633530e+02, 1.36240814e+02, 1.27916344e+02,
         1.18501625e+02, 1.26866882e+02, 1.14885880e+02, 1.32521927e+02,
         1.40380585e+02, 1.34149887e+02, 1.18848526e+02, 1.32837479e+02,
         1.35612686e+02, 1.33594543e+02, 1.24616821e+02, 1.37357712e+02,
         1.29805008e+02, 1.20333199e+02, 1.22046661e+02, 1.36914200e+02,
         1.12555191e+02, 1.30379700e+02, 1.35227112e+02, 1.38606781e+02,
         1.34566818e+02, 1.32145905e+02, 1.22297310e+02, 1.21963211e+02,
         1.27199722e+02, 1.26911591e+02, 1.09787529e+02, 1.30494919e+02,
         1.23749916e+02, 1.27413124e+02, 1.26225586e+02, 1.16144081e+02,
         1.21134430e+02, 1.41855591e+02, 1.27691795e+02, 1.29586624e+02,
         1.24219566e+02, 1.23242294e+02, 1.16223892e+02, 1.17546402e+02,
         1.13614815e+02, 1.27463501e+02, 1.25172691

In [9]:
obs_model.mu[:-1, :]

tensor([[133.65574646, 137.59939575, 132.55274963, 136.44743347, 119.09010315,
         127.67519379, 129.53474426, 134.00872803, 132.00338745, 126.63352966,
         136.24081421, 127.91634369, 118.50162506, 126.86688232, 114.88587952,
         132.52192688, 140.38058472, 134.14988708, 118.84852600, 132.83747864,
         135.61268616, 133.59454346, 124.61682129, 137.35771179, 129.80500793,
         120.33319855, 122.04666138, 136.91419983, 112.55519104, 130.37969971,
         135.22711182, 138.60678101, 134.56681824, 132.14590454, 122.29730988,
         121.96321106, 127.19972229, 126.91159058, 109.78752899, 130.49491882,
         123.74991608, 127.41312408, 126.22558594, 116.14408112, 121.13442993,
         141.85559082, 127.69179535, 129.58662415, 124.21956635, 123.24229431,
         116.22389221, 117.54640198, 113.61481476, 127.46350098, 125.17269135,
         131.89709473, 137.45140076, 124.66497040, 111.43415070, 138.04528809,
         139.30200195, 123.78310394, 133.72279358, 1

In [10]:
obs_model.mu[:-1, :].size()

torch.Size([3, 201])

In [11]:
obs_model.mu.size()

torch.Size([4, 201])

In [12]:
torch.mean(obs_model.mu, -1)[None, None, :].size()

torch.Size([1, 1, 4])

In [13]:
torch.mean(obs_model.mu[:-1, :], -1)[None, None, :].size()

torch.Size([1, 1, 3])

In [14]:
net = SDEFlow(active_device, obs_model, state_dim_SCON, t, dt_flow, n, num_layers = num_layers).to(active_device)

In [15]:
x0_SCON_tensor = torch.load('../generated_data/SCON-SS_CO2_trunc_2021_11_10_20_12_sample_y_t_100000_dt_0-01_sd_scale_0-25_x0_SCON_tensor.pt').to(active_device)

In [16]:
x0_prior_SCON = D.multivariate_normal.MultivariateNormal(x0_SCON_tensor, scale_tril = torch.eye(state_dim_SCON).to(active_device) * obs_error_scale * x0_SCON_tensor)

In [17]:
#Generate exogenous input vectors.
#Obtain temperature forcing function.
temp_tensor = temp_gen(t_span_tensor, temp_ref, temp_rise).to(active_device)

#Obtain SOC and DOC pool litter input vectors for use in flow SDE functions.
i_s_tensor = i_s(t_span_tensor).to(active_device) #Exogenous SOC input function
i_d_tensor = i_d(t_span_tensor).to(active_device) #Exogenous DOC input function

In [18]:
SCON_SS_priors_details = {k: v.to(active_device) for k, v in torch.load('../generated_data/SCON-SS_CO2_trunc_2021_11_10_20_12_sample_y_t_100000_dt_0-01_sd_scale_0-25_hyperparams.pt').items()}

In [19]:
net, q_theta, p_theta, obs_model, norm_losses, ELBO_hist, list_parent_loc_scale, SBM_SDE_instance = train2(
        active_device, train_lr, n_iter, batch_size, num_layers,
        csv_data_path, obs_error_scale, t, dt_flow, n, 
        t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref,
        SBM_SDE_class, diffusion_type, x0_prior_SCON, SCON_SS_priors_details, fix_dict, learn_CO2,
        theta_dist, BYPASS_NAN = False, LR_DECAY = 0.92, DECAY_STEP_SIZE = 25000, PRINT_EVERY = 1,
        DEBUG_SAVE_DIR = None, PTRAIN_ITER = ptrain_iter, PTRAIN_LR = ptrain_lr, PTRAIN_ALG = 'L2')


Learning SDE and hidden parameters.:   0%|          | 0/10 [00:00<?, ?it/s][A

Moving average norm loss at 1 iterations is: 4141.8447265625. Best norm loss value is: 4141.8447265625.

C_PATH mean = tensor([[0.65458912, 0.66522533, 0.67140174],
        [0.65596569, 0.65282661, 0.66341084],
        [0.64543110, 0.65224940, 0.66861689],
        [0.65754622, 0.66309446, 0.66157770],
        [0.65033585, 0.65733272, 0.66318500],
        [0.65846348, 0.64615530, 0.67273003],
        [0.66331226, 0.65603834, 0.66739100],
        [0.65337247, 0.64800060, 0.67124259],
        [0.65571070, 0.65965843, 0.64987189],
        [0.65532547, 0.66026253, 0.65502691],
        [0.64724052, 0.63940984, 0.67256165],
        [0.65299541, 0.64882803, 0.66264778],
        [0.65327501, 0.66443586, 0.66182840],
        [0.65221804, 0.66435254, 0.65548718],
        [0.64684272, 0.64841437, 0.66874951],
        [0.65179056, 0.66524315, 0.65920955],
        [0.65582120, 0.65319216, 0.65315044],
        [0.66552937, 0.65008909, 0.66325021],
        [0.66550726, 0.66042656, 0.65504295],
       


Learning SDE and hidden parameters.:  10%|█         | 1/10 [00:14<02:12, 14.73s/it][A

Moving average norm loss at 2 iterations is: 4141.314208984375. Best norm loss value is: 4140.78369140625.

C_PATH mean = tensor([[0.68943465, 0.67741245, 0.69549745],
        [0.68350190, 0.68204880, 0.69251472],
        [0.67248285, 0.67615485, 0.68936986],
        [0.68338877, 0.66523141, 0.69221199],
        [0.67719215, 0.69105482, 0.69040489],
        [0.66602021, 0.66690367, 0.70158815],
        [0.68554258, 0.67505640, 0.68691528],
        [0.68718153, 0.67273158, 0.70019555],
        [0.69306308, 0.67579538, 0.67804170],
        [0.67725343, 0.67842746, 0.69448340],
        [0.68547177, 0.68882996, 0.68956971],
        [0.67611271, 0.66824406, 0.69064265],
        [0.68292767, 0.67355508, 0.67091173],
        [0.68505496, 0.67956614, 0.68316227],
        [0.68454081, 0.68769562, 0.69405782],
        [0.68885869, 0.67140675, 0.70596242],
        [0.66251528, 0.68980730, 0.70402354],
        [0.67931217, 0.66782176, 0.69631886],
        [0.67728597, 0.67951918, 0.70818990],
    


Learning SDE and hidden parameters.:  20%|██        | 2/10 [00:26<01:45, 13.23s/it][A

Moving average norm loss at 3 iterations is: 4140.756184895833. Best norm loss value is: 4139.64013671875.

C_PATH mean = tensor([[0.69800901, 0.68273628, 0.74134469],
        [0.72484529, 0.70404106, 0.71125907],
        [0.71880037, 0.69252515, 0.70996028],
        [0.70870233, 0.71712697, 0.73898667],
        [0.73447675, 0.69378376, 0.70760012],
        [0.70719653, 0.71654677, 0.71998948],
        [0.74815309, 0.69273144, 0.72319812],
        [0.69693971, 0.70080191, 0.70274413],
        [0.68875098, 0.68947136, 0.69042808],
        [0.71511871, 0.69084311, 0.73087025],
        [0.73059452, 0.68547702, 0.76065898],
        [0.71620566, 0.70587045, 0.70383316],
        [0.70333207, 0.69173503, 0.70609730],
        [0.71354985, 0.70543164, 0.72676724],
        [0.72402596, 0.69545496, 0.73182458],
        [0.70978463, 0.69387096, 0.71507263],
        [0.68912357, 0.70307332, 0.72964460],
        [0.70892847, 0.71834505, 0.70495498],
        [0.73384440, 0.72290462, 0.77672756],
    


Learning SDE and hidden parameters.:  30%|███       | 3/10 [00:39<01:30, 12.87s/it][A

Moving average norm loss at 4 iterations is: 4140.1990966796875. Best norm loss value is: 4138.52783203125.

C_PATH mean = tensor([[0.76915407, 0.68967170, 0.75508773],
        [0.74567384, 0.71772200, 0.72659540],
        [0.72881246, 0.71565849, 0.75524187],
        [0.71481037, 0.74370468, 0.73222530],
        [0.73892808, 0.74630064, 0.74261457],
        [0.73608273, 0.73358220, 0.81294167],
        [0.71189672, 0.72398680, 0.72986299],
        [0.76463854, 0.74818689, 0.75874281],
        [0.76854092, 0.73540902, 0.76340592],
        [0.73489314, 0.70384288, 0.73058254],
        [0.77957267, 0.75259781, 0.75386596],
        [0.73503715, 0.72096694, 0.74303126],
        [0.77594936, 0.71837896, 0.76723689],
        [0.74547356, 0.71998978, 0.75388098],
        [0.73614651, 0.74889785, 0.74919450],
        [0.71862423, 0.73517948, 0.72549623],
        [0.74670655, 0.76355678, 0.76247251],
        [0.75514221, 0.72556734, 0.73565859],
        [0.76124227, 0.73080707, 0.71090710],
   


Learning SDE and hidden parameters.:  40%|████      | 4/10 [00:53<01:20, 13.43s/it][A

Moving average norm loss at 5 iterations is: 4139.82783203125. Best norm loss value is: 4138.3427734375.

C_PATH mean = tensor([[0.70978570, 0.68951631, 0.72655642],
        [0.79616970, 0.74937975, 0.76487142],
        [0.76184952, 0.76078790, 0.78103787],
        [0.73682094, 0.72109681, 0.74007291],
        [0.76829463, 0.74424762, 0.74911076],
        [0.71980906, 0.70227689, 0.70911568],
        [0.77599478, 0.73266017, 0.75158900],
        [0.72358948, 0.72667193, 0.72137094],
        [0.75423199, 0.74755740, 0.76378816],
        [0.72332782, 0.72032291, 0.71519548],
        [0.77269387, 0.75386918, 0.78599304],
        [0.77711415, 0.75247127, 0.76735055],
        [0.74264252, 0.71334124, 0.72711408],
        [0.73549718, 0.72835529, 0.70609993],
        [0.71187782, 0.71231383, 0.73406976],
        [0.70050502, 0.71377790, 0.71429157],
        [0.76875710, 0.75593871, 0.74768686],
        [0.73774922, 0.73674846, 0.73632902],
        [0.72887474, 0.72708315, 0.75772518],
      


Learning SDE and hidden parameters.:  50%|█████     | 5/10 [01:07<01:07, 13.56s/it][A

drift at 6 iterations: tensor([[[ 9.86035331e-04,  1.04432715e-04, -4.38653333e-05],
         [ 1.01213099e-03, -9.42040060e-07,  4.81934767e-05],
         [ 1.00459624e-03, -3.21553671e-07,  4.13679954e-05],
         ...,
         [ 1.34633260e-03,  3.13137352e-05,  4.43865538e-05],
         [ 1.34136016e-03,  4.99351809e-05,  3.53522737e-05],
         [ 1.33681821e-03,  6.56500342e-05,  2.69253633e-05]],

        [[ 1.06867438e-03, -2.74523336e-04,  2.34034873e-04],
         [ 1.21955690e-03, -4.65891615e-04, -3.09236348e-05],
         [ 9.92470654e-04, -7.03850164e-05,  1.35654671e-04],
         ...,
         [ 1.33520830e-03,  3.46012966e-05,  6.68762514e-05],
         [ 1.32992398e-03,  4.69113729e-05,  5.92416982e-05],
         [ 1.33942650e-03, -1.41640776e-05,  1.12896043e-04]],

        [[ 1.01098616e-03,  7.85036609e-05, -4.85545825e-05],
         [ 1.01775303e-03,  5.13831765e-06,  3.29340764e-05],
         [ 1.01662753e-03,  2.66787974e-05,  1.23988739e-05],
         ...,
 


Learning SDE and hidden parameters.:  60%|██████    | 6/10 [01:20<00:53, 13.27s/it][A

drift at 7 iterations: tensor([[[ 9.65265790e-04,  1.09764587e-04, -7.04793347e-05],
         [ 9.68373439e-04,  1.60265918e-05,  1.18039898e-05],
         [ 1.08345656e-03, -1.82010117e-04, -1.24801678e-04],
         ...,
         [ 1.33588491e-03,  3.71668939e-05,  3.16059413e-05],
         [ 1.33531052e-03,  4.06203981e-05,  3.29321447e-05],
         [ 1.32876437e-03,  5.80502820e-05,  2.56397143e-05]],

        [[ 9.69625369e-04,  9.92799178e-05, -5.56802079e-05],
         [ 1.00250391e-03, -6.85127889e-05,  1.26692743e-04],
         [ 6.29336340e-04, -7.06553110e-05,  2.64767732e-04],
         ...,
         [ 1.33205927e-03,  3.18061939e-05,  5.23847812e-05],
         [ 1.34865637e-03, -2.75221973e-05,  9.61569749e-05],
         [ 1.35453220e-03, -6.46642147e-05,  1.32127665e-04]],

        [[ 1.02234376e-03,  2.30127480e-05, -3.90284949e-05],
         [ 1.00730266e-03,  4.87579309e-05, -1.88986851e-05],
         [ 1.04878214e-03, -7.50664913e-06,  8.02900104e-06],
         ...,
 


Learning SDE and hidden parameters.:  70%|███████   | 7/10 [01:32<00:39, 13.03s/it][A

drift at 8 iterations: tensor([[[ 9.04795714e-04,  1.69744628e-04, -1.08492903e-04],
         [ 7.31760170e-04,  6.84046099e-05,  7.85100128e-05],
         [ 1.15802640e-03, -3.99424927e-04,  2.09177160e-04],
         ...,
         [ 1.33142667e-03,  5.24034913e-05,  2.16008884e-05],
         [ 1.34255341e-03,  1.87027035e-05,  4.82563264e-05],
         [ 1.33936387e-03,  2.29954167e-05,  4.75928027e-05]],

        [[ 9.69423563e-04,  3.45011358e-07, -1.04439096e-05],
         [ 1.04632869e-03,  1.16964657e-05, -1.65600810e-04],
         [ 1.01159269e-03, -1.05526109e-04,  1.08111875e-04],
         ...,
         [ 1.32420589e-03,  2.00583745e-05,  4.71665371e-05],
         [ 1.32606633e-03,  1.69819978e-05,  5.41116679e-05],
         [ 1.32260518e-03,  2.31052691e-05,  5.42371163e-05]],

        [[ 9.75054572e-04,  6.39513019e-06, -1.89895800e-05],
         [ 1.02467311e-03,  4.60521260e-05, -1.61842865e-04],
         [ 1.02120754e-03, -4.59231378e-05,  6.60815131e-05],
         ...,
 


Learning SDE and hidden parameters.:  80%|████████  | 8/10 [01:45<00:26, 13.05s/it][A

drift at 9 iterations: tensor([[[ 9.42579703e-04,  1.23679070e-04, -9.38825542e-05],
         [ 9.89288790e-04, -9.77258605e-06,  5.78969775e-06],
         [ 9.64893377e-04, -2.20663060e-05,  2.50469966e-05],
         ...,
         [ 1.32246967e-03, -2.77866493e-05,  8.39643908e-05],
         [ 1.31848501e-03, -1.00387697e-05,  7.71612176e-05],
         [ 1.28854974e-03,  5.65362716e-05,  4.14383903e-05]],

        [[ 9.65270505e-04,  1.32012574e-04, -8.90254596e-05],
         [ 1.09751464e-03, -3.12911870e-05, -6.76007257e-05],
         [ 1.29142764e-03, -4.01866157e-04, -4.25861799e-05],
         ...,
         [ 1.32820627e-03,  3.82962025e-05,  2.19286885e-05],
         [ 1.33076578e-03,  4.80250383e-05,  1.78999435e-05],
         [ 1.33095693e-03,  5.56763480e-05,  7.80224218e-06]],

        [[ 1.05828198e-03, -4.83591226e-04,  2.62556976e-04],
         [ 6.76072319e-04,  1.99561386e-04,  1.74873439e-05],
         [ 1.00919150e-03, -4.76795714e-04,  3.76364245e-04],
         ...,
 


Learning SDE and hidden parameters.:  90%|█████████ | 9/10 [01:58<00:12, 12.89s/it][A

drift at 10 iterations: tensor([[[ 1.00466097e-03, -2.75099068e-04,  1.39525306e-04],
         [ 8.95518111e-04,  2.11121223e-05,  5.64193397e-05],
         [ 9.25571134e-04, -5.92377182e-05,  1.13624075e-04],
         ...,
         [ 1.30957866e-03,  4.47211642e-05,  2.08652782e-05],
         [ 1.31866557e-03,  1.57788163e-05,  4.71317871e-05],
         [ 1.30857329e-03,  4.03518643e-05,  2.90026728e-05]],

        [[ 9.82625992e-04, -1.39927986e-04,  3.13424389e-05],
         [ 1.02663680e-03, -2.69201773e-05,  5.95392266e-05],
         [ 1.28452876e-03, -6.47925015e-04,  3.01201944e-05],
         ...,
         [ 1.32323999e-03,  3.43362190e-05,  1.72453365e-05],
         [ 1.31771772e-03,  2.99532403e-05,  1.70693093e-05],
         [ 1.31420884e-03,  2.35544721e-05,  3.52610914e-05]],

        [[ 9.09030903e-04,  1.46656545e-04, -1.17552387e-04],
         [ 9.41254897e-04,  6.26946130e-05, -1.98529699e-04],
         [ 9.16709541e-04,  1.49271305e-04, -9.61248370e-05],
         ...,



Learning SDE and hidden parameters.: 100%|██████████| 10/10 [02:12<00:00, 13.21s/it][A
