In [1]:
# Hack to import from a parent directory
import sys
path = '..'
if path not in sys.path:
    sys.path.append(path)

In [2]:
#Python-related imports
import math
from tqdm import tqdm
from typing import Dict, Tuple, Union

#Torch-related imports
import torch
from torch.autograd import Function
from torch import nn
import torch.distributions as D
import torch.nn.functional as F
import torch.optim as optim

#Module imports
from mean_field import *
from obs_and_flow import *
from SBM_SDE_classes import *
from training import *

In [3]:
#PyTorch settings
torch.manual_seed(0)
print('cuda device available?: ', torch.cuda.is_available())
active_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
torch.set_printoptions(precision = 8)

#IAF SSM time parameters
dt_flow = 1.0 #Increased from 0.1 to reduce memory.
t = 1000 #In hours.
n = int(t / dt_flow) + 1
t_span = np.linspace(0, t, n)
t_span_tensor = torch.reshape(torch.Tensor(t_span), [1, n, 1]).to(active_device) #T_span needs to be converted to tensor object. Additionally, facilitates conversion of I_S and I_D to tensor objects.

#SBM temperature forcing parameters
temp_ref = 283
temp_rise = 5 #High estimate of 5 celsius temperature rise by 2100.

#Training parameters
n_iter = 10
ptrain_iter = 4
train_lr = 2e-5 #ELBO learning rate
ptrain_lr = 1e-5
batch_size = 32 #32 is presently max batch_size with 16 GB VRAM at t = 5000 so far.
eval_batch_size = 32
obs_error_scale = 0.1 #Observation (y) standard deviation.
prior_scale_factor = 0.333 #Proportion of prior standard deviation to prior means.
num_layers = 5

#Specify desired SBM SDE model type and details.
state_dim_SCON = 3
SBM_SDE_class = 'SCON'
diffusion_type = 'SS'
learn_CO2 = True
theta_dist = 'TruncatedNormal' #String needs to be exact name of the distribution class. Options are 'TruncatedNormal' and 'RescaledLogitNormal'.

cuda device available?:  False


In [4]:
state_dim_SCON = 3
SBM_SDE_class = 'SCON'
diffusion_type = 'SS'
learn_CO2 = True
theta_dist = 'TruncatedNormal' #String needs to be exact name of the distribution class. Options are 'TruncatedNormal' and 'RescaledLogitNormal'.
fix_dict = None

In [5]:
csv_data_path = os.path.join('../generated_data/', 'SCON-SS_CO2_trunc_2021_11_10_20_12_sample_y_t_100000_dt_0-01_sd_scale_0-25.csv')

In [6]:
obs_times, obs_means, obs_error = csv_to_obs_df(csv_data_path, state_dim_SCON + 1, t, obs_error_scale) 

In [7]:
obs_model = ObsModel(active_device, TIMES = obs_times, DT = dt_flow, MU = obs_means, SCALE = obs_error).to(active_device)

In [8]:
obs_model.mu

