In [15]:
import torch
from torch import nn
import torch.distributions as d
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import math
from tqdm import tqdm
import random
from torch.autograd import Function
# from torch.utils.tensorboard import SummaryWriter
import argparse
import os
import sys
from pathlib import Path
import shutil
import pandas as pd

In [42]:
torch.manual_seed(0)
STATE_DIM = 5 #Including fake state CO2.
CUDA_ID = 1
dt = .2
T = 1000 #Run simulation for 1000 hours.
N = int(T / dt) 
T_span = np.linspace(0, T, N + 1)
T_span_tensor = torch.reshape(torch.Tensor(T_span), [1, N + 1, 1]) #T_span needs to be converted to tensor object. Additionally, facilitates conversion of I_S and I_D to tensor objects.

device = torch.device("".join(["cuda:",f'{CUDA_ID}']) if torch.cuda.is_available() else "cpu")
LR = 1e-3
niter = 20000
pretrain_iter = 1000
BATCH_SIZE = 1

In [43]:
obs_df_awb_eca_full = pd.read_csv('awb_eca_synthetic_sol_df.csv') #Must be link to raw Github output if in Colab.
obs_df_awb_eca = obs_df_awb_eca_full[obs_df_awb_eca_full['hour'] <= T] #Test with just first T hours of data.

In [44]:
obs_times = np.array(obs_df_awb_eca['hour'])
obs_means_awb_eca = torch.Tensor(np.array(obs_df_awb_eca.drop(columns = 'hour')))
obs_means_awb_eca_T = obs_means_awb_eca.T
obs_error_scale_awb_eca = torch.mean(obs_means_awb_eca_T, 1) * 0.1 #Observation noise set at 10% of respective observation means.

In [45]:
I_S_tensor = 0.001 + 0.0005 * torch.sin((2 * math.pi / (24 * 365)) * T_span_tensor) #Exogenous SOC input function
I_D_tensor = 0.0001 + 0.00005 * torch.sin((2 * math.pi / (24 * 365)) * T_span_tensor) #Exogenous DOC input function

In [46]:
temp_ref = 283

#System parameters from deterministic model
u_Q_ref = 0.2
Q = 0.002
a_MSA = 0.5
K_DE = 200
K_UE = 1
V_DE_ref = 0.4
V_UE_ref = 0.02
Ea_V_DE = 75
Ea_V_UE = 50
r_M = 0.0004
r_E = 0.00001
r_L = 0.0005

#Diffusion matrix sigma scale parameters
s_SOC = 0.01
s_DOC = 0.01
s_MBC = 0.01
s_EEC = 0.01
s_CO2 = 0.01

sawb_eca_ss_params_dict = {'u_Q_ref': u_Q_ref, 'Q': Q, 'a_MSA': a_MSA, 'K_DE': K_DE, 'K_UE': K_UE, 'V_DE_ref': V_DE_ref, 'V_UE_ref': V_UE_ref, 'Ea_V_DE': Ea_V_DE, 'Ea_V_UE': Ea_V_UE, 'r_M': r_M, 'r_E': r_E, 'r_L': r_L, 's_SOC': s_SOC, 's_DOC': s_DOC, 's_MBC': s_MBC, 's_EEC': s_EEC, 's_CO2': s_CO2}

In [47]:
############################################################
##SOIL BIOGEOCHEMICAL MODEL TEMPERATURE RESPONSE FUNCTIONS##
############################################################

def temp_gen(t, temp_ref):
    temp = temp_ref + t / (20 * 24 * 365) + 10 * torch.sin((2 * np.pi / 24) * t) + 10 * torch.sin((2 * math.pi / (24 * 365)) * t)
    return temp

def arrhenius_temp_dep(parameter, temp, Ea, temp_ref):
    '''
    For a parameter with Arrhenius temperature dependence, returns the transformed parameter value.
    0.008314 is the gas constant. Temperatures are in K.
    '''
    decayed_parameter = parameter * torch.exp(-Ea / 0.008314 * (1 / temp - 1 / temp_ref))
    return decayed_parameter

def linear_temp_dep(parameter, temp, Q, temp_ref):
    '''
    For a parameter with linear temperature dependence, returns the transformed parameter value.
    Q is the slope of the temperature dependence and is a varying parameter.
    Temperatures are in K.
    '''
    modified_parameter = parameter - Q * (temp - temp_ref)
    return modified_parameter

##########################################################################
##DETERMINISTIC SOIL BIOGEOCHEMICAL MODEL INITIAL STEADY STATE ESTIMATES##
##########################################################################

def analytical_steady_state_init_awb_eca(SOC_input, DOC_input, sawb_eca_ss_params_dict = sawb_eca_ss_params_dict):
    '''
    Returns a vector of C pool values to initialize an SAWB-ECA system corresponding to set of parameter values using the analytical steady state equations of the deterministic CON system.
    Vector elements are in order of S_0, D_0, M_0, E_0, and CO2_0.
    Expected sawb_eca_ss_params_dict = {'u_Q_ref': u_Q_ref, 'Q': Q, 'a_MSA': a_MSA, 'K_DE': K_DE, 'K_UE': K_UE, 'V_DE_ref': V_DE_ref, 'V_UE_ref': V_UE_ref, 'Ea_V_DE': Ea_V_DE, 'Ea_V_UE': Ea_V_UE, 'r_M': r_M, 'r_E': r_E, 'r_L': r_L, 's_SOC': s_SOC, 's_DOC': s_DOC, 's_MBC': s_MBC, 's_EEC': s_EEC, 's_CO2': s_CO2}
    '''
    S_0 = ((-sawb_eca_ss_params_dict['K_DE'] * sawb_eca_ss_params_dict['r_L'] * (sawb_eca_ss_params_dict['r_E'] + sawb_eca_ss_params_dict['r_M']) * (sawb_eca_ss_params_dict['u_Q_ref'] - 1) + sawb_eca_ss_params_dict['r_E'] * sawb_eca_ss_params_dict['u_Q_ref'] * (SOC_input + DOC_input)) * (SOC_input * sawb_eca_ss_params_dict['r_E'] * (sawb_eca_ss_params_dict['u_Q_ref'] - 1) - sawb_eca_ss_params_dict['a_MSA'] * DOC_input * sawb_eca_ss_params_dict['r_M'] * sawb_eca_ss_params_dict['u_Q_ref'] + SOC_input * sawb_eca_ss_params_dict['r_M'] * (sawb_eca_ss_params_dict['u_Q_ref'] - sawb_eca_ss_params_dict['a_MSA'] * sawb_eca_ss_params_dict['u_Q_ref'] - 1))) / ((sawb_eca_ss_params_dict['r_E'] + sawb_eca_ss_params_dict['r_M']) * (sawb_eca_ss_params_dict['u_Q_ref'] - 1) * (DOC_input * sawb_eca_ss_params_dict['u_Q_ref'] * (sawb_eca_ss_params_dict['r_E'] * sawb_eca_ss_params_dict['V_DE_ref'] - sawb_eca_ss_params_dict['a_MSA'] * sawb_eca_ss_params_dict['r_L'] * sawb_eca_ss_params_dict['r_M']) + SOC_input * (sawb_eca_ss_params_dict['r_E'] * sawb_eca_ss_params_dict['r_L'] * (sawb_eca_ss_params_dict['u_Q_ref'] - 1) + sawb_eca_ss_params_dict['r_L'] * sawb_eca_ss_params_dict['r_M'] * (sawb_eca_ss_params_dict['u_Q_ref'] - sawb_eca_ss_params_dict['a_MSA'] * sawb_eca_ss_params_dict['u_Q_ref'] - 1) + sawb_eca_ss_params_dict['r_E'] * sawb_eca_ss_params_dict['u_Q_ref'] * sawb_eca_ss_params_dict['V_DE_ref'])))
    D_0 = -(sawb_eca_ss_params_dict['K_UE'] * (sawb_eca_ss_params_dict['r_E'] + sawb_eca_ss_params_dict['r_M']) * (sawb_eca_ss_params_dict['u_Q_ref'] - 1) - (SOC_input + DOC_input) * sawb_eca_ss_params_dict['u_Q_ref']) / ((sawb_eca_ss_params_dict['u_Q_ref'] - 1) * (sawb_eca_ss_params_dict['r_E'] + sawb_eca_ss_params_dict['r_M'] - sawb_eca_ss_params_dict['u_Q_ref'] * sawb_eca_ss_params_dict['V_UE_ref']))
    M_0 = -((SOC_input + DOC_input) * sawb_eca_ss_params_dict['u_Q_ref']) / ((sawb_eca_ss_params_dict['r_E'] + sawb_eca_ss_params_dict['r_M']) * (sawb_eca_ss_params_dict['u_Q_ref'] - 1))
    E_0 = sawb_eca_ss_params_dict['r_E'] * M_0 / sawb_eca_ss_params_dict['r_L']
    #E_0 = -((sawb_params_dict['r_E'] * sawb_params_dict['u_Q_ref'] * (SOC_input + DOC_input)) / (sawb_params_dict['r_L'] * (sawb_params_dict['r_E'] + sawb_params_dict['r_M']) * (sawb_params_dict['u_Q_ref'] - 1)))
    CO2_0 = (1 - sawb_eca_ss_params_dict['u_Q_ref']) * sawb_eca_ss_params_dict['V_UE_ref'] * M_0 * D_0 / (sawb_eca_ss_params_dict['K_UE'] + M_0 + D_0)
    C_0_vector = torch.as_tensor([S_0, D_0, M_0, E_0, CO2_0])
    return C_0_vector

In [48]:
####################################################
##STOCHASTIC DIFFERENTIAL EQUATION MODEL FUNCTIONS##
#################################################### 

