In [1]:
import torch
from torch import nn
import torch.distributions as d
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import math
from tqdm import tqdm
import random
from torch.autograd import Function
# from torch.utils.tensorboard import SummaryWriter
import argparse
import os
import sys
from pathlib import Path
import shutil
import pandas as pd

In [2]:
torch.manual_seed(0)
STATE_DIM = 5 #Including fake state CO2.
CUDA_ID = 1
dt = .2
T = 1000 #Run simulation for 1000 hours.
N = int(T / dt) 
T_span = np.linspace(0, T, N + 1)
T_span_tensor = torch.reshape(torch.Tensor(T_span), [1, N + 1, 1]) #T_span needs to be converted to tensor object. Additionally, facilitates conversion of I_S and I_D to tensor objects.

device = torch.device("".join(["cuda:",f'{CUDA_ID}']) if torch.cuda.is_available() else "cpu")
LR = 1e-3
niter = 20000
pretrain_iter = 1000
BATCH_SIZE = 1

In [3]:
obs_df_awb_full = pd.read_csv('AWB_synthetic_sol_df.csv') #Must be link to raw Github output if in Colab.
obs_df_awb = obs_df_awb_full[obs_df_awb_full['hour'] <= T] #Test with just first T hours of data.

In [4]:
obs_times = np.array(obs_df_awb['hour'])
obs_means_awb = torch.Tensor(np.array(obs_df_awb.drop(columns = 'hour')))
obs_means_awb_T = obs_means_awb.T
obs_error_scale_awb = torch.mean(obs_means_awb_T, 1) * 0.1 #Observation noise set at 10% of respective observation means.

In [5]:
I_S_tensor = 0.001 + 0.0005 * torch.sin((2 * math.pi / (24 * 365)) * T_span_tensor) #Exogenous SOC input function
I_D_tensor = 0.0001 + 0.00005 * torch.sin((2 * math.pi / (24 * 365)) * T_span_tensor) #Exogenous DOC input function

In [6]:
temp_ref = 283

#System parameters from deterministic model
u_Q_ref = 0.2
Q = 0.002
a_MSA = 0.5
K_D = 200
K_U = 1
V_D_ref = 0.4
V_U_ref = 0.02
Ea_V_D = 75
Ea_V_U = 50
r_M = 0.0004
r_E = 0.00001
r_L = 0.0005

#Diffusion matrix sigma scale parameters
c_SOC = 1.
c_DOC = 0.01
c_MBC = 0.1
c_EEC = 0.001
c_CO2 = 0.0001

sawb_c_params_dict = {'u_Q_ref': u_Q_ref, 'Q': Q, 'a_MSA': a_MSA, 'K_D': K_D, 'K_U': K_U, 'V_D_ref': V_D_ref, 'V_U_ref': V_U_ref, 'Ea_V_D': Ea_V_D, 'Ea_V_U': Ea_V_U, 'r_M': r_M, 'r_E': r_E, 'r_L': r_L, 'c_SOC': c_SOC, 'c_DOC': c_DOC, 'c_MBC': c_MBC, 'c_EEC': c_EEC, 'c_CO2': c_CO2}

In [7]:
############################################################
##SOIL BIOGEOCHEMICAL MODEL TEMPERATURE RESPONSE FUNCTIONS##
############################################################

def temp_gen(t, temp_ref):
    temp = temp_ref + t / (20 * 24 * 365) + 10 * torch.sin((2 * np.pi / 24) * t) + 10 * torch.sin((2 * math.pi / (24 * 365)) * t)
    return temp

def arrhenius_temp_dep(parameter, temp, Ea, temp_ref):
    '''
    For a parameter with Arrhenius temperature dependence, returns the transformed parameter value.
    0.008314 is the gas constant. Temperatures are in K.
    '''
    decayed_parameter = parameter * torch.exp(-Ea / 0.008314 * (1 / temp - 1 / temp_ref))
    return decayed_parameter

def linear_temp_dep(parameter, temp, Q, temp_ref):
    '''
    For a parameter with linear temperature dependence, returns the transformed parameter value.
    Q is the slope of the temperature dependence and is a varying parameter.
    Temperatures are in K.
    '''
    modified_parameter = parameter - Q * (temp - temp_ref)
    return modified_parameter

##########################################################################
##DETERMINISTIC SOIL BIOGEOCHEMICAL MODEL INITIAL STEADY STATE ESTIMATES##
##########################################################################