tensor([[1.33655746e+02, 1.37599396e+02, 1.32552750e+02, 1.36447433e+02,
         1.19090103e+02, 1.27675194e+02, 1.29534744e+02, 1.34008728e+02,
         1.32003387e+02, 1.26633530e+02, 1.36240814e+02, 1.27916344e+02,
         1.18501625e+02, 1.26866882e+02, 1.14885880e+02, 1.32521927e+02,
         1.40380585e+02, 1.34149887e+02, 1.18848526e+02, 1.32837479e+02,
         1.35612686e+02, 1.33594543e+02, 1.24616821e+02, 1.37357712e+02,
         1.29805008e+02, 1.20333199e+02, 1.22046661e+02, 1.36914200e+02,
         1.12555191e+02, 1.30379700e+02, 1.35227112e+02, 1.38606781e+02,
         1.34566818e+02, 1.32145905e+02, 1.22297310e+02, 1.21963211e+02,
         1.27199722e+02, 1.26911591e+02, 1.09787529e+02, 1.30494919e+02,
         1.23749916e+02, 1.27413124e+02, 1.26225586e+02, 1.16144081e+02,
         1.21134430e+02, 1.41855591e+02, 1.27691795e+02, 1.29586624e+02,
         1.24219566e+02, 1.23242294e+02, 1.16223892e+02, 1.17546402e+02,
         1.13614815e+02, 1.27463501e+02, 1.25172691

In [9]:
obs_model.mu[:-1, :]

tensor([[133.65574646, 137.59939575, 132.55274963, 136.44743347, 119.09010315,
         127.67519379, 129.53474426, 134.00872803, 132.00338745, 126.63352966,
         136.24081421, 127.91634369, 118.50162506, 126.86688232, 114.88587952,
         132.52192688, 140.38058472, 134.14988708, 118.84852600, 132.83747864,
         135.61268616, 133.59454346, 124.61682129, 137.35771179, 129.80500793,
         120.33319855, 122.04666138, 136.91419983, 112.55519104, 130.37969971,
         135.22711182, 138.60678101, 134.56681824, 132.14590454, 122.29730988,
         121.96321106, 127.19972229, 126.91159058, 109.78752899, 130.49491882,
         123.74991608, 127.41312408, 126.22558594, 116.14408112, 121.13442993,
         141.85559082, 127.69179535, 129.58662415, 124.21956635, 123.24229431,
         116.22389221, 117.54640198, 113.61481476, 127.46350098, 125.17269135,
         131.89709473, 137.45140076, 124.66497040, 111.43415070, 138.04528809,
         139.30200195, 123.78310394, 133.72279358, 1

In [10]:
obs_model.mu[:-1, :].size()

torch.Size([3, 201])

In [11]:
obs_model.mu.size()

torch.Size([4, 201])

In [12]:
torch.mean(obs_model.mu, -1)[None, None, :].size()

torch.Size([1, 1, 4])

In [13]:
torch.mean(obs_model.mu[:-1, :], -1)[None, None, :].size()

torch.Size([1, 1, 3])

In [14]:
net = SDEFlow(active_device, obs_model, state_dim_SCON, t, dt_flow, n, num_layers = num_layers).to(active_device)

In [15]:
x0_SCON_tensor = torch.load('../generated_data/SCON-SS_CO2_trunc_2021_11_10_20_12_sample_y_t_100000_dt_0-01_sd_scale_0-25_x0_SCON_tensor.pt').to(active_device)

In [16]:
x0_prior_SCON = D.multivariate_normal.MultivariateNormal(x0_SCON_tensor, scale_tril = torch.eye(state_dim_SCON).to(active_device) * obs_error_scale * x0_SCON_tensor)

In [17]:
#Generate exogenous input vectors.
#Obtain temperature forcing function.
temp_tensor = temp_gen(t_span_tensor, temp_ref, temp_rise).to(active_device)

#Obtain SOC and DOC pool litter input vectors for use in flow SDE functions.
i_s_tensor = i_s(t_span_tensor).to(active_device) #Exogenous SOC input function
i_d_tensor = i_d(t_span_tensor).to(active_device) #Exogenous DOC input function

In [18]:
SCON_SS_priors_details = {k: v.to(active_device) for k, v in torch.load('../generated_data/SCON-SS_CO2_trunc_2021_11_10_20_12_sample_y_t_100000_dt_0-01_sd_scale_0-25_hyperparams.pt').items()}

In [None]:
net, q_theta, p_theta, obs_model, norm_losses, ELBO_hist, list_parent_loc_scale, SBM_SDE_instance = train2(
        active_device, train_lr, n_iter, batch_size, num_layers,
        csv_data_path, obs_error_scale, t, dt_flow, n, 
        t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref,
        SBM_SDE_class, diffusion_type, x0_prior_SCON, SCON_SS_priors_details, fix_dict, learn_CO2,
        theta_dist, BYPASS_NAN = False, LR_DECAY = 0.92, DECAY_STEP_SIZE = 25000, PRINT_EVERY = 1,
        DEBUG_SAVE_DIR = None, PTRAIN_ITER = ptrain_iter, PTRAIN_LR = ptrain_lr, PTRAIN_ALG = 'L2')


Learning SDE and hidden parameters.:   0%|          | 0/10 [00:00<?, ?it/s][A

Moving average norm loss at 1 iterations is: 4141.8447265625. Best norm loss value is: 4141.8447265625.

C_PATH mean = tensor([[0.65458912, 0.66522533, 0.67140174],
        [0.65596569, 0.65282661, 0.66341084],
        [0.64543110, 0.65224940, 0.66861689],
        [0.65754622, 0.66309446, 0.66157770],
        [0.65033585, 0.65733272, 0.66318500],
        [0.65846348, 0.64615530, 0.67273003],
        [0.66331226, 0.65603834, 0.66739100],
        [0.65337247, 0.64800060, 0.67124259],
        [0.65571070, 0.65965843, 0.64987189],
        [0.65532547, 0.66026253, 0.65502691],
        [0.64724052, 0.63940984, 0.67256165],
        [0.65299541, 0.64882803, 0.66264778],
        [0.65327501, 0.66443586, 0.66182840],
        [0.65221804, 0.66435254, 0.65548718],
        [0.64684272, 0.64841437, 0.66874951],
        [0.65179056, 0.66524315, 0.65920955],
        [0.65582120, 0.65319216, 0.65315044],
        [0.66552937, 0.65008909, 0.66325021],
        [0.66550726, 0.66042656, 0.65504295],
       


Learning SDE and hidden parameters.:  10%|█         | 1/10 [00:13<01:57, 13.05s/it][A

Moving average norm loss at 2 iterations is: 4141.79052734375. Best norm loss value is: 4141.736328125.

C_PATH mean = tensor([[0.67007512, 0.64847213, 0.66200346],
        [0.65519637, 0.65922290, 0.66259450],
        [0.64997870, 0.65493566, 0.66226441],
        [0.64728332, 0.65316236, 0.66411686],
        [0.65487391, 0.66981858, 0.66905022],
        [0.65883887, 0.64397150, 0.67006713],
        [0.65145898, 0.65172273, 0.67003441],
        [0.66532826, 0.66307610, 0.66602820],
        [0.66219574, 0.65529376, 0.66477817],
        [0.65771461, 0.65727764, 0.67121464],
        [0.66201210, 0.66114628, 0.65530264],
        [0.65826255, 0.65671241, 0.66700071],
        [0.66192889, 0.66023779, 0.65784371],
        [0.66215289, 0.65844470, 0.66913325],
        [0.65711135, 0.66287297, 0.66622293],
        [0.66593772, 0.65715450, 0.66050881],
        [0.63913554, 0.67051697, 0.66131884],
        [0.66099733, 0.65371233, 0.66660601],
        [0.66033745, 0.65312225, 0.67697173],
       


Learning SDE and hidden parameters.:  20%|██        | 2/10 [00:26<01:46, 13.35s/it][A

Moving average norm loss at 3 iterations is: 4141.734375. Best norm loss value is: 4141.6220703125.

C_PATH mean = tensor([[0.65811211, 0.65836602, 0.67529356],
        [0.65846103, 0.66023058, 0.67238796],
        [0.65780705, 0.66562831, 0.67287397],
        [0.66109270, 0.66011298, 0.66088271],
        [0.66926163, 0.66546106, 0.66184354],
        [0.66856766, 0.65798450, 0.66778666],
        [0.66355157, 0.65863121, 0.66607767],
        [0.65163434, 0.65513599, 0.65964407],
        [0.65579766, 0.66346067, 0.65855700],
        [0.65541327, 0.64196253, 0.68096232],
        [0.67143768, 0.66265851, 0.66860211],
        [0.66035616, 0.66789019, 0.65434498],
        [0.65059245, 0.66324168, 0.66365117],
        [0.66483045, 0.66489172, 0.67958826],
        [0.66912580, 0.65021992, 0.66625589],
        [0.65716076, 0.66671330, 0.66598845],
        [0.65310925, 0.65813982, 0.66821945],
        [0.66323906, 0.66095245, 0.65524906],
        [0.67061836, 0.66221696, 0.67195088],
        [0.


Learning SDE and hidden parameters.:  30%|███       | 3/10 [00:40<01:36, 13.76s/it][A

Moving average norm loss at 4 iterations is: 4141.6822509765625. Best norm loss value is: 4141.52587890625.

C_PATH mean = tensor([[0.65919089, 0.65145665, 0.67590022],
        [0.66589606, 0.66583997, 0.68394256],
        [0.66163260, 0.65708977, 0.66986161],
        [0.65851468, 0.65747350, 0.67469883],
        [0.65232676, 0.67272484, 0.66583002],
        [0.65742230, 0.66497147, 0.67029613],
        [0.66307324, 0.65635931, 0.67411941],
        [0.67643887, 0.66464436, 0.67707443],
        [0.66697145, 0.66967118, 0.65724254],
        [0.67142963, 0.66523969, 0.66295856],
        [0.66598624, 0.65524054, 0.66908562],
        [0.66569793, 0.66837573, 0.66618919],
        [0.68046492, 0.66072565, 0.67754501],
        [0.66171610, 0.67060894, 0.67063677],
        [0.66220957, 0.66574550, 0.67251408],
        [0.66102910, 0.66774768, 0.66337138],
        [0.67264265, 0.67062759, 0.66474950],
        [0.65916979, 0.67266488, 0.66681635],
        [0.66650110, 0.65958238, 0.67825377],
   


Learning SDE and hidden parameters.:  40%|████      | 4/10 [00:54<01:23, 13.85s/it][A

Moving average norm loss at 5 iterations is: 4141.6287109375. Best norm loss value is: 4141.41455078125.

C_PATH mean = tensor([[0.66796350, 0.65916383, 0.67438751],
        [0.66555178, 0.65922052, 0.68210775],
        [0.67366421, 0.67923629, 0.65685016],
        [0.66590714, 0.66572028, 0.66883391],
        [0.66832662, 0.66593021, 0.66605544],
        [0.65068811, 0.66580641, 0.68313974],
        [0.66572976, 0.66921496, 0.67007881],
        [0.66163784, 0.66843218, 0.68086040],
        [0.67095006, 0.66815996, 0.67219794],
        [0.66473240, 0.66255575, 0.67501491],
        [0.67508358, 0.67053199, 0.66597366],
        [0.65193778, 0.66182363, 0.68367314],
        [0.67308569, 0.67067677, 0.66493112],
        [0.65628552, 0.67114729, 0.66734862],
        [0.67296988, 0.66894072, 0.67447090],
        [0.67860234, 0.66686839, 0.66668731],
        [0.65400314, 0.66987407, 0.68808925],
        [0.65250742, 0.67055333, 0.68075860],
        [0.68319827, 0.66153860, 0.66808796],
      


Learning SDE and hidden parameters.:  50%|█████     | 5/10 [01:08<01:09, 13.90s/it][A

drift at 6 iterations: tensor([[[ 1.02078915e-03, -6.27047411e-05,  6.11534197e-05],
         [ 1.01092609e-03,  4.65960038e-05,  1.83541106e-05],
         [ 1.03718985e-03,  1.58995608e-05, -3.41501873e-05],
         ...,
         [ 1.34033198e-03,  5.29054305e-05,  3.05370268e-05],
         [ 1.33804942e-03,  5.67155512e-05,  3.11323820e-05],
         [ 1.34059880e-03,  5.45104631e-05,  3.00216034e-05]],

        [[ 1.03544933e-03, -3.32237280e-04,  3.24632303e-04],
         [ 1.01124181e-03,  7.88893740e-05, -2.29226971e-05],
         [ 8.72739125e-04,  1.57392104e-04, -4.01028447e-05],
         ...,
         [ 1.33418955e-03,  2.11524020e-05,  7.88406178e-05],
         [ 1.33451377e-03,  2.17793640e-05,  8.08656041e-05],
         [ 1.33761286e-03, -1.27562816e-05,  1.11626025e-04]],

        [[ 1.03144266e-03, -2.07392077e-05,  1.69441446e-05],
         [ 1.01359305e-03, -4.88771184e-06,  4.74951376e-05],
         [ 1.04252389e-03, -3.17378581e-05,  1.85242534e-05],
         ...,
 


Learning SDE and hidden parameters.:  60%|██████    | 6/10 [01:23<00:56, 14.25s/it][A