def drift_diffusion_sawb_eca_ss(C_path, T_span_tensor, I_S_tensor, I_D_tensor, sawb_eca_ss_params_dict, temp_ref):
    '''
    Returns SAWB-ECA "state scaling diffusion parameterization" drift vectors and diffusion matrices.
    current_temp is output from temp_gen function. 
    Expected sawb_eca_ss_params_dict = {'u_Q_ref': u_Q_ref, 'Q': Q, 'a_MSA': a_MSA, 'K_DE': K_DE, 'K_UE': K_UE, 'V_DE_ref': V_DE_ref, 'V_UE_ref': V_UE_ref, 'Ea_V_DE': Ea_V_DE, 'Ea_V_UE': Ea_V_UE, 'r_M': r_M, 'r_E': r_E, 'r_L': r_L, 's_SOC': s_SOC, 's_DOC': s_DOC, 's_MBC': s_MBC, 's_EEC': s_EEC, 's_CO2': s_CO2}
    '''
    state_dim = 5 #SAWB and AWB family variants will have 5 'states' with the inclusion of CO2.
    SOC, DOC, MBC, EEC, CO2 =  torch.chunk(C_path, state_dim, -1) #Partition SOC, DOC, MBC, EEC, and CO2 values. Split based on final C_path dim, which specifies state variables and is also indexed as dim #2 in tensor. 
    current_temp = temp_gen(T_span_tensor, temp_ref) #Obtain temperature function vector across span of times.
    drift = torch.empty_like(C_path, device=C_path.device) #Initiate tensor with same dims as C_path to assign drift.
    diffusion_sqrt = torch.zeros([drift.size(0), drift.size(1), state_dim, state_dim], device = drift.device) #Create tensor to assign diffusion matrix elements. Diffusion exists for explicit algebraic variable CO2.
    #diffusion_sqrt_diag = torch.empty_like(C_path, device=C_path.device) #Create tensor to assign diffusion matrix elements.
    #Decay parameters are forced by temperature changes.
    u_Q = linear_temp_dep(sawb_eca_ss_params_dict['u_Q_ref'], current_temp, sawb_eca_ss_params_dict['Q'], temp_ref) #Apply linear temperature-dependence to u_Q.
    V_DE = arrhenius_temp_dep(sawb_eca_ss_params_dict['V_DE_ref'], current_temp, sawb_eca_ss_params_dict['Ea_V_DE'], temp_ref) #Apply vectorized temperature-dependent transformation to V_DE.
    V_UE = arrhenius_temp_dep(sawb_eca_ss_params_dict['V_UE_ref'], current_temp, sawb_eca_ss_params_dict['Ea_V_UE'], temp_ref) #Apply vectorized temperature-dependent transformation to V_UE.
    #Drift is calculated.
    drift_SOC = I_S_tensor + sawb_eca_ss_params_dict['a_MSA'] * sawb_eca_ss_params_dict['r_M'] * MBC - ((V_DE * EEC * SOC) / (sawb_eca_ss_params_dict['K_DE'] + EEC + SOC))
    drift_DOC = I_D_tensor + (1 - sawb_eca_ss_params_dict['a_MSA']) * sawb_eca_ss_params_dict['r_M'] * MBC + ((V_DE * EEC * SOC) / (sawb_eca_ss_params_dict['K_DE'] + EEC + SOC)) + sawb_eca_ss_params_dict['r_L'] * EEC - ((V_UE * MBC * DOC) / (sawb_eca_ss_params_dict['K_UE'] + MBC + DOC))
    drift_MBC = (u_Q * (V_UE * MBC * DOC) / (sawb_eca_ss_params_dict['K_UE'] + MBC + DOC)) - (sawb_eca_ss_params_dict['r_M'] + sawb_eca_ss_params_dict['r_E']) * MBC
    drift_EEC = sawb_eca_ss_params_dict['r_E'] * MBC - sawb_eca_ss_params_dict['r_L'] * EEC
    CO2 = (1 - u_Q) * (V_UE * MBC * DOC) / (sawb_eca_ss_params_dict['K_UE'] + MBC + DOC)
    #Assign elements to drift vector.
    drift[:, :, 0 : 1] = drift_SOC
    drift[:, :, 1 : 2] = drift_DOC
    drift[:, :, 2 : 3] = drift_MBC
    drift[:, :, 3 : 4] = drift_EEC
    drift[:, :, 4 : 5] = CO2 #CO2 is not a part of the drift. This is a hack for the explicit algebraic variable situation.
    #Diffusion matrix is assigned.
    diffusion_sqrt[:, :, 0 : 1, 0] = torch.sqrt(LowerBound.apply(SOC * sawb_eca_ss_params_dict['s_SOC'], 1e-9)) #SOC diffusion standard deviation
    diffusion_sqrt[:, :, 1 : 2, 1] = torch.sqrt(LowerBound.apply(DOC * sawb_eca_ss_params_dict['s_DOC'], 1e-9)) #DOC diffusion standard deviation
    diffusion_sqrt[:, :, 2 : 3, 2] = torch.sqrt(LowerBound.apply(MBC * sawb_eca_ss_params_dict['s_MBC'], 1e-9)) #MBC diffusion standard deviation
    diffusion_sqrt[:, :, 3 : 4, 3] = torch.sqrt(LowerBound.apply(EEC * sawb_eca_ss_params_dict['s_EEC'], 1e-9)) #EEC diffusion standard deviation
    diffusion_sqrt[:, :, 4 : 5, 4] = torch.sqrt(LowerBound.apply(CO2 * sawb_eca_ss_params_dict['s_CO2'], 1e-9)) #CO2 diffusion standard deviation
    #diffusion_sqrt_diag[:, :, 0 : 1] = torch.sqrt(LowerBound.apply(SOC * sawb_eca_ss_params_dict['s_SOC'], 1e-9)) #SOC diffusion standard deviation
    #diffusion_sqrt_diag[:, :, 1 : 2] = torch.sqrt(LowerBound.apply(DOC * sawb_eca_ss_params_dict['s_DOC'], 1e-9)) #DOC diffusion standard deviation
    #diffusion_sqrt_diag[:, :, 2 : 3] = torch.sqrt(LowerBound.apply(MBC * sawb_eca_ss_params_dict['s_MBC'], 1e-9)) #MBC diffusion standard deviation
    #diffusion_sqrt_diag[:, :, 3 : 4] = torch.sqrt(LowerBound.apply(EEC * sawb_eca_ss_params_dict['s_EEC'], 1e-9)) #EEC diffusion standard deviation
    #diffusion_sqrt_diag[:, :, 4 : 5] = torch.sqrt(LowerBound.apply(CO2 * sawb_eca_ss_params_dict['s_CO2'], 1e-9)) #CO2 diffusion standard deviation
    #diffusion_sqrt = torch.diag_embed(diffusion_sqrt_diag)
    return drift, diffusion_sqrt

In [49]:
class LowerBound(Function):
    @staticmethod
    def forward(ctx, inputs, bound):
        b = torch.ones(inputs.size()) * bound
        b = b.to(inputs.device)
        b = b.type(inputs.dtype)
        ctx.save_for_backward(inputs, b)
        return torch.max(inputs, b)
    @staticmethod
    def backward(ctx, grad_output):
        inputs, b = ctx.saved_tensors

        pass_through_1 = inputs >= b
        pass_through_2 = grad_output < 0

        pass_through = pass_through_1 | pass_through_2
        return pass_through.type(grad_output.dtype) * grad_output, None