def analytical_steady_state_init_awb(SOC_input, DOC_input, sawb_c_params_dict = sawb_c_params_dict):
    '''
    Returns a vector of C pool values to initialize an SAWB system corresponding to set of parameter values using the analytical steady state equations of the deterministic CON system.
    Vector elements are in order of S_0, D_0, M_0, and E_0.
    Expected sawb_c_params_dict = {'u_Q_ref': u_Q_ref, 'Q': Q, 'a_MSA': a_MSA, 'K_D': K_D, 'K_U': K_U, 'V_D_ref': V_D_ref, 'V_U_ref': V_U_ref, 'Ea_V_D': Ea_V_D, 'Ea_V_U': Ea_V_U, 'r_M': r_M, 'r_E': r_E, 'r_L': r_L, 'c_SOC': c_SOC, 'c_DOC': c_DOC, 'c_MBC': c_MBC, 'c_EEC': c_EEC, 'c_CO2': c_CO2}
    '''
    S_0 = -((sawb_c_params_dict['K_D'] * sawb_c_params_dict['r_L'] * (SOC_input * sawb_c_params_dict['r_E'] * (sawb_c_params_dict['u_Q_ref'] - 1) - sawb_c_params_dict['a_MSA'] * DOC_input * sawb_c_params_dict['r_M'] * sawb_c_params_dict['u_Q_ref'] + SOC_input * sawb_c_params_dict['r_M'] * (-1 + sawb_c_params_dict['u_Q_ref'] - sawb_c_params_dict['a_MSA'] * sawb_c_params_dict['u_Q_ref']))) / (DOC_input * sawb_c_params_dict['u_Q_ref'] * (-sawb_c_params_dict['a_MSA'] * sawb_c_params_dict['r_L'] * sawb_c_params_dict['r_M'] + sawb_c_params_dict['r_E'] * sawb_c_params_dict['V_D_ref']) + SOC_input * (sawb_c_params_dict['r_E'] * sawb_c_params_dict['r_L'] * (sawb_c_params_dict['u_Q_ref'] - 1) + sawb_c_params_dict['r_L'] * sawb_c_params_dict['r_M'] * (-1 + sawb_c_params_dict['u_Q_ref'] - sawb_c_params_dict['a_MSA'] * sawb_c_params_dict['u_Q_ref']) + sawb_c_params_dict['r_E'] * sawb_c_params_dict['u_Q_ref'] * sawb_c_params_dict['V_D_ref'])))
    D_0 = -((sawb_c_params_dict['K_U'] * (sawb_c_params_dict['r_E'] + sawb_c_params_dict['r_M'])) / (sawb_c_params_dict['r_E'] + sawb_c_params_dict['r_M'] - sawb_c_params_dict['u_Q_ref'] * sawb_c_params_dict['V_U_ref']))
    M_0 = -((SOC_input + DOC_input) * sawb_c_params_dict['u_Q_ref']) / ((sawb_c_params_dict['r_E'] + sawb_c_params_dict['r_M']) * (sawb_c_params_dict['u_Q_ref'] - 1))
    E_0 = sawb_c_params_dict['r_E'] * M_0 / sawb_c_params_dict['r_L']
    #E_0 = -((sawb_c_params_dict['r_E'] * sawb_c_params_dict['u_Q_ref'] * (SOC_input + DOC_input)) / (sawb_c_params_dict['r_L'] * (sawb_c_params_dict['r_E'] + sawb_c_params_dict['r_M']) * (sawb_c_params_dict['u_Q_ref'] - 1)))
    CO2_0 = (1 - sawb_c_params_dict['u_Q_ref']) * (sawb_c_params_dict['V_U_ref'] * M_0 * D_0) / (sawb_c_params_dict['K_U'] + D_0)
    C_0_vector = torch.as_tensor([S_0, D_0, M_0, E_0, CO2_0])
    return C_0_vector

In [8]:
####################################################
##STOCHASTIC DIFFERENTIAL EQUATION MODEL FUNCTIONS##
#################################################### 

def drift_diffusion_sawb_c(C_path, T_span_tensor, I_S_tensor, I_D_tensor, sawb_c_params_dict, temp_ref):
    '''
    Returns SAWB "constant diffusion parameterization" drift vectors and diffusion matrices.
    current_temp is output from temp_gen function. 
    Expected sawb_c_params_dict = {'u_Q_ref': u_Q_ref, 'Q': Q, 'a_MSA': a_MSA, 'K_D': K_D, 'K_U': K_U, 'V_D_ref': V_D_ref, 'V_U_ref': V_U_ref, 'Ea_V_D': Ea_V_D, 'Ea_V_U': Ea_V_U, 'r_M': r_M, 'r_E': r_E, 'r_L': r_L, 'c_SOC': c_SOC, 'c_DOC': c_DOC, 'c_MBC': c_MBC, 'c_EEC': c_EEC, 'c_CO2': c_CO2}
    '''
    state_dim = 5 #SAWB and AWB family variants will have 5 'states' with the inclusion of CO2.
    SOC, DOC, MBC, EEC, CO2 =  torch.chunk(C_path, state_dim, -1) #Partition SOC, DOC, MBC, EEC, and CO2 values. Split based on final C_path dim, which specifies state variables and is also indexed as dim #2 in tensor. 
    current_temp = temp_gen(T_span_tensor, temp_ref) #Obtain temperature function vector across span of times.
    drift = torch.empty_like(C_path, device=C_path.device) #Initiate tensor with same dims as C_path to assign drift.
    #Decay parameters are forced by temperature changes.
    u_Q = linear_temp_dep(sawb_c_params_dict['u_Q_ref'], current_temp, sawb_c_params_dict['Q'], temp_ref) #Apply linear temperature-dependence to u_Q.
    V_D = arrhenius_temp_dep(sawb_c_params_dict['V_D_ref'], current_temp, sawb_c_params_dict['Ea_V_D'], temp_ref) #Apply vectorized temperature-dependent transformation to V_D.
    V_U = arrhenius_temp_dep(sawb_c_params_dict['V_U_ref'], current_temp, sawb_c_params_dict['Ea_V_U'], temp_ref) #Apply vectorized temperature-dependent transformation to V_U.
    #Drift is calculated.
    drift_SOC = I_S_tensor + sawb_c_params_dict['a_MSA'] * sawb_c_params_dict['r_M'] * MBC - ((V_D * EEC * SOC) / (sawb_c_params_dict['K_D'] + SOC))
    drift_DOC = I_D_tensor + (1 - sawb_c_params_dict['a_MSA']) * sawb_c_params_dict['r_M'] * MBC + ((V_D * EEC * SOC) / (sawb_c_params_dict['K_D'] + SOC)) + sawb_c_params_dict['r_L'] * EEC - ((V_U * MBC * DOC) / (sawb_c_params_dict['K_U'] + DOC))
    drift_MBC = (u_Q * (V_U * MBC * DOC) / (sawb_c_params_dict['K_U'] + DOC)) - (sawb_c_params_dict['r_M'] + sawb_c_params_dict['r_E']) * MBC
    drift_EEC = sawb_c_params_dict['r_E'] * MBC - sawb_c_params_dict['r_L'] * EEC
    CO2 = (1 - u_Q) * (V_U * MBC * DOC) / (sawb_c_params_dict['K_U'] + DOC)
    #Assign elements to drift vector.
    drift[:, :, 0 : 1] = drift_SOC
    drift[:, :, 1 : 2] = drift_DOC
    drift[:, :, 2 : 3] = drift_MBC
    drift[:, :, 3 : 4] = drift_EEC
    drift[:, :, 4 : 5] = CO2 #CO2 is not a part of the drift. This is a hack for the explicit algebraic variable situation.
    #Diffusion matrix is assigned.
    diffusion_sqrt_single = torch.diag(torch.sqrt(LowerBound.apply(torch.as_tensor([sawb_c_params_dict['c_SOC'], sawb_c_params_dict['c_DOC'], sawb_c_params_dict['c_MBC'], sawb_c_params_dict['c_EEC'], sawb_c_params_dict['c_CO2']]), 1e-9))) #Create single diffusion matrix by diagonalizing constant noise scale parameters.
    diffusion_sqrt = diffusion_sqrt_single.expand(drift.size(0), drift.size(1), state_dim, state_dim) #Expand diffusion matrices across all paths and across discretized time steps. Diffusion exists for explicit algebraic variable CO2.
    return drift, diffusion_sqrt