class MaskedConv1d(nn.Conv1d):
    def __init__(self, mask_type, *args, **kwargs):
        super(MaskedConv1d, self).__init__(*args, **kwargs)
        assert mask_type in {'A', 'B'}
        self.register_buffer('mask', self.weight.data.clone())
        _, _, kW = self.weight.size()
        self.mask.fill_(1)
        self.mask[:, :, kW // 2 + 1*(mask_type == 'B'):] = 0

    def forward(self, x):
        self.weight.data *= self.mask
        return super(MaskedConv1d, self).forward(x)

class ResNetBlock(nn.Module):

    def __init__(self, inp_cha, out_cha, stride = 1, first = True, batch_norm = True):
        super().__init__()
        self.conv1 = MaskedConv1d('A' if first else 'B', inp_cha,  out_cha, 3, stride, 1, bias = False)
        self.conv2 = MaskedConv1d('B', out_cha,  out_cha, 3, 1, 1, bias = False)

        self.act1 = nn.PReLU(out_cha, init = 0.2)
        self.act2 = nn.PReLU(out_cha, init = 0.2)

        if batch_norm:
            self.bn1 = nn.BatchNorm1d(out_cha)
            self.bn2 = nn.BatchNorm1d(out_cha)
        else:
            self.bn1 = nn.Identity()
            self.bn2 = nn.Identity()

        # If dimensions change, transform shortcut with a conv layer
        if inp_cha != out_cha or stride > 1:
            self.conv_skip = MaskedConv1d('A' if first else 'B', inp_cha,  out_cha, 3, stride, 1, bias = False)
        else:
            self.conv_skip = nn.Identity()

    def forward(self, x):
        residual = x
        x = self.act1(self.bn1(self.conv1(x)))
        x = self.act2(self.bn2(self.conv2(x) + self.conv_skip(residual)))
        return x

class ResNetBlockUnMasked(nn.Module):
    
    def __init__(self, inp_cha, out_cha, stride = 1, batch_norm = False):
        super().__init__()
        self.conv1 = nn.Conv1d(inp_cha,  out_cha, 3, stride, 1)
        self.conv2 = nn.Conv1d(out_cha,  out_cha, 3, 1, 1)

        self.act1 = nn.PReLU(out_cha, init = 0.2)
        self.act2 = nn.PReLU(out_cha, init = 0.2)

        if batch_norm:
            self.bn1 = nn.BatchNorm1d(out_cha)
            self.bn2 = nn.BatchNorm1d(out_cha)
        else:
            self.bn1 = nn.Identity()
            self.bn2 = nn.Identity()

        # If dimensions change, transform shortcut with a conv layer
        if inp_cha != out_cha or stride > 1:
            self.conv_skip = nn.Conv1d(inp_cha,  out_cha, 3, stride, 1, bias=False)
        else:
            self.conv_skip = nn.Identity()

    def forward(self, x):
        residual = x
        x = self.act1(self.bn1(self.conv1(x)))
        x = self.act2(self.bn2(self.conv2(x) + self.conv_skip(residual)))
        return x

class CouplingLayer(nn.Module):

    def __init__(self, cond_inputs, stride, h_cha = 96):
        super().__init__()
        self.first_block = ResNetBlock(1, h_cha, first = True)
        self.second_block = nn.Sequential(ResNetBlock(h_cha + cond_inputs, h_cha, first = False),
                                          MaskedConv1d('B', h_cha,  2, 3, stride, 1, bias = False))

        self.feature_net = nn.Sequential(ResNetBlockUnMasked(cond_inputs, h_cha),
                                          ResNetBlockUnMasked(h_cha, cond_inputs))
        
        self.unpack = True if cond_inputs > 1 else False

    def forward(self, x, cond_inputs):
        if self.unpack:
            cond_inputs = torch.cat([*cond_inputs], 1)
        cond_inputs = self.feature_net(cond_inputs)
        first_block = self.first_block(x)
        feature_vec = torch.cat([first_block, cond_inputs], 1)
        output = self.second_block(feature_vec)
        mu, sigma = torch.chunk(output, 2, 1)
        sigma = LowerBound.apply(sigma, 1e-6)
        x = mu + sigma*x
        return x, -torch.log(sigma)

class PermutationLayer(nn.Module):

    def __init__(self):
        super().__init__()
        self.index_1 = torch.randperm(STATE_DIM)

    def forward(self, x):
        B, S, L = x.shape
        x_reshape = x.reshape(B, S, -1, STATE_DIM)
        x_perm = x_reshape[:, :, :, self.index_1]
        x = x_perm.reshape(B, S, L)
        return x

class SoftplusLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.softplus = nn.Softplus()
    
    def forward(self, x):
        y = self.softplus(x)
        return y, -torch.log(-torch.expm1(-y))

class BatchNormLayer(nn.Module):
    def __init__(self, num_inputs, momentum = 0.0, eps = 1e-5):
        super(BatchNormLayer, self).__init__()

        self.log_gamma = nn.Parameter(torch.rand(num_inputs))
        self.beta = nn.Parameter(torch.rand(num_inputs))
        self.momentum = momentum
        self.eps = eps

        self.register_buffer('running_mean', torch.zeros(num_inputs))
        self.register_buffer('running_var', torch.ones(num_inputs))

    def forward(self, inputs):
        inputs = inputs.squeeze(1)
        if self.training:
            self.batch_mean = inputs.mean(0)
            self.batch_var = (
                inputs - self.batch_mean).pow(2).mean(0) + self.eps

            self.running_mean.mul_(self.momentum)
            self.running_var.mul_(self.momentum)

            self.running_mean.add_(self.batch_mean.data *
                                   (1 - self.momentum))
            self.running_var.add_(self.batch_var.data *
                                  (1 - self.momentum))

            mean = self.batch_mean
            var = self.batch_var
        else:
            mean = self.running_mean
            var = self.running_var

        x_hat = (inputs - mean) / var.sqrt()
        y = torch.exp(self.log_gamma) * x_hat + self.beta
        ildj = -self.log_gamma + 0.5 * torch.log(var)
        return y[:, None, :], ildj[None, None, :]

In [50]:
class SDEFlow(nn.Module):

    def __init__(self, cond_inputs = 1, num_layers = 5):
        super().__init__()
        
        self.coupling = nn.ModuleList([CouplingLayer(cond_inputs + STATE_DIM, 1) for _ in range(num_layers)])
        self.permutation = [PermutationLayer() for _ in range(num_layers)]
        self.batch_norm = nn.ModuleList([BatchNormLayer(STATE_DIM * N) for _ in range(num_layers-1)])
        self.SP = SoftplusLayer()
        
        self.base_dist = d.normal.Normal(loc = 0., scale = 1.)
        self.num_layers = num_layers
        
    def forward(self, batch_size, obs_model, *args, **kwargs):

        eps = self.base_dist.sample([batch_size, 1, STATE_DIM * N]).to(device)
        log_prob = self.base_dist.log_prob(eps).sum(-1)
        
        obs_tile = obs_model.mu[None, :, 1:, None].repeat(batch_size, STATE_DIM, 1, 50).reshape(batch_size, STATE_DIM, -1)
        times = torch.arange(dt, T + dt, dt, device = eps.device)[(None,) * 2].repeat(batch_size, STATE_DIM, 1).transpose(-2, -1).reshape(batch_size, 1, -1)
        
        ildjs = []
        
        for i in range(self.num_layers):
            eps, cl_ildj = self.coupling[i](self.permutation[i](eps), (obs_tile, times))
            if i < (self.num_layers - 1):
                eps, bn_ildj = self.batch_norm[i](eps)
                ildjs.append(bn_ildj)
            ildjs.append(cl_ildj)
                
        eps, sp_ildj = self.SP(eps)
        ildjs.append(sp_ildj)
        
        for ildj in ildjs:
            log_prob += ildj.sum(-1)
    
        return eps.reshape(batch_size, STATE_DIM, -1).permute(0, 2, 1) + 1e-9, log_prob

In [51]:
class ObsModel(nn.Module):

    def __init__(self, times, mu, scale):
        super().__init__()

        self.idx = self._get_idx(times)
        self.times = times
        self.mu = torch.Tensor(mu).to(device)
        self.scale = scale
        
    def forward(self, x):
        obs_ll = d.normal.Normal(self.mu.permute(1, 0), self.scale).log_prob(x[:, self.idx, :])
        return torch.sum(obs_ll, [-1, -2]).mean()

    def _get_idx(self, times):
        return list((times / dt).astype(int))
    
    def plt_dat(self):
        return self.mu, self.times

In [52]:
def neg_log_lik(C_path, T_span_tensor, dt, I_S_tensor, I_D_tensor, drift_diffusion, params_dict, temp_ref):
    drift, diffusion_sqrt = drift_diffusion(C_path[:, :-1, :], T_span_tensor[:, :-1, :], I_S_tensor[:, :-1, :], I_D_tensor[:, :-1, :], params_dict, temp_ref)
    #print('\n drift =', drift)
    #print('\n diffusion_sqrt =', diffusion_sqrt)
    #euler_maruyama_sample = d.multivariate_normal.MultivariateNormal(loc = C_path[:, :-1, :] + drift * dt, scale_tril = diffusion_sqrt * math.sqrt(dt)) This line no longer applies because of addition of CO2 as a 'state'.
    drift_means_with_CO2 = torch.cat((C_path[:, :-1, :-1] + drift[:, :, :-1] * dt, drift[:, :, -1].unsqueeze(2)), 2) #Separate explicit algebraic variable CO2 mean from integration process.
    euler_maruyama_sample = d.multivariate_normal.MultivariateNormal(loc = drift_means_with_CO2, scale_tril = diffusion_sqrt * math.sqrt(dt))
    return -euler_maruyama_sample.log_prob(C_path[:, 1:, :]).sum(-1)

In [53]:
obs_model = ObsModel(times = obs_times, mu = obs_means_awb_eca_T, scale = obs_error_scale_awb_eca.reshape([1, STATE_DIM]))
net = SDEFlow().to(device)
optimizer = optim.Adam(net.parameters(), lr = LR)

def train(niter, pretrain_iter, BATCH_SIZE, T_span_tensor, I_S_tensor, I_D_tensor, drift_diffusion, params_dict, analytical_steady_state_init):
    if pretrain_iter >= niter:
        raise Exception("pretrain_inter must be < niter.")
    best_loss_norm = 1e10
    best_loss_ELBO = 1e20
    norm_losses = [best_loss_norm] * 10
    ELBO_losses = [best_loss_ELBO] * 10
    C0 = analytical_steady_state_init(I_S_tensor[0, 0, 0].item(), I_D_tensor[0, 0, 0].item(), params_dict) #Calculate deterministic initial conditions.
    C0 = C0[(None,) * 2].repeat(BATCH_SIZE, 1, 1).to(device) #Assign initial conditions to C_path.
    with tqdm(total = niter, desc = f'Train Diffusion', position = -1) as t:
        for iter in range(niter):
            net.train()
            optimizer.zero_grad()
            C_path, log_prob = net(BATCH_SIZE, obs_model) #Obtain paths with solutions at times after t0.
            C_path = torch.cat([C0, C_path], 1) #Append deterministic CON initial conditions conditional on parameter values to C path. 
            if iter <= pretrain_iter:
                l1_norm_element = C_path - torch.mean(obs_model.mu, -1)
                l1_norm = torch.sum(torch.abs(l1_norm_element)).mean()
                best_loss_norm = l1_norm if l1_norm < best_loss_norm else best_loss_norm
                l1_norm.backward()
                norm_losses.append(l1_norm.item())
                #l2_norm_element = C_path - torch.mean(obs_model.mu, -1)
                #l2_norm = torch.sqrt(torch.sum(torch.square(l2_norm_element))).mean()
                #best_loss_norm = l2_norm if l2_norm < best_loss_norm else best_loss_norm
                #l2_norm.backward()
                #norm_losses.append(l2_norm.item())
                if len(norm_losses) > 10:
                    norm_losses.pop(0)
                if iter % 10 == 0:
                    print(f"Moving average norm loss at {iter} iterations is: {sum(norm_losses) / len(norm_losses)}. Best norm loss value is: {best_loss_norm}.")
                    print('\nC_path mean =', C_path.mean(-2))
                    print('\nC_path =', C_path)
            else:
                log_lik = neg_log_lik(C_path, T_span_tensor.to(device), dt, I_S_tensor.to(device), I_D_tensor.to(device), drift_diffusion, params_dict, temp_ref)
                ELBO = log_prob.mean() + log_lik.mean() - obs_model(C_path)
                best_loss_ELBO = ELBO if ELBO < best_loss_ELBO else best_loss_ELBO
                ELBO.backward()
                ELBO_losses.append(ELBO.item())
                if len(ELBO_losses) > 10:
                    ELBO_losses.pop(0)
                if iter % 10 == 0:
                    print(f"Moving average ELBO loss at {iter} iterations is: {sum(ELBO_losses) / len(ELBO_losses)}. Best ELBO loss value is: {best_loss_ELBO}.")
                    print('\nC_path mean =', C_path.mean(-2))
                    print('\n C_path =', C_path)
            torch.nn.utils.clip_grad_norm_(net.parameters(), 3.0)
            optimizer.step()
            if iter % 100000 == 0 and iter > 0:
                optimizer.param_groups[0]['lr'] *= 0.1
            t.update()

In [None]:
train(niter, pretrain_iter, BATCH_SIZE, T_span_tensor, I_S_tensor, I_D_tensor, drift_diffusion_sawb_eca_ss, sawb_eca_ss_params_dict, analytical_steady_state_init_awb_eca)


Train Diffusion:   0%|          | 0/20000 [00:00<?, ?it/s][A
Train Diffusion:   0%|          | 1/20000 [00:03<21:24:51,  3.85s/it][A

Moving average norm loss at 0 iterations is: 9000026948.89375. Best norm loss value is: 269488.9375.

C_path mean = tensor([[0.9220, 0.6072, 0.6381, 0.5925, 0.5082]], grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [7.9971e-01, 7.1143e-01, 6.3176e-01, 5.6106e-01, 5.3278e-01],
         [8.4864e-01, 6.5093e-01, 6.1483e-01, 6.3111e-01, 4.6989e-01],
         ...,
         [8.1543e-01, 6.3381e-01, 6.3075e-01, 4.7001e-01, 5.0232e-01],
         [9.2011e-01, 6.3572e-01, 6.5004e-01, 6.3174e-01, 5.7745e-01],
         [6.9366e-01, 6.7901e-01, 6.9061e-01, 6.2450e-01, 8.0403e-01]]],
       grad_fn=<CatBackward>)



Train Diffusion:   0%|          | 2/20000 [00:08<22:50:20,  4.11s/it][A
Train Diffusion:   0%|          | 3/20000 [00:12<22:29:31,  4.05s/it][A
Train Diffusion:   0%|          | 4/20000 [00:16<22:27:27,  4.04s/it][A
Train Diffusion:   0%|          | 5/20000 [00:19<21:32:13,  3.88s/it][A
Train Diffusion:   0%|          | 6/20000 [00:23<20:53:12,  3.76s/it][A
Train Diffusion:   0%|          | 7/20000 [00:27<21:42:14,  3.91s/it][A
Train Diffusion:   0%|          | 8/20000 [00:31<22:21:59,  4.03s/it][A
Train Diffusion:   0%|          | 9/20000 [00:36<22:57:12,  4.13s/it][A
Train Diffusion:   0%|          | 10/20000 [00:40<22:59:40,  4.14s/it][A
Train Diffusion:   0%|          | 11/20000 [00:44<22:15:22,  4.01s/it][A

Moving average norm loss at 10 iterations is: 256154.1609375. Best norm loss value is: 245591.890625.

C_path mean = tensor([[5.0919, 0.5965, 0.4443, 0.2204, 0.0907]], grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [1.3560e+00, 3.1710e+00, 5.3221e-01, 3.3265e-01, 1.4144e-01],
         [3.7443e+00, 2.1107e+00, 5.5668e-01, 3.1682e-01, 1.4090e-01],
         ...,
         [3.4383e+00, 5.4734e-01, 3.2404e-01, 1.4100e-01, 5.4285e-02],
         [4.7260e+00, 5.3687e-01, 3.2807e-01, 1.3682e-01, 6.4090e-02],
         [3.7674e+00, 5.3702e-01, 3.2240e-01, 1.3616e-01, 9.2043e-02]]],
       grad_fn=<CatBackward>)



Train Diffusion:   0%|          | 12/20000 [00:47<22:04:02,  3.97s/it][A
Train Diffusion:   0%|          | 13/20000 [00:51<22:02:38,  3.97s/it][A
Train Diffusion:   0%|          | 14/20000 [00:55<21:21:15,  3.85s/it][A
Train Diffusion:   0%|          | 15/20000 [00:58<20:47:26,  3.75s/it][A
Train Diffusion:   0%|          | 16/20000 [01:02<20:55:19,  3.77s/it][A
Train Diffusion:   0%|          | 17/20000 [01:07<22:39:40,  4.08s/it][A
Train Diffusion:   0%|          | 18/20000 [01:12<24:02:19,  4.33s/it][A
Train Diffusion:   0%|          | 19/20000 [01:17<24:31:39,  4.42s/it][A
Train Diffusion:   0%|          | 20/20000 [01:21<25:05:37,  4.52s/it][A
Train Diffusion:   0%|          | 21/20000 [01:25<24:15:15,  4.37s/it][A

Moving average norm loss at 20 iterations is: 233697.5015625. Best norm loss value is: 223737.34375.

C_path mean = tensor([[9.2516, 0.2274, 0.1833, 0.1175, 0.0805]], grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [9.3993e-01, 6.1739e+00, 2.0215e-01, 1.4572e-01, 8.4693e-02],
         [3.9620e+00, 3.2090e+00, 2.0229e-01, 1.3789e-01, 8.2933e-02],
         ...,
         [6.9916e+00, 2.0476e-01, 1.6487e-01, 9.9728e-02, 7.3377e-02],
         [9.4829e+00, 2.1240e-01, 1.6721e-01, 1.0106e-01, 1.3770e-01],
         [7.7576e+00, 2.0767e-01, 1.5269e-01, 9.4625e-02, 1.6575e-01]]],
       grad_fn=<CatBackward>)



Train Diffusion:   0%|          | 22/20000 [01:30<24:23:19,  4.39s/it][A
Train Diffusion:   0%|          | 23/20000 [01:35<25:03:37,  4.52s/it][A
Train Diffusion:   0%|          | 24/20000 [01:39<24:21:28,  4.39s/it][A
Train Diffusion:   0%|          | 25/20000 [01:43<23:34:41,  4.25s/it][A
Train Diffusion:   0%|          | 26/20000 [01:47<24:30:31,  4.42s/it][A
Train Diffusion:   0%|          | 27/20000 [01:52<24:00:02,  4.33s/it][A
Train Diffusion:   0%|          | 28/20000 [01:56<23:52:20,  4.30s/it][A
Train Diffusion:   0%|          | 29/20000 [02:00<24:27:39,  4.41s/it][A
Train Diffusion:   0%|          | 30/20000 [02:05<24:32:12,  4.42s/it][A
Train Diffusion:   0%|          | 31/20000 [02:09<23:38:06,  4.26s/it][A

Moving average norm loss at 30 iterations is: 212832.028125. Best norm loss value is: 203792.328125.

C_path mean = tensor([[13.2094,  0.2030,  0.1230,  0.0588,  0.0311]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [1.3411e+00, 7.4657e+00, 1.5269e-01, 8.6965e-02, 3.9962e-02],
         [6.7518e+00, 3.2582e+00, 1.6271e-01, 8.0809e-02, 3.9876e-02],
         ...,
         [9.1239e+00, 1.5949e-01, 9.4091e-02, 4.3845e-02, 2.3731e-02],
         [1.2467e+01, 1.6268e-01, 9.3346e-02, 4.0997e-02, 4.2116e-02],
         [1.0060e+01, 1.5566e-01, 8.8135e-02, 3.9118e-02, 6.3264e-02]]],
       grad_fn=<CatBackward>)



Train Diffusion:   0%|          | 32/20000 [02:13<23:11:45,  4.18s/it][A
Train Diffusion:   0%|          | 33/20000 [02:18<25:02:48,  4.52s/it][A
Train Diffusion:   0%|          | 34/20000 [02:23<25:20:26,  4.57s/it][A
Train Diffusion:   0%|          | 35/20000 [02:27<25:03:51,  4.52s/it][A
Train Diffusion:   0%|          | 36/20000 [02:31<24:35:18,  4.43s/it][A
Train Diffusion:   0%|          | 37/20000 [02:36<24:19:22,  4.39s/it][A
Train Diffusion:   0%|          | 38/20000 [02:40<24:02:39,  4.34s/it][A
Train Diffusion:   0%|          | 39/20000 [02:44<23:22:25,  4.22s/it][A
Train Diffusion:   0%|          | 40/20000 [02:48<22:51:08,  4.12s/it][A
Train Diffusion:   0%|          | 41/20000 [02:51<22:06:16,  3.99s/it][A

Moving average norm loss at 40 iterations is: 192124.678125. Best norm loss value is: 182333.34375.

C_path mean = tensor([[1.7474e+01, 2.1421e-01, 1.0467e-01, 3.0487e-02, 1.0014e-02]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [1.6564e+00, 1.1006e+01, 1.5054e-01, 5.5868e-02, 1.4688e-02],
         [8.7876e+00, 5.2662e+00, 1.5684e-01, 5.1258e-02, 1.4695e-02],
         ...,
         [1.2775e+01, 1.5700e-01, 6.6924e-02, 1.7097e-02, 6.2588e-03],
         [1.7618e+01, 1.5722e-01, 6.5081e-02, 1.6403e-02, 1.7707e-02],
         [1.4376e+01, 1.5150e-01, 6.0570e-02, 1.5226e-02, 6.7750e-02]]],
       grad_fn=<CatBackward>)



Train Diffusion:   0%|          | 42/20000 [02:55<22:02:15,  3.98s/it][A
Train Diffusion:   0%|          | 43/20000 [02:59<21:52:52,  3.95s/it][A
Train Diffusion:   0%|          | 44/20000 [03:04<22:26:21,  4.05s/it][A
Train Diffusion:   0%|          | 45/20000 [03:08<22:34:51,  4.07s/it][A
Train Diffusion:   0%|          | 46/20000 [03:12<22:48:05,  4.11s/it][A
Train Diffusion:   0%|          | 47/20000 [03:16<22:38:52,  4.09s/it][A
Train Diffusion:   0%|          | 48/20000 [03:21<23:33:36,  4.25s/it][A
Train Diffusion:   0%|          | 49/20000 [03:26<25:02:46,  4.52s/it][A
Train Diffusion:   0%|          | 50/20000 [03:31<25:34:03,  4.61s/it][A
Train Diffusion:   0%|          | 51/20000 [03:36<26:31:35,  4.79s/it][A

Moving average norm loss at 50 iterations is: 169673.5265625. Best norm loss value is: 159114.25.

C_path mean = tensor([[2.2116e+01, 1.9485e-01, 9.2719e-02, 1.6310e-02, 3.9476e-03]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [1.6790e+00, 1.3959e+01, 1.3697e-01, 3.9604e-02, 6.1129e-03],
         [1.0477e+01, 6.7407e+00, 1.5056e-01, 3.2779e-02, 6.1776e-03],
         ...,
         [1.6527e+01, 1.4477e-01, 5.5677e-02, 7.7678e-03, 2.0575e-03],
         [2.2890e+01, 1.4923e-01, 5.2333e-02, 7.0355e-03, 9.4717e-03],
         [1.8611e+01, 1.3817e-01, 4.5595e-02, 6.1894e-03, 5.5308e-02]]],
       grad_fn=<CatBackward>)



Train Diffusion:   0%|          | 52/20000 [03:41<26:45:12,  4.83s/it][A
Train Diffusion:   0%|          | 53/20000 [03:45<26:36:07,  4.80s/it][A
Train Diffusion:   0%|          | 54/20000 [03:50<26:17:57,  4.75s/it][A
Train Diffusion:   0%|          | 55/20000 [03:54<25:46:42,  4.65s/it][A
Train Diffusion:   0%|          | 56/20000 [03:59<25:14:51,  4.56s/it][A
Train Diffusion:   0%|          | 57/20000 [04:02<23:45:39,  4.29s/it][A
Train Diffusion:   0%|          | 58/20000 [04:06<22:29:03,  4.06s/it][A
Train Diffusion:   0%|          | 59/20000 [04:10<22:30:46,  4.06s/it][A
Train Diffusion:   0%|          | 60/20000 [04:15<23:35:00,  4.26s/it][A
Train Diffusion:   0%|          | 61/20000 [04:20<24:30:37,  4.43s/it][A

Moving average norm loss at 60 iterations is: 145565.65. Best norm loss value is: 134335.6875.

C_path mean = tensor([[2.7057e+01, 2.1091e-01, 1.0369e-01, 8.3107e-03, 1.3108e-03]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [2.1479e+00, 1.6039e+01, 1.5425e-01, 3.0930e-02, 1.9752e-03],
         [1.3381e+01, 7.1448e+00, 1.6509e-01, 2.2360e-02, 1.9324e-03],
         ...,
         [1.9948e+01, 1.5948e-01, 6.1306e-02, 3.1635e-03, 5.7489e-04],
         [2.7707e+01, 1.7356e-01, 5.5989e-02, 2.8126e-03, 3.5761e-03],
         [2.2367e+01, 1.5964e-01, 4.4317e-02, 2.2444e-03, 3.2677e-02]]],
       grad_fn=<CatBackward>)



Train Diffusion:   0%|          | 62/20000 [04:24<24:03:48,  4.34s/it][A
Train Diffusion:   0%|          | 63/20000 [04:28<23:33:30,  4.25s/it][A
Train Diffusion:   0%|          | 64/20000 [04:31<22:40:35,  4.09s/it][A
Train Diffusion:   0%|          | 65/20000 [04:35<21:43:52,  3.92s/it][A
Train Diffusion:   0%|          | 66/20000 [04:39<22:08:46,  4.00s/it][A
Train Diffusion:   0%|          | 67/20000 [04:43<22:25:49,  4.05s/it][A
Train Diffusion:   0%|          | 68/20000 [04:47<21:55:22,  3.96s/it][A
Train Diffusion:   0%|          | 69/20000 [04:51<21:35:01,  3.90s/it][A
Train Diffusion:   0%|          | 70/20000 [04:55<21:15:46,  3.84s/it][A
Train Diffusion:   0%|          | 71/20000 [04:58<20:49:10,  3.76s/it][A

Moving average norm loss at 70 iterations is: 119851.25546875. Best norm loss value is: 108109.90625.

C_path mean = tensor([[3.2303e+01, 2.4194e-01, 1.0987e-01, 4.2685e-03, 4.5734e-04]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [2.9282e+00, 1.9305e+01, 1.8421e-01, 1.8758e-02, 6.7951e-04],
         [1.7002e+01, 9.2825e+00, 1.9688e-01, 1.2541e-02, 6.6725e-04],
         ...,
         [2.3898e+01, 1.9999e-01, 4.7435e-02, 1.2356e-03, 1.5198e-04],
         [3.3257e+01, 2.1453e-01, 4.1916e-02, 1.1137e-03, 1.5032e-03],
         [2.6876e+01, 1.9218e-01, 2.9550e-02, 8.2823e-04, 2.4551e-02]]],
       grad_fn=<CatBackward>)



Train Diffusion:   0%|          | 72/20000 [05:02<20:27:19,  3.70s/it][A
Train Diffusion:   0%|          | 73/20000 [05:06<20:43:51,  3.75s/it][A
Train Diffusion:   0%|          | 74/20000 [05:10<21:14:52,  3.84s/it][A
Train Diffusion:   0%|          | 75/20000 [05:13<20:50:33,  3.77s/it][A
Train Diffusion:   0%|          | 76/20000 [05:17<20:19:37,  3.67s/it][A
Train Diffusion:   0%|          | 77/20000 [05:21<20:50:31,  3.77s/it][A
Train Diffusion:   0%|          | 78/20000 [05:25<21:13:44,  3.84s/it][A
Train Diffusion:   0%|          | 79/20000 [05:29<22:54:15,  4.14s/it][A
Train Diffusion:   0%|          | 80/20000 [05:35<24:29:59,  4.43s/it][A
Train Diffusion:   0%|          | 81/20000 [05:39<24:17:04,  4.39s/it][A

Moving average norm loss at 80 iterations is: 92362.50625. Best norm loss value is: 79383.0625.

C_path mean = tensor([[3.8079e+01, 1.8538e-01, 9.8610e-02, 2.0694e-03, 3.3700e-04]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [3.2049e+00, 2.1610e+01, 1.4394e-01, 1.0075e-02, 2.5482e-04],
         [1.9773e+01, 9.8741e+00, 1.6283e-01, 5.4628e-03, 2.7186e-04],
         ...,
         [2.8059e+01, 1.8397e-01, 5.1148e-02, 7.5398e-04, 1.0558e-04],
         [3.9091e+01, 2.2914e-01, 4.1610e-02, 6.0074e-04, 6.7373e-04],
         [3.1323e+01, 1.7602e-01, 2.1258e-02, 3.4093e-04, 9.7317e-03]]],
       grad_fn=<CatBackward>)



Train Diffusion:   0%|          | 82/20000 [05:44<24:43:08,  4.47s/it][A
Train Diffusion:   0%|          | 83/20000 [05:49<26:18:18,  4.75s/it][A
Train Diffusion:   0%|          | 84/20000 [05:54<27:29:55,  4.97s/it][A
Train Diffusion:   0%|          | 85/20000 [06:00<27:44:34,  5.02s/it][A
Train Diffusion:   0%|          | 86/20000 [06:04<26:28:45,  4.79s/it][A
Train Diffusion:   0%|          | 87/20000 [06:08<25:51:04,  4.67s/it][A
Train Diffusion:   0%|          | 88/20000 [06:13<25:41:03,  4.64s/it][A
Train Diffusion:   0%|          | 89/20000 [06:17<25:15:22,  4.57s/it][A
Train Diffusion:   0%|          | 90/20000 [06:21<24:26:39,  4.42s/it][A
Train Diffusion:   0%|          | 91/20000 [06:26<24:34:47,  4.44s/it][A

Moving average norm loss at 90 iterations is: 63678.681640625. Best norm loss value is: 52799.84765625.

C_path mean = tensor([[4.4138e+01, 2.0129e-01, 7.7244e-02, 1.2258e-03, 1.3705e-04]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [4.0385e+00, 2.5706e+01, 1.4614e-01, 5.7341e-03, 2.0373e-04],
         [2.3224e+01, 1.3002e+01, 1.6692e-01, 3.7639e-03, 2.0333e-04],
         ...,
         [3.2760e+01, 1.7115e-01, 1.8864e-02, 3.3556e-04, 2.8923e-05],
         [4.5696e+01, 2.2411e-01, 1.5318e-02, 2.7539e-04, 2.0356e-04],
         [3.6897e+01, 1.7059e-01, 9.9386e-03, 2.0154e-04, 4.9570e-03]]],
       grad_fn=<CatBackward>)



Train Diffusion:   0%|          | 92/20000 [06:30<24:11:23,  4.37s/it][A
Train Diffusion:   0%|          | 93/20000 [06:35<25:11:48,  4.56s/it][A
Train Diffusion:   0%|          | 94/20000 [06:40<26:33:56,  4.80s/it][A
Train Diffusion:   0%|          | 95/20000 [06:44<25:16:47,  4.57s/it][A
Train Diffusion:   0%|          | 96/20000 [06:49<24:44:43,  4.48s/it][A
Train Diffusion:   0%|          | 97/20000 [06:53<24:09:40,  4.37s/it][A
Train Diffusion:   0%|          | 98/20000 [06:57<23:40:28,  4.28s/it][A
Train Diffusion:   0%|          | 99/20000 [07:01<23:19:01,  4.22s/it][A
Train Diffusion:   0%|          | 100/20000 [07:05<23:00:33,  4.16s/it][A
Train Diffusion:   1%|          | 101/20000 [07:09<22:21:17,  4.04s/it][A

Moving average norm loss at 100 iterations is: 45527.064453125. Best norm loss value is: 41372.53125.

C_path mean = tensor([[4.9542e+01, 3.1620e-01, 2.7532e-01, 5.9459e-04, 4.1137e-05]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [5.3874e+00, 2.7555e+01, 3.2911e-01, 1.9439e-03, 7.5704e-06],
         [2.8583e+01, 1.4164e+01, 3.6402e-01, 7.6577e-04, 9.2744e-06],
         ...,
         [3.7109e+01, 5.0683e-01, 1.0398e-01, 7.4141e-05, 3.5225e-06],
         [5.1534e+01, 7.1133e-01, 6.4775e-02, 5.2126e-05, 2.5480e-05],
         [4.1278e+01, 4.9283e-01, 1.0475e-02, 1.3396e-05, 8.5127e-04]]],
       grad_fn=<CatBackward>)



Train Diffusion:   1%|          | 102/20000 [07:13<22:45:28,  4.12s/it][A
Train Diffusion:   1%|          | 103/20000 [07:17<23:20:54,  4.22s/it][A
Train Diffusion:   1%|          | 104/20000 [07:22<24:40:29,  4.46s/it][A
Train Diffusion:   1%|          | 105/20000 [07:27<24:51:35,  4.50s/it][A
Train Diffusion:   1%|          | 106/20000 [07:31<24:41:29,  4.47s/it][A
Train Diffusion:   1%|          | 107/20000 [07:37<25:55:19,  4.69s/it][A
Train Diffusion:   1%|          | 108/20000 [07:41<25:34:45,  4.63s/it][A
Train Diffusion:   1%|          | 109/20000 [07:46<25:39:53,  4.65s/it][A
Train Diffusion:   1%|          | 110/20000 [07:50<25:26:13,  4.60s/it][A
Train Diffusion:   1%|          | 111/20000 [07:55<25:59:46,  4.71s/it][A

Moving average norm loss at 110 iterations is: 38640.464453125. Best norm loss value is: 36596.625.

C_path mean = tensor([[5.1349e+01, 4.3498e-01, 5.1931e-01, 1.8785e-03, 2.3513e-06]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [5.7738e+00, 2.6069e+01, 5.1909e-01, 3.0032e-02, 4.8006e-06],
         [3.1877e+01, 1.3060e+01, 6.2233e-01, 8.2679e-03, 4.0139e-06],
         ...,
         [3.8648e+01, 5.9568e-01, 3.7264e-01, 1.4373e-05, 2.0455e-08],
         [5.2131e+01, 9.7462e-01, 2.7164e-01, 9.2378e-06, 1.8331e-07],
         [4.1516e+01, 6.8848e-01, 1.4404e-01, 4.0640e-06, 1.9259e-05]]],
       grad_fn=<CatBackward>)



Train Diffusion:   1%|          | 112/20000 [07:59<24:43:40,  4.48s/it][A
Train Diffusion:   1%|          | 113/20000 [08:03<22:58:21,  4.16s/it][A
Train Diffusion:   1%|          | 114/20000 [08:06<21:44:14,  3.94s/it][A
Train Diffusion:   1%|          | 115/20000 [08:09<20:52:34,  3.78s/it][A
Train Diffusion:   1%|          | 116/20000 [08:15<24:09:48,  4.37s/it][A
Train Diffusion:   1%|          | 117/20000 [08:19<23:30:51,  4.26s/it][A
Train Diffusion:   1%|          | 118/20000 [08:23<23:06:40,  4.18s/it][A
Train Diffusion:   1%|          | 119/20000 [08:27<23:00:54,  4.17s/it][A
Train Diffusion:   1%|          | 120/20000 [08:32<23:48:39,  4.31s/it][A
Train Diffusion:   1%|          | 121/20000 [08:36<23:33:18,  4.27s/it][A

Moving average norm loss at 120 iterations is: 33950.79453125. Best norm loss value is: 31896.216796875.

C_path mean = tensor([[5.1293e+01, 4.0215e-01, 6.9005e-01, 8.7143e-03, 1.2007e-06]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [5.9823e+00, 2.4138e+01, 5.7370e-01, 2.0477e-01, 1.1446e-05],
         [3.2417e+01, 1.1145e+01, 6.3997e-01, 8.5060e-02, 7.2477e-06],
         ...,
         [3.9263e+01, 5.0846e-01, 5.6526e-01, 9.2906e-06, 1.7545e-09],
         [5.1386e+01, 7.0837e-01, 4.8451e-01, 7.0224e-06, 1.5487e-08],
         [4.0361e+01, 5.9806e-01, 4.6962e-01, 5.0633e-06, 5.1053e-06]]],
       grad_fn=<CatBackward>)



Train Diffusion:   1%|          | 122/20000 [08:41<23:44:30,  4.30s/it][A
Train Diffusion:   1%|          | 123/20000 [08:44<22:31:09,  4.08s/it][A
Train Diffusion:   1%|          | 124/20000 [08:47<21:13:25,  3.84s/it][A
Train Diffusion:   1%|          | 125/20000 [08:51<20:36:05,  3.73s/it][A
Train Diffusion:   1%|          | 126/20000 [08:54<20:21:59,  3.69s/it][A
Train Diffusion:   1%|          | 127/20000 [08:58<19:45:38,  3.58s/it][A
Train Diffusion:   1%|          | 128/20000 [09:01<19:35:12,  3.55s/it][A
Train Diffusion:   1%|          | 129/20000 [09:05<19:51:25,  3.60s/it][A
Train Diffusion:   1%|          | 130/20000 [09:09<20:37:10,  3.74s/it][A
Train Diffusion:   1%|          | 131/20000 [09:13<21:24:09,  3.88s/it][A

Moving average norm loss at 130 iterations is: 30783.284765625. Best norm loss value is: 28924.4296875.

C_path mean = tensor([[5.1264e+01, 2.6151e-01, 5.9031e-01, 7.7462e-03, 1.6729e-06]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [6.5667e+00, 2.4976e+01, 4.3174e-01, 1.2417e-01, 9.5302e-06],
         [3.2803e+01, 1.1414e+01, 4.9760e-01, 7.4444e-02, 9.1178e-06],
         ...,
         [4.0432e+01, 4.2587e-01, 3.3387e-01, 1.4239e-05, 2.6950e-09],
         [5.2265e+01, 4.7658e-01, 2.8587e-01, 1.4551e-05, 2.8388e-08],
         [4.1410e+01, 4.1189e-01, 2.2945e-01, 9.4019e-06, 4.9346e-06]]],
       grad_fn=<CatBackward>)



Train Diffusion:   1%|          | 132/20000 [09:17<21:33:15,  3.91s/it][A
Train Diffusion:   1%|          | 133/20000 [09:22<22:51:35,  4.14s/it][A
Train Diffusion:   1%|          | 134/20000 [09:27<24:41:31,  4.47s/it][A
Train Diffusion:   1%|          | 135/20000 [09:31<23:49:09,  4.32s/it][A
Train Diffusion:   1%|          | 136/20000 [09:35<23:47:56,  4.31s/it][A
Train Diffusion:   1%|          | 137/20000 [09:40<24:50:37,  4.50s/it][A
Train Diffusion:   1%|          | 138/20000 [09:45<24:25:32,  4.43s/it][A
Train Diffusion:   1%|          | 139/20000 [09:49<24:02:24,  4.36s/it][A
Train Diffusion:   1%|          | 140/20000 [09:53<23:53:53,  4.33s/it][A
Train Diffusion:   1%|          | 141/20000 [09:58<24:11:10,  4.38s/it][A

Moving average norm loss at 140 iterations is: 28330.3611328125. Best norm loss value is: 26621.927734375.

C_path mean = tensor([[5.0906e+01, 2.0892e-01, 7.1150e-01, 1.2284e-02, 2.3031e-06]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [6.2340e+00, 2.5334e+01, 4.1140e-01, 2.1130e-01, 5.3794e-06],
         [3.1704e+01, 1.2011e+01, 4.7050e-01, 9.2262e-02, 5.4302e-06],
         ...,
         [4.1236e+01, 3.9680e-01, 8.2801e-01, 2.1256e-05, 3.8600e-09],
         [5.2659e+01, 4.7619e-01, 6.7178e-01, 2.7017e-05, 4.7485e-08],
         [4.1928e+01, 4.3342e-01, 4.5061e-01, 1.0734e-05, 1.4934e-05]]],
       grad_fn=<CatBackward>)



Train Diffusion:   1%|          | 142/20000 [10:01<22:37:43,  4.10s/it][A
Train Diffusion:   1%|          | 143/20000 [10:05<21:59:25,  3.99s/it][A
Train Diffusion:   1%|          | 144/20000 [10:09<22:02:07,  4.00s/it][A
Train Diffusion:   1%|          | 145/20000 [10:13<22:07:30,  4.01s/it][A
Train Diffusion:   1%|          | 146/20000 [10:17<22:04:05,  4.00s/it][A
Train Diffusion:   1%|          | 147/20000 [10:21<22:18:00,  4.04s/it][A
Train Diffusion:   1%|          | 148/20000 [10:25<21:57:36,  3.98s/it][A
Train Diffusion:   1%|          | 149/20000 [10:29<21:41:48,  3.93s/it][A
Train Diffusion:   1%|          | 150/20000 [10:32<20:56:45,  3.80s/it][A
Train Diffusion:   1%|          | 151/20000 [10:35<20:14:12,  3.67s/it][A

Moving average norm loss at 150 iterations is: 25509.7046875. Best norm loss value is: 24285.419921875.

C_path mean = tensor([[5.1353e+01, 1.8204e-01, 6.2930e-01, 8.1494e-03, 2.2082e-06]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [6.7488e+00, 2.5097e+01, 3.7959e-01, 2.0170e-01, 7.4650e-06],
         [3.3069e+01, 1.1569e+01, 4.0342e-01, 8.6315e-02, 7.3277e-06],
         ...,
         [4.2037e+01, 3.1133e-01, 6.3545e-01, 2.0346e-05, 4.6918e-09],
         [5.2695e+01, 3.6986e-01, 5.4263e-01, 2.1543e-05, 5.9621e-08],
         [4.2095e+01, 3.8709e-01, 4.1395e-01, 9.9193e-06, 2.4717e-05]]],
       grad_fn=<CatBackward>)



Train Diffusion:   1%|          | 152/20000 [10:39<19:29:02,  3.53s/it][A
Train Diffusion:   1%|          | 153/20000 [10:42<18:56:36,  3.44s/it][A
Train Diffusion:   1%|          | 154/20000 [10:45<19:05:55,  3.46s/it][A
Train Diffusion:   1%|          | 155/20000 [10:49<19:41:30,  3.57s/it][A
Train Diffusion:   1%|          | 156/20000 [10:53<19:56:10,  3.62s/it][A
Train Diffusion:   1%|          | 157/20000 [10:57<20:20:18,  3.69s/it][A
Train Diffusion:   1%|          | 158/20000 [11:00<20:15:51,  3.68s/it][A
Train Diffusion:   1%|          | 159/20000 [11:05<21:00:53,  3.81s/it][A
Train Diffusion:   1%|          | 160/20000 [11:11<24:46:58,  4.50s/it][A
Train Diffusion:   1%|          | 161/20000 [11:16<26:19:20,  4.78s/it][A

Moving average norm loss at 160 iterations is: 23046.02890625. Best norm loss value is: 22054.900390625.

C_path mean = tensor([[5.1533e+01, 2.7249e-01, 8.4703e-01, 4.9003e-03, 1.3547e-06]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [7.5534e+00, 2.4956e+01, 5.4323e-01, 2.2245e-01, 3.3466e-06],
         [3.4249e+01, 1.1479e+01, 6.2297e-01, 6.3087e-02, 4.0509e-06],
         ...,
         [4.2681e+01, 4.6552e-01, 9.8449e-01, 1.2184e-05, 3.3004e-09],
         [5.2742e+01, 5.9794e-01, 7.5217e-01, 9.6659e-06, 7.7248e-08],
         [4.2255e+01, 5.6288e-01, 5.4436e-01, 3.9546e-06, 6.6433e-05]]],
       grad_fn=<CatBackward>)



Train Diffusion:   1%|          | 162/20000 [11:22<28:41:57,  5.21s/it][A
Train Diffusion:   1%|          | 163/20000 [11:28<29:37:10,  5.38s/it][A
Train Diffusion:   1%|          | 164/20000 [11:34<30:45:10,  5.58s/it][A
Train Diffusion:   1%|          | 165/20000 [11:37<26:57:51,  4.89s/it][A
Train Diffusion:   1%|          | 166/20000 [11:41<24:17:11,  4.41s/it][A
Train Diffusion:   1%|          | 167/20000 [11:44<23:02:38,  4.18s/it][A
Train Diffusion:   1%|          | 168/20000 [11:48<21:59:28,  3.99s/it][A
Train Diffusion:   1%|          | 169/20000 [11:53<23:02:17,  4.18s/it][A
Train Diffusion:   1%|          | 170/20000 [11:58<24:52:35,  4.52s/it][A
Train Diffusion:   1%|          | 171/20000 [12:01<23:27:56,  4.26s/it][A

Moving average norm loss at 170 iterations is: 20535.29765625. Best norm loss value is: 19497.857421875.

C_path mean = tensor([[5.1642e+01, 2.9393e-01, 8.2791e-01, 4.7970e-03, 1.3125e-06]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [8.4594e+00, 2.5010e+01, 5.5783e-01, 2.3992e-01, 2.2410e-06],
         [3.5095e+01, 1.1675e+01, 6.5796e-01, 5.8044e-02, 2.5189e-06],
         ...,
         [4.3752e+01, 4.5246e-01, 9.7420e-01, 1.1492e-05, 2.9404e-09],
         [5.3137e+01, 6.2299e-01, 7.5284e-01, 1.0373e-05, 7.1113e-08],
         [4.2624e+01, 5.7670e-01, 6.8351e-01, 3.7518e-06, 9.2676e-05]]],
       grad_fn=<CatBackward>)



Train Diffusion:   1%|          | 172/20000 [12:05<22:06:22,  4.01s/it][A
Train Diffusion:   1%|          | 173/20000 [12:08<20:54:21,  3.80s/it][A
Train Diffusion:   1%|          | 174/20000 [12:11<19:58:04,  3.63s/it][A
Train Diffusion:   1%|          | 175/20000 [12:15<19:22:53,  3.52s/it][A
Train Diffusion:   1%|          | 176/20000 [12:21<23:52:41,  4.34s/it][A
Train Diffusion:   1%|          | 177/20000 [12:24<22:32:15,  4.09s/it][A
Train Diffusion:   1%|          | 178/20000 [12:28<21:44:54,  3.95s/it][A
Train Diffusion:   1%|          | 179/20000 [12:32<21:04:43,  3.83s/it][A
Train Diffusion:   1%|          | 180/20000 [12:35<20:36:05,  3.74s/it][A
Train Diffusion:   1%|          | 181/20000 [12:38<19:48:26,  3.60s/it][A

Moving average norm loss at 180 iterations is: 18229.9982421875. Best norm loss value is: 17339.783203125.

C_path mean = tensor([[5.1919e+01, 1.6491e-01, 7.3127e-01, 3.6166e-03, 1.4905e-06]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [8.3963e+00, 2.4586e+01, 4.1880e-01, 1.5917e-01, 3.4731e-06],
         [3.5111e+01, 1.0593e+01, 4.8372e-01, 4.1168e-02, 3.7334e-06],
         ...,
         [4.4991e+01, 3.2422e-01, 8.5647e-01, 1.1630e-05, 4.3964e-09],
         [5.3487e+01, 4.1881e-01, 6.7614e-01, 1.1890e-05, 6.3345e-08],
         [4.2859e+01, 4.1920e-01, 4.3217e-01, 5.2517e-06, 3.1824e-05]]],
       grad_fn=<CatBackward>)



Train Diffusion:   1%|          | 182/20000 [12:42<19:11:11,  3.49s/it][A
Train Diffusion:   1%|          | 183/20000 [12:46<19:53:04,  3.61s/it][A
Train Diffusion:   1%|          | 184/20000 [12:50<22:00:23,  4.00s/it][A
Train Diffusion:   1%|          | 185/20000 [12:54<21:33:03,  3.92s/it][A
Train Diffusion:   1%|          | 186/20000 [12:57<20:22:51,  3.70s/it][A
Train Diffusion:   1%|          | 187/20000 [13:01<20:51:38,  3.79s/it][A
Train Diffusion:   1%|          | 188/20000 [13:06<21:49:39,  3.97s/it][A
Train Diffusion:   1%|          | 189/20000 [13:10<22:42:26,  4.13s/it][A
Train Diffusion:   1%|          | 190/20000 [13:15<24:00:43,  4.36s/it][A
Train Diffusion:   1%|          | 191/20000 [13:19<22:18:50,  4.06s/it][A

Moving average norm loss at 190 iterations is: 16132.27578125. Best norm loss value is: 15173.8671875.

C_path mean = tensor([[5.2038e+01, 2.7639e-01, 7.5100e-01, 2.3420e-03, 9.4284e-07]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [9.6225e+00, 2.4330e+01, 5.2478e-01, 1.3676e-01, 2.2781e-06],
         [3.6757e+01, 1.0825e+01, 6.0173e-01, 3.0519e-02, 2.4298e-06],
         ...,
         [4.5178e+01, 4.0167e-01, 7.9244e-01, 6.4587e-06, 2.7079e-09],
         [5.2725e+01, 5.0559e-01, 6.3261e-01, 6.3332e-06, 2.8097e-08],
         [4.2286e+01, 5.1158e-01, 4.0209e-01, 3.1355e-06, 1.6354e-05]]],
       grad_fn=<CatBackward>)



Train Diffusion:   1%|          | 192/20000 [13:22<21:17:42,  3.87s/it][A
Train Diffusion:   1%|          | 193/20000 [13:25<20:42:59,  3.77s/it][A
Train Diffusion:   1%|          | 194/20000 [13:29<20:19:29,  3.69s/it][A
Train Diffusion:   1%|          | 195/20000 [13:32<19:33:57,  3.56s/it][A
Train Diffusion:   1%|          | 196/20000 [13:35<18:59:21,  3.45s/it][A
Train Diffusion:   1%|          | 197/20000 [13:39<18:36:24,  3.38s/it][A
Train Diffusion:   1%|          | 198/20000 [13:42<18:29:03,  3.36s/it][A
Train Diffusion:   1%|          | 199/20000 [13:45<18:37:09,  3.39s/it][A
Train Diffusion:   1%|          | 200/20000 [13:49<18:45:04,  3.41s/it][A
Train Diffusion:   1%|          | 201/20000 [13:52<18:26:21,  3.35s/it][A

Moving average norm loss at 200 iterations is: 14564.86962890625. Best norm loss value is: 13639.283203125.

C_path mean = tensor([[5.2355e+01, 2.2454e-01, 6.4364e-01, 1.8558e-03, 9.7894e-07]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [9.9985e+00, 2.4518e+01, 4.4951e-01, 1.0590e-01, 2.2048e-06],
         [3.6773e+01, 1.0913e+01, 5.1409e-01, 2.2293e-02, 2.1590e-06],
         ...,
         [4.6355e+01, 3.2860e-01, 6.4793e-01, 6.3134e-06, 3.3021e-09],
         [5.3309e+01, 4.1708e-01, 5.3204e-01, 6.3300e-06, 2.3522e-08],
         [4.2752e+01, 4.3009e-01, 3.4514e-01, 3.3499e-06, 7.4382e-06]]],
       grad_fn=<CatBackward>)



Train Diffusion:   1%|          | 202/20000 [13:55<18:13:34,  3.31s/it][A
Train Diffusion:   1%|          | 203/20000 [13:59<18:27:03,  3.36s/it][A
Train Diffusion:   1%|          | 204/20000 [14:02<18:11:38,  3.31s/it][A
Train Diffusion:   1%|          | 205/20000 [14:05<18:01:36,  3.28s/it][A
Train Diffusion:   1%|          | 206/20000 [14:09<18:18:15,  3.33s/it][A
Train Diffusion:   1%|          | 207/20000 [14:12<18:25:36,  3.35s/it][A
Train Diffusion:   1%|          | 208/20000 [14:16<18:43:54,  3.41s/it][A
Train Diffusion:   1%|          | 209/20000 [14:19<18:46:49,  3.42s/it][A
Train Diffusion:   1%|          | 210/20000 [14:24<20:45:42,  3.78s/it][A
Train Diffusion:   1%|          | 211/20000 [14:29<22:50:58,  4.16s/it][A

Moving average norm loss at 210 iterations is: 13054.13076171875. Best norm loss value is: 12421.927734375.

C_path mean = tensor([[5.2047e+01, 3.2706e-01, 8.3261e-01, 1.9293e-03, 6.1756e-07]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [1.0628e+01, 2.4152e+01, 6.4397e-01, 1.4721e-01, 1.0273e-06],
         [3.7568e+01, 1.0767e+01, 7.3023e-01, 2.5043e-02, 1.0716e-06],
         ...,
         [4.6375e+01, 4.7539e-01, 8.8766e-01, 3.7107e-06, 1.6153e-09],
         [5.2415e+01, 6.1547e-01, 6.9799e-01, 3.4071e-06, 7.2945e-09],
         [4.2168e+01, 6.1374e-01, 5.7035e-01, 1.5616e-06, 2.8878e-06]]],
       grad_fn=<CatBackward>)



Train Diffusion:   1%|          | 212/20000 [14:34<24:43:31,  4.50s/it][A
Train Diffusion:   1%|          | 213/20000 [14:37<22:53:51,  4.17s/it][A
Train Diffusion:   1%|          | 214/20000 [14:41<21:47:54,  3.97s/it][A
Train Diffusion:   1%|          | 215/20000 [14:44<20:49:54,  3.79s/it][A
Train Diffusion:   1%|          | 216/20000 [14:48<20:10:15,  3.67s/it][A
Train Diffusion:   1%|          | 217/20000 [14:51<19:54:14,  3.62s/it][A
Train Diffusion:   1%|          | 218/20000 [14:55<20:38:06,  3.76s/it][A
Train Diffusion:   1%|          | 219/20000 [15:00<22:54:04,  4.17s/it][A
Train Diffusion:   1%|          | 220/20000 [15:04<21:28:26,  3.91s/it][A
Train Diffusion:   1%|          | 221/20000 [15:07<21:01:05,  3.83s/it][A

Moving average norm loss at 220 iterations is: 11608.92236328125. Best norm loss value is: 10885.1884765625.

C_path mean = tensor([[5.2430e+01, 1.9207e-01, 6.5311e-01, 1.6543e-03, 8.8504e-07]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [1.0053e+01, 2.4133e+01, 4.1451e-01, 1.2516e-01, 1.9874e-06],
         [3.6053e+01, 1.0456e+01, 5.0407e-01, 2.1926e-02, 2.0797e-06],
         ...,
         [4.7589e+01, 2.9851e-01, 7.5730e-01, 5.6370e-06, 2.6300e-09],
         [5.3023e+01, 3.6614e-01, 6.0200e-01, 4.6329e-06, 1.4657e-08],
         [4.2553e+01, 3.6584e-01, 5.0170e-01, 2.5718e-06, 3.1599e-06]]],
       grad_fn=<CatBackward>)



Train Diffusion:   1%|          | 222/20000 [15:11<20:25:00,  3.72s/it][A
Train Diffusion:   1%|          | 223/20000 [15:14<19:50:20,  3.61s/it][A
Train Diffusion:   1%|          | 224/20000 [15:18<20:00:43,  3.64s/it][A
Train Diffusion:   1%|          | 225/20000 [15:21<19:50:32,  3.61s/it][A
Train Diffusion:   1%|          | 226/20000 [15:25<19:50:45,  3.61s/it][A
Train Diffusion:   1%|          | 227/20000 [15:28<19:16:55,  3.51s/it][A
Train Diffusion:   1%|          | 228/20000 [15:32<19:35:03,  3.57s/it][A
Train Diffusion:   1%|          | 229/20000 [15:35<19:01:54,  3.47s/it][A
Train Diffusion:   1%|          | 230/20000 [15:39<19:32:19,  3.56s/it][A
Train Diffusion:   1%|          | 231/20000 [15:42<18:57:17,  3.45s/it][A

Moving average norm loss at 230 iterations is: 10381.80244140625. Best norm loss value is: 9774.470703125.

C_path mean = tensor([[5.2555e+01, 1.9265e-01, 6.6079e-01, 1.4299e-03, 8.9386e-07]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [1.0422e+01, 2.3744e+01, 4.3176e-01, 1.4722e-01, 1.4077e-06],
         [3.6448e+01, 1.0034e+01, 5.2847e-01, 1.9123e-02, 1.3385e-06],
         ...,
         [4.8109e+01, 2.5794e-01, 7.5134e-01, 5.2226e-06, 2.8139e-09],
         [5.2885e+01, 3.2547e-01, 6.2468e-01, 4.9443e-06, 1.4729e-08],
         [4.2321e+01, 3.6474e-01, 5.9189e-01, 2.3298e-06, 3.8318e-06]]],
       grad_fn=<CatBackward>)



Train Diffusion:   1%|          | 232/20000 [15:46<18:55:37,  3.45s/it][A
Train Diffusion:   1%|          | 233/20000 [15:49<19:07:12,  3.48s/it][A
Train Diffusion:   1%|          | 234/20000 [15:52<18:38:34,  3.40s/it][A
Train Diffusion:   1%|          | 235/20000 [15:56<18:26:49,  3.36s/it][A
Train Diffusion:   1%|          | 236/20000 [15:59<18:47:51,  3.42s/it][A
Train Diffusion:   1%|          | 237/20000 [16:02<18:34:18,  3.38s/it][A
Train Diffusion:   1%|          | 238/20000 [16:06<18:45:02,  3.42s/it][A
Train Diffusion:   1%|          | 239/20000 [16:09<18:40:17,  3.40s/it][A
Train Diffusion:   1%|          | 240/20000 [16:13<18:49:47,  3.43s/it][A
Train Diffusion:   1%|          | 241/20000 [16:16<18:41:59,  3.41s/it][A

Moving average norm loss at 240 iterations is: 9197.3779296875. Best norm loss value is: 8707.2529296875.

C_path mean = tensor([[5.2587e+01, 1.8292e-01, 7.5473e-01, 1.2059e-03, 7.7908e-07]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [1.0971e+01, 2.3115e+01, 4.7452e-01, 1.6053e-01, 9.1028e-07],
         [3.7419e+01, 9.2644e+00, 5.8487e-01, 1.7132e-02, 9.2303e-07],
         ...,
         [4.8723e+01, 2.7232e-01, 9.4864e-01, 4.2122e-06, 2.4524e-09],
         [5.2841e+01, 3.5585e-01, 7.6424e-01, 3.6264e-06, 9.6728e-09],
         [4.2022e+01, 4.0337e-01, 7.4806e-01, 1.5186e-06, 1.5593e-06]]],
       grad_fn=<CatBackward>)



Train Diffusion:   1%|          | 242/20000 [16:20<18:34:33,  3.38s/it][A
Train Diffusion:   1%|          | 243/20000 [16:23<18:27:25,  3.36s/it][A
Train Diffusion:   1%|          | 244/20000 [16:26<18:21:30,  3.35s/it][A
Train Diffusion:   1%|          | 245/20000 [16:30<18:33:13,  3.38s/it][A
Train Diffusion:   1%|          | 246/20000 [16:33<19:02:38,  3.47s/it][A
Train Diffusion:   1%|          | 247/20000 [16:37<18:44:01,  3.41s/it][A
Train Diffusion:   1%|          | 248/20000 [16:40<19:00:55,  3.47s/it][A
Train Diffusion:   1%|          | 249/20000 [16:44<18:51:08,  3.44s/it][A
Train Diffusion:   1%|▏         | 250/20000 [16:47<18:49:28,  3.43s/it][A
Train Diffusion:   1%|▏         | 251/20000 [16:50<18:49:19,  3.43s/it][A

Moving average norm loss at 250 iterations is: 8295.503662109375. Best norm loss value is: 7786.54736328125.

C_path mean = tensor([[5.2666e+01, 1.5729e-01, 7.9636e-01, 9.3684e-04, 7.5822e-07]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [1.1664e+01, 2.2799e+01, 4.9174e-01, 1.5492e-01, 5.7565e-07],
         [3.7840e+01, 8.8463e+00, 5.9683e-01, 1.2035e-02, 6.0173e-07],
         ...,
         [4.8934e+01, 2.3378e-01, 1.0203e+00, 3.7839e-06, 2.3414e-09],
         [5.2357e+01, 3.4693e-01, 8.2089e-01, 3.1194e-06, 7.1481e-09],
         [4.1618e+01, 4.1944e-01, 8.1118e-01, 1.1047e-06, 8.7689e-07]]],
       grad_fn=<CatBackward>)



Train Diffusion:   1%|▏         | 252/20000 [16:54<18:27:52,  3.37s/it][A
Train Diffusion:   1%|▏         | 253/20000 [16:57<18:38:28,  3.40s/it][A
Train Diffusion:   1%|▏         | 254/20000 [17:00<18:24:20,  3.36s/it][A
Train Diffusion:   1%|▏         | 255/20000 [17:04<18:47:51,  3.43s/it][A
Train Diffusion:   1%|▏         | 256/20000 [17:07<18:50:52,  3.44s/it][A
Train Diffusion:   1%|▏         | 257/20000 [17:11<18:39:47,  3.40s/it][A
Train Diffusion:   1%|▏         | 258/20000 [17:14<18:53:13,  3.44s/it][A
Train Diffusion:   1%|▏         | 259/20000 [17:18<19:15:32,  3.51s/it][A
Train Diffusion:   1%|▏         | 260/20000 [17:22<19:37:24,  3.58s/it][A
Train Diffusion:   1%|▏         | 261/20000 [17:25<19:33:44,  3.57s/it][A

Moving average norm loss at 260 iterations is: 6929.360400390625. Best norm loss value is: 6602.751953125.

C_path mean = tensor([[5.2602e+01, 2.8593e-01, 7.7865e-01, 8.9620e-04, 4.4884e-07]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [1.2633e+01, 2.2981e+01, 5.9527e-01, 1.2021e-01, 4.3105e-07],
         [3.8873e+01, 9.2853e+00, 7.1114e-01, 1.3509e-02, 5.2689e-07],
         ...,
         [4.9305e+01, 3.8010e-01, 8.5510e-01, 2.1052e-06, 1.3119e-09],
         [5.1819e+01, 5.0168e-01, 6.9455e-01, 1.4573e-06, 2.5004e-09],
         [4.1240e+01, 5.3109e-01, 6.4371e-01, 6.1464e-07, 2.0428e-07]]],
       grad_fn=<CatBackward>)



Train Diffusion:   1%|▏         | 262/20000 [17:29<19:29:04,  3.55s/it][A
Train Diffusion:   1%|▏         | 263/20000 [17:32<19:24:00,  3.54s/it][A
Train Diffusion:   1%|▏         | 264/20000 [17:36<19:38:12,  3.58s/it][A
Train Diffusion:   1%|▏         | 265/20000 [17:40<20:16:20,  3.70s/it][A
Train Diffusion:   1%|▏         | 266/20000 [17:44<21:43:01,  3.96s/it][A
Train Diffusion:   1%|▏         | 267/20000 [17:48<20:34:38,  3.75s/it][A
Train Diffusion:   1%|▏         | 268/20000 [17:53<22:42:30,  4.14s/it][A
Train Diffusion:   1%|▏         | 269/20000 [17:57<23:24:04,  4.27s/it][A
Train Diffusion:   1%|▏         | 270/20000 [18:01<22:48:58,  4.16s/it][A
Train Diffusion:   1%|▏         | 271/20000 [18:08<26:28:09,  4.83s/it][A

Moving average norm loss at 270 iterations is: 5950.7455078125. Best norm loss value is: 5645.5390625.

C_path mean = tensor([[5.2868e+01, 1.8376e-01, 7.1560e-01, 5.6888e-04, 5.1396e-07]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [1.2480e+01, 2.2804e+01, 4.7408e-01, 8.0754e-02, 5.7028e-07],
         [3.8099e+01, 8.6629e+00, 5.7585e-01, 7.8656e-03, 6.4896e-07],
         ...,
         [5.0436e+01, 2.5970e-01, 8.3196e-01, 2.2537e-06, 1.6224e-09],
         [5.2305e+01, 3.5384e-01, 6.8081e-01, 1.8352e-06, 2.3127e-09],
         [4.1446e+01, 4.1097e-01, 5.1543e-01, 8.4150e-07, 5.5479e-08]]],
       grad_fn=<CatBackward>)



Train Diffusion:   1%|▏         | 272/20000 [18:14<29:41:04,  5.42s/it][A
Train Diffusion:   1%|▏         | 273/20000 [18:19<27:36:33,  5.04s/it][A
Train Diffusion:   1%|▏         | 274/20000 [18:23<26:20:44,  4.81s/it][A
Train Diffusion:   1%|▏         | 275/20000 [18:26<24:12:58,  4.42s/it][A
Train Diffusion:   1%|▏         | 276/20000 [18:32<26:55:41,  4.91s/it][A
Train Diffusion:   1%|▏         | 277/20000 [18:37<25:40:48,  4.69s/it][A
Train Diffusion:   1%|▏         | 278/20000 [18:41<24:32:01,  4.48s/it][A
Train Diffusion:   1%|▏         | 279/20000 [18:44<23:26:43,  4.28s/it][A
Train Diffusion:   1%|▏         | 280/20000 [18:48<22:32:40,  4.12s/it][A
Train Diffusion:   1%|▏         | 281/20000 [18:52<22:30:10,  4.11s/it][A

Moving average norm loss at 280 iterations is: 5110.672216796875. Best norm loss value is: 4797.53125.

C_path mean = tensor([[5.2789e+01, 2.5744e-01, 7.8926e-01, 4.2247e-04, 4.6835e-07]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [1.3462e+01, 2.2581e+01, 5.9868e-01, 1.1799e-01, 3.0911e-07],
         [3.9092e+01, 8.7597e+00, 7.0879e-01, 7.1284e-03, 2.9012e-07],
         ...,
         [5.0416e+01, 3.1282e-01, 8.6242e-01, 1.6005e-06, 1.4835e-09],
         [5.1518e+01, 4.2486e-01, 7.3308e-01, 1.2417e-06, 2.0775e-09],
         [4.0778e+01, 4.9855e-01, 7.4234e-01, 5.4121e-07, 7.4886e-08]]],
       grad_fn=<CatBackward>)



Train Diffusion:   1%|▏         | 282/20000 [18:56<21:40:17,  3.96s/it][A
Train Diffusion:   1%|▏         | 283/20000 [19:00<22:12:37,  4.06s/it][A
Train Diffusion:   1%|▏         | 284/20000 [19:04<21:18:16,  3.89s/it][A
Train Diffusion:   1%|▏         | 285/20000 [19:07<20:21:22,  3.72s/it][A
Train Diffusion:   1%|▏         | 286/20000 [19:10<19:46:01,  3.61s/it][A
Train Diffusion:   1%|▏         | 287/20000 [19:14<19:20:41,  3.53s/it][A
Train Diffusion:   1%|▏         | 288/20000 [19:17<19:48:30,  3.62s/it][A
Train Diffusion:   1%|▏         | 289/20000 [19:21<19:50:35,  3.62s/it][A
Train Diffusion:   1%|▏         | 290/20000 [19:25<19:56:31,  3.64s/it][A
Train Diffusion:   1%|▏         | 291/20000 [19:28<19:18:52,  3.53s/it][A

Moving average norm loss at 290 iterations is: 4251.831103515625. Best norm loss value is: 3937.6962890625.

C_path mean = tensor([[5.3017e+01, 2.0602e-01, 7.2597e-01, 3.9857e-04, 4.3072e-07]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [1.3627e+01, 2.2606e+01, 5.2177e-01, 7.6192e-02, 3.6945e-07],
         [3.9262e+01, 8.4619e+00, 6.1863e-01, 5.2729e-03, 3.6331e-07],
         ...,
         [5.1152e+01, 2.8157e-01, 7.9616e-01, 1.4448e-06, 1.4604e-09],
         [5.1697e+01, 3.6929e-01, 6.6751e-01, 1.1678e-06, 2.1389e-09],
         [4.0859e+01, 4.4120e-01, 5.4242e-01, 5.6306e-07, 5.9901e-08]]],
       grad_fn=<CatBackward>)



Train Diffusion:   1%|▏         | 292/20000 [19:32<19:15:24,  3.52s/it][A
Train Diffusion:   1%|▏         | 293/20000 [19:35<19:03:27,  3.48s/it][A
Train Diffusion:   1%|▏         | 294/20000 [19:39<19:49:21,  3.62s/it][A
Train Diffusion:   1%|▏         | 295/20000 [19:42<19:29:32,  3.56s/it][A
Train Diffusion:   1%|▏         | 296/20000 [19:46<19:24:05,  3.54s/it][A
Train Diffusion:   1%|▏         | 297/20000 [19:49<18:56:14,  3.46s/it][A
Train Diffusion:   1%|▏         | 298/20000 [19:52<18:32:57,  3.39s/it][A
Train Diffusion:   1%|▏         | 299/20000 [19:56<18:19:28,  3.35s/it][A
Train Diffusion:   2%|▏         | 300/20000 [19:59<18:08:03,  3.31s/it][A
Train Diffusion:   2%|▏         | 301/20000 [20:02<17:59:02,  3.29s/it][A

Moving average norm loss at 300 iterations is: 3621.3041748046876. Best norm loss value is: 3247.361572265625.

C_path mean = tensor([[5.2927e+01, 2.3883e-01, 7.5421e-01, 3.2900e-04, 4.5890e-07]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [1.3937e+01, 2.2025e+01, 5.5439e-01, 1.1353e-01, 2.0864e-07],
         [3.9417e+01, 8.1360e+00, 6.6293e-01, 5.2801e-03, 1.9444e-07],
         ...,
         [5.1284e+01, 2.4871e-01, 8.1964e-01, 1.4158e-06, 1.5967e-09],
         [5.1126e+01, 3.4567e-01, 7.1127e-01, 1.2350e-06, 1.9322e-09],
         [4.0159e+01, 4.5012e-01, 7.0006e-01, 4.5972e-07, 4.4929e-08]]],
       grad_fn=<CatBackward>)



Train Diffusion:   2%|▏         | 302/20000 [20:05<17:57:09,  3.28s/it][A
Train Diffusion:   2%|▏         | 303/20000 [20:08<17:51:55,  3.27s/it][A
Train Diffusion:   2%|▏         | 304/20000 [20:12<17:47:42,  3.25s/it][A
Train Diffusion:   2%|▏         | 305/20000 [20:15<17:48:39,  3.26s/it][A
Train Diffusion:   2%|▏         | 306/20000 [20:18<17:46:57,  3.25s/it][A
Train Diffusion:   2%|▏         | 307/20000 [20:21<17:44:41,  3.24s/it][A
Train Diffusion:   2%|▏         | 308/20000 [20:25<17:45:43,  3.25s/it][A
Train Diffusion:   2%|▏         | 309/20000 [20:28<17:44:35,  3.24s/it][A
Train Diffusion:   2%|▏         | 310/20000 [20:31<17:43:10,  3.24s/it][A
Train Diffusion:   2%|▏         | 311/20000 [20:34<17:42:01,  3.24s/it][A

Moving average norm loss at 310 iterations is: 2973.33447265625. Best norm loss value is: 2674.88623046875.

C_path mean = tensor([[5.3131e+01, 2.4302e-01, 7.3491e-01, 3.5438e-04, 3.7859e-07]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [1.4864e+01, 2.1966e+01, 5.4765e-01, 1.0000e-01, 2.4465e-07],
         [4.0514e+01, 7.9518e+00, 6.5195e-01, 5.6814e-03, 2.4622e-07],
         ...,
         [5.1774e+01, 2.8649e-01, 7.8418e-01, 1.0797e-06, 1.3087e-09],
         [5.1140e+01, 3.8176e-01, 6.8113e-01, 8.3433e-07, 1.6774e-09],
         [4.0095e+01, 4.6138e-01, 6.2250e-01, 3.7894e-07, 4.2048e-08]]],
       grad_fn=<CatBackward>)



Train Diffusion:   2%|▏         | 312/20000 [20:38<17:42:24,  3.24s/it][A
Train Diffusion:   2%|▏         | 313/20000 [20:41<17:42:08,  3.24s/it][A
Train Diffusion:   2%|▏         | 314/20000 [20:44<17:41:06,  3.23s/it][A
Train Diffusion:   2%|▏         | 315/20000 [20:47<17:39:45,  3.23s/it][A
Train Diffusion:   2%|▏         | 316/20000 [20:51<17:43:19,  3.24s/it][A
Train Diffusion:   2%|▏         | 317/20000 [20:54<17:42:24,  3.24s/it][A
Train Diffusion:   2%|▏         | 318/20000 [20:57<17:42:31,  3.24s/it][A
Train Diffusion:   2%|▏         | 319/20000 [21:00<17:40:10,  3.23s/it][A
Train Diffusion:   2%|▏         | 320/20000 [21:04<17:55:27,  3.28s/it][A
Train Diffusion:   2%|▏         | 321/20000 [21:07<18:04:36,  3.31s/it][A

Moving average norm loss at 320 iterations is: 2430.7921508789063. Best norm loss value is: 1995.1014404296875.

C_path mean = tensor([[5.3129e+01, 2.2076e-01, 7.2761e-01, 2.8919e-04, 4.0987e-07]],
       grad_fn=<MeanBackward1>)

C_path = tensor([[[5.3606e+01, 1.9081e-01, 6.7073e-01, 1.3415e-02, 1.1000e-03],
         [1.5559e+01, 2.1467e+01, 5.2392e-01, 1.1259e-01, 2.2170e-07],
         [4.1695e+01, 7.5248e+00, 6.3514e-01, 5.0993e-03, 2.1598e-07],
         ...,
         [5.2018e+01, 2.2838e-01, 7.9066e-01, 1.0901e-06, 1.4303e-09],
         [5.0703e+01, 3.1822e-01, 6.9303e-01, 8.0555e-07, 1.7105e-09],
         [3.9550e+01, 4.1800e-01, 6.7990e-01, 3.5554e-07, 3.4804e-08]]],
       grad_fn=<CatBackward>)



Train Diffusion:   2%|▏         | 322/20000 [21:11<18:59:37,  3.47s/it][A