In [9]:
class LowerBound(Function):
    @staticmethod
    def forward(ctx, inputs, bound):
        b = torch.ones(inputs.size()) * bound
        b = b.to(inputs.device)
        b = b.type(inputs.dtype)
        ctx.save_for_backward(inputs, b)
        return torch.max(inputs, b)
    @staticmethod
    def backward(ctx, grad_output):
        inputs, b = ctx.saved_tensors

        pass_through_1 = inputs >= b
        pass_through_2 = grad_output < 0

        pass_through = pass_through_1 | pass_through_2
        return pass_through.type(grad_output.dtype) * grad_output, None

class MaskedConv1d(nn.Conv1d):
    def __init__(self, mask_type, *args, **kwargs):
        super(MaskedConv1d, self).__init__(*args, **kwargs)
        assert mask_type in {'A', 'B'}
        self.register_buffer('mask', self.weight.data.clone())
        _, _, kW = self.weight.size()
        self.mask.fill_(1)
        self.mask[:, :, kW // 2 + 1*(mask_type == 'B'):] = 0

    def forward(self, x):
        self.weight.data *= self.mask
        return super(MaskedConv1d, self).forward(x)

class ResNetBlock(nn.Module):

    def __init__(self, inp_cha, out_cha, stride = 1, first = True, batch_norm = True):
        super().__init__()
        self.conv1 = MaskedConv1d('A' if first else 'B', inp_cha,  out_cha, 3, stride, 1, bias = False)
        self.conv2 = MaskedConv1d('B', out_cha,  out_cha, 3, 1, 1, bias = False)

        self.act1 = nn.PReLU(out_cha, init = 0.2)
        self.act2 = nn.PReLU(out_cha, init = 0.2)

        if batch_norm:
            self.bn1 = nn.BatchNorm1d(out_cha)
            self.bn2 = nn.BatchNorm1d(out_cha)
        else:
            self.bn1 = nn.Identity()
            self.bn2 = nn.Identity()

        # If dimensions change, transform shortcut with a conv layer
        if inp_cha != out_cha or stride > 1:
            self.conv_skip = MaskedConv1d('A' if first else 'B', inp_cha,  out_cha, 3, stride, 1, bias = False)
        else:
            self.conv_skip = nn.Identity()

    def forward(self, x):
        residual = x
        x = self.act1(self.bn1(self.conv1(x)))
        x = self.act2(self.bn2(self.conv2(x) + self.conv_skip(residual)))
        return x

class ResNetBlockUnMasked(nn.Module):
    
    def __init__(self, inp_cha, out_cha, stride = 1, batch_norm = False):
        super().__init__()
        self.conv1 = nn.Conv1d(inp_cha,  out_cha, 3, stride, 1)
        self.conv2 = nn.Conv1d(out_cha,  out_cha, 3, 1, 1)

        self.act1 = nn.PReLU(out_cha, init = 0.2)
        self.act2 = nn.PReLU(out_cha, init = 0.2)

        if batch_norm:
            self.bn1 = nn.BatchNorm1d(out_cha)
            self.bn2 = nn.BatchNorm1d(out_cha)
        else:
            self.bn1 = nn.Identity()
            self.bn2 = nn.Identity()

        # If dimensions change, transform shortcut with a conv layer
        if inp_cha != out_cha or stride > 1:
            self.conv_skip = nn.Conv1d(inp_cha,  out_cha, 3, stride, 1, bias=False)
        else:
            self.conv_skip = nn.Identity()

    def forward(self, x):
        residual = x
        x = self.act1(self.bn1(self.conv1(x)))
        x = self.act2(self.bn2(self.conv2(x) + self.conv_skip(residual)))
        return x

class CouplingLayer(nn.Module):

    def __init__(self, cond_inputs, stride, h_cha = 96):
        super().__init__()
        self.first_block = ResNetBlock(1, h_cha, first = True)
        self.second_block = nn.Sequential(ResNetBlock(h_cha + cond_inputs, h_cha, first = False),
                                          MaskedConv1d('B', h_cha,  2, 3, stride, 1, bias = False))

        self.feature_net = nn.Sequential(ResNetBlockUnMasked(cond_inputs, h_cha),
                                          ResNetBlockUnMasked(h_cha, cond_inputs))
        
        self.unpack = True if cond_inputs > 1 else False

    def forward(self, x, cond_inputs):
        if self.unpack:
            cond_inputs = torch.cat([*cond_inputs], 1)
        cond_inputs = self.feature_net(cond_inputs)
        first_block = self.first_block(x)
        feature_vec = torch.cat([first_block, cond_inputs], 1)
        output = self.second_block(feature_vec)
        mu, sigma = torch.chunk(output, 2, 1)
        sigma = LowerBound.apply(sigma, 1e-6)
        x = mu + sigma*x
        return x, -torch.log(sigma)

class PermutationLayer(nn.Module):

    def __init__(self):
        super().__init__()
        self.index_1 = torch.randperm(STATE_DIM)

    def forward(self, x):
        B, S, L = x.shape
        x_reshape = x.reshape(B, S, -1, STATE_DIM)
        x_perm = x_reshape[:, :, :, self.index_1]
        x = x_perm.reshape(B, S, L)
        return x

class SoftplusLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.softplus = nn.Softplus()
    
    def forward(self, x):
        y = self.softplus(x)
        return y, -torch.log(-torch.expm1(-y))

class BatchNormLayer(nn.Module):
    def __init__(self, num_inputs, momentum = 0.0, eps = 1e-5):
        super(BatchNormLayer, self).__init__()

        self.log_gamma = nn.Parameter(torch.rand(num_inputs))
        self.beta = nn.Parameter(torch.rand(num_inputs))
        self.momentum = momentum
        self.eps = eps

        self.register_buffer('running_mean', torch.zeros(num_inputs))
        self.register_buffer('running_var', torch.ones(num_inputs))

    def forward(self, inputs):
        inputs = inputs.squeeze(1)
        if self.training:
            self.batch_mean = inputs.mean(0)
            self.batch_var = (
                inputs - self.batch_mean).pow(2).mean(0) + self.eps

            self.running_mean.mul_(self.momentum)
            self.running_var.mul_(self.momentum)

            self.running_mean.add_(self.batch_mean.data *
                                   (1 - self.momentum))
            self.running_var.add_(self.batch_var.data *
                                  (1 - self.momentum))

            mean = self.batch_mean
            var = self.batch_var
        else:
            mean = self.running_mean
            var = self.running_var

        x_hat = (inputs - mean) / var.sqrt()
        y = torch.exp(self.log_gamma) * x_hat + self.beta
        ildj = -self.log_gamma + 0.5 * torch.log(var)
        return y[:, None, :], ildj[None, None, :]

In [10]:
class SDEFlow(nn.Module):

    def __init__(self, cond_inputs = 1, num_layers = 5):
        super().__init__()
        
        self.coupling = nn.ModuleList([CouplingLayer(cond_inputs + STATE_DIM, 1) for _ in range(num_layers)])
        self.permutation = [PermutationLayer() for _ in range(num_layers)]
        self.batch_norm = nn.ModuleList([BatchNormLayer(STATE_DIM * N) for _ in range(num_layers-1)])
        self.SP = SoftplusLayer()
        
        self.base_dist = d.normal.Normal(loc = 0., scale = 1.)
        self.num_layers = num_layers
        
    def forward(self, batch_size, obs_model, *args, **kwargs):

        eps = self.base_dist.sample([batch_size, 1, STATE_DIM * N]).to(device)
        log_prob = self.base_dist.log_prob(eps).sum(-1)
        
        obs_tile = obs_model.mu[None, :, 1:, None].repeat(batch_size, STATE_DIM, 1, 50).reshape(batch_size, STATE_DIM, -1)
        times = torch.arange(dt, T + dt, dt, device = eps.device)[(None,) * 2].repeat(batch_size, STATE_DIM, 1).transpose(-2, -1).reshape(batch_size, 1, -1)
        
        ildjs = []
        
        for i in range(self.num_layers):
            eps, cl_ildj = self.coupling[i](self.permutation[i](eps), (obs_tile, times))
            if i < (self.num_layers - 1):
                eps, bn_ildj = self.batch_norm[i](eps)
                ildjs.append(bn_ildj)
            ildjs.append(cl_ildj)
                
        eps, sp_ildj = self.SP(eps)
        ildjs.append(sp_ildj)
        
        for ildj in ildjs:
            log_prob += ildj.sum(-1)
    
        return eps.reshape(batch_size, STATE_DIM, -1).permute(0, 2, 1) + 1e-9, log_prob

In [11]:
class ObsModel(nn.Module):

    def __init__(self, times, mu, scale):
        super().__init__()

        self.idx = self._get_idx(times)
        self.times = times
        self.mu = torch.Tensor(mu).to(device)
        self.scale = scale
        
    def forward(self, x):
        obs_ll = d.normal.Normal(self.mu.permute(1, 0), self.scale).log_prob(x[:, self.idx, :])
        return torch.sum(obs_ll, [-1, -2]).mean()

    def _get_idx(self, times):
        return list((times / dt).astype(int))
    
    def plt_dat(self):
        return self.mu, self.times

In [12]:
def neg_log_lik(C_path, T_span_tensor, dt, I_S_tensor, I_D_tensor, drift_diffusion, params_dict, temp_ref):
    drift, diffusion_sqrt = drift_diffusion(C_path[:, :-1, :], T_span_tensor[:, :-1, :], I_S_tensor[:, :-1, :], I_D_tensor[:, :-1, :], params_dict, temp_ref)
    #print('\n drift =', drift)
    #print('\n diffusion_sqrt =', diffusion_sqrt)
    #euler_maruyama_sample = d.multivariate_normal.MultivariateNormal(loc = C_path[:, :-1, :] + drift * dt, scale_tril = diffusion_sqrt * math.sqrt(dt)) This line no longer applies because of addition of CO2 as a 'state'.
    drift_means_with_CO2 = torch.cat((C_path[:, :-1, :-1] + drift[:, :, :-1] * dt, drift[:, :, -1].unsqueeze(2)), 2) #Separate explicit algebraic variable CO2 mean from integration process.
    euler_maruyama_sample = d.multivariate_normal.MultivariateNormal(loc = drift_means_with_CO2, scale_tril = diffusion_sqrt * math.sqrt(dt))
    return -euler_maruyama_sample.log_prob(C_path[:, 1:, :]).sum(-1)

In [13]:
obs_model = ObsModel(times = obs_times, mu = obs_means_awb_T, scale = obs_error_scale_awb.reshape([1, STATE_DIM]))
net = SDEFlow().to(device)
optimizer = optim.Adam(net.parameters(), lr = LR)

def train(niter, pretrain_iter, BATCH_SIZE, T_span_tensor, I_S_tensor, I_D_tensor, drift_diffusion, params_dict, analytical_steady_state_init):
    if pretrain_iter >= niter:
        raise Exception("pretrain_inter must be < niter.")
    best_loss_norm = 1e10
    best_loss_ELBO = 1e20
    norm_losses = [best_loss_norm] * 10
    ELBO_losses = [best_loss_ELBO] * 10
    C0 = analytical_steady_state_init(I_S_tensor[0, 0, 0].item(), I_D_tensor[0, 0, 0].item(), params_dict) #Calculate deterministic initial conditions.
    C0 = C0[(None,) * 2].repeat(BATCH_SIZE, 1, 1).to(device) #Assign initial conditions to C_path.
    with tqdm(total = niter, desc = f'Train Diffusion', position = -1) as t:
        for iter in range(niter):
            net.train()
            optimizer.zero_grad()
            C_path, log_prob = net(BATCH_SIZE, obs_model) #Obtain paths with solutions at times after t0.
            C_path = torch.cat([C0, C_path], 1) #Append deterministic CON initial conditions conditional on parameter values to C path. 
            if iter <= pretrain_iter:
                l1_norm_element = C_path - torch.mean(obs_model.mu, -1)
                l1_norm = torch.sum(torch.abs(l1_norm_element)).mean()
                best_loss_norm = l1_norm if l1_norm < best_loss_norm else best_loss_norm
                l1_norm.backward()
                norm_losses.append(l1_norm.item())
                #l2_norm_element = C_path - torch.mean(obs_model.mu, -1)
                #l2_norm = torch.sqrt(torch.sum(torch.square(l2_norm_element))).mean()
                #best_loss_norm = l2_norm if l2_norm < best_loss_norm else best_loss_norm
                #l2_norm.backward()
                #norm_losses.append(l2_norm.item())
                if len(norm_losses) > 10:
                    norm_losses.pop(0)
                if iter % 10 == 0:
                    print(f"Moving average norm loss at {iter} iterations is: {sum(norm_losses) / len(norm_losses)}. Best norm loss value is: {best_loss_norm}.")
                    print('\nC_path mean =', C_path.mean(-2))
                    print('\nC_path =', C_path)
            else:
                log_lik = neg_log_lik(C_path, T_span_tensor.to(device), dt, I_S_tensor.to(device), I_D_tensor.to(device), drift_diffusion, params_dict, temp_ref)
                ELBO = log_prob.mean() + log_lik.mean() - obs_model(C_path)
                best_loss_ELBO = ELBO if ELBO < best_loss_ELBO else best_loss_ELBO
                ELBO.backward()
                ELBO_losses.append(ELBO.item())
                if len(ELBO_losses) > 10:
                    ELBO_losses.pop(0)
                if iter % 10 == 0:
                    print(f"Moving average ELBO loss at {iter} iterations is: {sum(ELBO_losses) / len(ELBO_losses)}. Best ELBO loss value is: {best_loss_ELBO}.")
                    print('\nC_path mean =', C_path.mean(-2))
                    print('\n C_path =', C_path)
            torch.nn.utils.clip_grad_norm_(net.parameters(), 3.0)
            optimizer.step()
            if iter % 100000 == 0 and iter > 0:
                optimizer.param_groups[0]['lr'] *= 0.1
            t.update()

In [14]:
train(niter, pretrain_iter, BATCH_SIZE, T_span_tensor, I_S_tensor, I_D_tensor, drift_diffusion_sawb_c, sawb_c_params_dict, analytical_steady_state_init_awb)


Train Diffusion:   0%|          | 0/20000 [00:00<?, ?it/s][A
Train Diffusion:   0%|          | 1/20000 [00:04<23:17:36,  4.19s/it][A

Moving average norm loss at 0 iterations is: 9000026993.86875. Best norm loss value is: 269938.6875.

C_path mean = tensor([[0.9219, 0.6074, 0.6382, 0.5924, 0.5081]], grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3602e+01, 1.1421e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [7.9968e-01, 7.1134e-01, 6.3181e-01, 5.6100e-01, 5.3264e-01],
         [8.4875e-01, 6.5099e-01, 6.1520e-01, 6.3112e-01, 4.6978e-01],
         ...,
         [8.1544e-01, 6.3414e-01, 6.3061e-01, 4.6991e-01, 5.0222e-01],
         [9.2016e-01, 6.3638e-01, 6.4990e-01, 6.3152e-01, 5.7690e-01],
         [6.9342e-01, 6.7947e-01, 6.9055e-01, 6.2434e-01, 8.0328e-01]]],
       grad_fn=<CatBackward>)



Train Diffusion:   0%|          | 2/20000 [00:08<23:38:39,  4.26s/it][A
Train Diffusion:   0%|          | 3/20000 [00:12<22:42:43,  4.09s/it][A
Train Diffusion:   0%|          | 4/20000 [00:16<23:07:09,  4.16s/it][A
Train Diffusion:   0%|          | 5/20000 [00:22<25:51:03,  4.65s/it][A
Train Diffusion:   0%|          | 6/20000 [00:27<26:27:49,  4.76s/it][A
Train Diffusion:   0%|          | 7/20000 [00:32<27:09:09,  4.89s/it][A
Train Diffusion:   0%|          | 8/20000 [00:37<27:25:28,  4.94s/it][A
Train Diffusion:   0%|          | 9/20000 [00:41<26:03:04,  4.69s/it][A
Train Diffusion:   0%|          | 10/20000 [00:46<26:53:44,  4.84s/it][A
Train Diffusion:   0%|          | 11/20000 [00:51<26:03:43,  4.69s/it][A

Moving average norm loss at 10 iterations is: 256604.334375. Best norm loss value is: 246045.390625.

C_path mean = tensor([[5.0930, 0.6012, 0.4458, 0.2196, 0.0899]], grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3602e+01, 1.1421e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [1.3688e+00, 3.1588e+00, 5.3021e-01, 3.3386e-01, 1.4043e-01],
         [3.8144e+00, 2.1069e+00, 5.6176e-01, 3.1555e-01, 1.4001e-01],
         ...,
         [3.4356e+00, 5.4972e-01, 3.2355e-01, 1.3975e-01, 5.3773e-02],
         [4.7199e+00, 5.3895e-01, 3.2745e-01, 1.3621e-01, 6.3541e-02],
         [3.7616e+00, 5.3937e-01, 3.2128e-01, 1.3499e-01, 9.2099e-02]]],
       grad_fn=<CatBackward>)


Train Diffusion:   0%|          | 11/20000 [00:54<27:29:34,  4.95s/it]


KeyboardInterrupt: 