In [1]:
# Hack to import from a parent directory
import sys
path = '..'
if path not in sys.path:
    sys.path.append(path)

In [2]:
#Python-related imports
import math, sys
from typing import Dict, Tuple, Union
from datetime import datetime
import os.path

#Torch-related imports
import torch
from torch.autograd import Function
from torch import nn
import torch.distributions as D
import torch.nn.functional as F
import torch.optim as optim

#Module imports
from SBM_SDE_classes_optim import *
from obs_and_flow import *
from training import *
from plotting import *
from mean_field import *
from TruncatedNormal import *
from LogitNormal import *

In [3]:
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm
import time
from scipy.optimize import bisect

## Generate data

### Draw $\theta \sim p(\theta)$

In [4]:
seed = 0
temp_ref = 283
temp_rise = 5 #High estimate of 5 celsius temperature rise by 2100.

prior_scale_factor = 0.333

#Parameter prior means
u_M_mean = 0.0016
a_SD_mean = 0.5
a_DS_mean = 0.5
a_M_mean = 0.5
a_MSC_mean = 0.5
k_S_ref_mean = 0.0005
k_D_ref_mean = 0.0008
k_M_ref_mean = 0.0007
Ea_S_mean = 55
Ea_D_mean = 48
Ea_M_mean = 48
c_SOC_mean = 0.5
c_DOC_mean = 0.01
c_MBC_mean = 0.01

#SCON theta logit-normal distribution parameter details in order of mean, sdev, lower, and upper.
u_M_details = torch.Tensor([u_M_mean, u_M_mean * prior_scale_factor, 0, 1])
a_SD_details = torch.Tensor([a_SD_mean, a_SD_mean * prior_scale_factor, 0, 1])
a_DS_details = torch.Tensor([a_DS_mean, a_DS_mean * prior_scale_factor, 0, 1])
a_M_details = torch.Tensor([a_M_mean, a_M_mean * prior_scale_factor, 0, 1])
a_MSC_details = torch.Tensor([a_MSC_mean, a_MSC_mean * prior_scale_factor, 0, 1])
k_S_ref_details = torch.Tensor([k_S_ref_mean, k_S_ref_mean * prior_scale_factor, 0, 1])
k_D_ref_details = torch.Tensor([k_D_ref_mean, k_D_ref_mean * prior_scale_factor, 0, 1])
k_M_ref_details = torch.Tensor([k_M_ref_mean, k_M_ref_mean * prior_scale_factor, 0, 1])
Ea_S_details = torch.Tensor([Ea_S_mean, Ea_S_mean * prior_scale_factor, 10, 100])
Ea_D_details = torch.Tensor([Ea_D_mean, Ea_D_mean * prior_scale_factor, 10, 100])
Ea_M_details = torch.Tensor([Ea_M_mean, Ea_M_mean * prior_scale_factor, 10, 100])

#SCON-C diffusion matrix parameter distribution details
c_SOC_details = torch.Tensor([c_SOC_mean, c_SOC_mean * prior_scale_factor, 0, 1])
c_DOC_details = torch.Tensor([c_DOC_mean, c_DOC_mean * prior_scale_factor, 0, 1])
c_MBC_details = torch.Tensor([c_MBC_mean, c_MBC_mean * prior_scale_factor, 0, 1])

priors = {'u_M': u_M_details, 'a_SD': a_SD_details, 'a_DS': a_DS_details, 'a_M': a_M_details, 'a_MSC': a_MSC_details, 'k_S_ref': k_S_ref_details, 'k_D_ref': k_D_ref_details, 'k_M_ref': k_M_ref_details, 'Ea_S': Ea_S_details, 'Ea_D': Ea_D_details, 'Ea_M': Ea_M_details,
          'c_SOC': c_SOC_details, 'c_DOC': c_DOC_details, 'c_MBC': c_MBC_details}

In [5]:
def find_scale(scale, loc, a, b, target_sd):
    x = RescaledLogitNormal(loc, scale, a, b)
    #print(scale, x.mean, x.stddev)
    return x.stddev - target_sd

In [6]:
def sample_theta(priors):
    torch.manual_seed(0)
    scale_lower = 1e-8 #Lower bound for scale search by bisect function.
    scale_upper = 100 #Upper bound for scale search by bisect function. 
    
    theta_hyperparams = {} # hyperparams
    theta_samples = {} # theta samples
    for k, v in priors.items():
        sigmoid_loc, target_sd, a, b = v
        loc = logit(sigmoid_loc, a, b)
        scale = bisect(find_scale, scale_lower, scale_upper, (loc, a, b, target_sd))
        dist = RescaledLogitNormal(loc, scale, a, b)
        assert torch.abs(dist.stddev - target_sd) < 1e-6
        
        theta_hyperparams[k] = torch.tensor((loc, scale, a, b))
        theta_samples[k] = dist.sample()
        
    return theta_hyperparams, theta_samples

In [7]:
theta_hyperparams, theta_samples = sample_theta(priors)

In [9]:
hyperparams_file = 'data/dt_1_t_100000_n_1_minibatch/theta_from_multi_x_hyperparams_scon_c_{}.pt'.format(0)
theta_file = 'data/dt_1_t_100000_n_1_minibatch/theta_from_multi_x_theta_scon_c_{}.pt'.format(0)
theta_hyperparams1 = torch.load(hyperparams_file)
theta_samples1 = torch.load(theta_file)
print(theta_hyperparams1)

{'u_M': tensor([-6.4362,  0.3104,  0.0000,  1.0000]), 'a_SD': tensor([0.0000, 0.7465, 0.0000, 1.0000]), 'a_DS': tensor([0.0000, 0.7465, 0.0000, 1.0000]), 'a_M': tensor([0.0000, 0.7465, 0.0000, 1.0000]), 'a_MSC': tensor([0.0000, 0.7465, 0.0000, 1.0000]), 'k_S_ref': tensor([-7.6004,  0.3100,  0.0000,  1.0000]), 'k_D_ref': tensor([-7.1301,  0.3101,  0.0000,  1.0000]), 'k_M_ref': tensor([-7.2637,  0.3101,  0.0000,  1.0000]), 'Ea_S': tensor([  0.0000,   0.9685,  10.0000, 100.0000]), 'Ea_D': tensor([ -0.3137,   0.8252,  10.0000, 100.0000]), 'Ea_M': tensor([ -0.3137,   0.8252,  10.0000, 100.0000]), 'c_SOC': tensor([0.0000, 0.7465, 0.0000, 1.0000]), 'c_DOC': tensor([-4.5951,  0.3137,  0.0000,  1.0000]), 'c_MBC': tensor([-4.5951,  0.3137,  0.0000,  1.0000])}


In [11]:
[theta_hyperparams[k] == theta_hyperparams1[k] for k in theta_hyperparams.keys()]

[tensor([True, True, True, True]),
 tensor([True, True, True, True]),
 tensor([True, True, True, True]),
 tensor([True, True, True, True]),
 tensor([True, True, True, True]),
 tensor([ True, False,  True,  True]),
 tensor([ True, False,  True,  True]),
 tensor([ True, False,  True,  True]),
 tensor([True, True, True, True]),
 tensor([True, True, True, True]),
 tensor([True, True, True, True]),
 tensor([True, True, True, True]),
 tensor([True, True, True, True]),
 tensor([True, True, True, True])]

In [13]:
save_dir = 'data/dt_01_t_100000_minibatch_logit_theta_trunc_trans/SCON_C'
target_hyperparams_file = '{}_target_hyperparams.pt'.format(save_dir)
torch.save(priors, target_hyperparams_file)

### Draw $x \sim p(x|\theta)$

In [None]:
num_sequences = 1
dt = 1.0
t = 1000000
x0_SCON = [65, 0.4, 2.5]
x0_scale = 0.25

In [None]:
#Generate data from SBM SDEs
#x in order of SOC, DOC, MBC (and EEC for AWB family models)

def alpha_SCON_multi(x, SCON_params_dict, I_S, I_D, current_temp, temp_ref, arrhenius_temp, linear_temp):
    #Partition SOC, DOC, and MBC values.
    state_dim = 3
    SOC, DOC, MBC = torch.chunk(x, state_dim, 1)
    
    #Force temperature-dependent parameters.
    k_S = arrhenius_temp(SCON_params_dict['k_S_ref'], current_temp, SCON_params_dict['Ea_S'], temp_ref)
    k_D = arrhenius_temp(SCON_params_dict['k_D_ref'], current_temp, SCON_params_dict['Ea_D'], temp_ref)
    k_M = arrhenius_temp(SCON_params_dict['k_M_ref'], current_temp, SCON_params_dict['Ea_M'], temp_ref)
    
    #Evolve drift.
    drift_SOC = I_S + SCON_params_dict['a_DS'] * k_D * DOC + SCON_params_dict['a_M'] * SCON_params_dict['a_MSC'] * k_M * MBC - k_S * SOC
    drift_DOC = I_D + SCON_params_dict['a_SD'] * k_S * SOC + SCON_params_dict['a_M'] * (1 - SCON_params_dict['a_MSC']) * k_M * MBC - (SCON_params_dict['u_M'] + k_D) * DOC
    drift_MBC = SCON_params_dict['u_M'] * DOC - k_M * MBC
    
    return torch.cat([drift_SOC, drift_DOC, drift_MBC], 1)

def beta_SCON_C_multi(x, SCON_C_params_dict):
    b11 = SCON_C_params_dict['c_SOC']
    b22 = SCON_C_params_dict['c_DOC']
    b33 = SCON_C_params_dict['c_MBC']
    b_matrix = torch.diag_embed(torch.stack([b11, b22, b33])) 
    return b_matrix

In [None]:
def generate_x(BATCH_SIZE, ALPHA, BETA, X0_LOC, X0_SCALE, T, DT, THETA_DICT, I_S_FUNC, I_D_FUNC, TEMP_FUNC, TEMP_REF, TEMP_RISE, lower_bound = 1e-4):
    torch.manual_seed(seed)
    if ALPHA == alpha_SCON_multi:
        state_dim = 3
    elif ALPHA == alpha_SAWB_multi:
        state_dim = 4
    elif ALPHA == alpha_SAWB_ECA_multi:
        state_dim = 4
        
    N = int(T / DT) + 1
    x = torch.zeros([BATCH_SIZE, N, state_dim])
    
    # Draw initial condition x0
    X0_LOC = torch.as_tensor(X0_LOC)
    p_x0 = D.multivariate_normal.MultivariateNormal(loc = X0_LOC,
                                                    scale_tril = torch.diag(X0_LOC * X0_SCALE))
    x0_samples = p_x0.sample((BATCH_SIZE, )) # (batch_size, state_dim)
    #x0_samples[x0_samples < lower_bound] = lower_bound #Bound initial conditions above 0. 
    print('X0_samples = ', x0_samples)
    x[:, 0, :] = x0_samples
    
    # Vectorize variable calculations where possible
    hours = torch.tensor(np.linspace(0, T, N), dtype=torch.float) # 0
    I_S_tensor = I_S_FUNC(hours)
    I_D_tensor = I_D_FUNC(hours)
    temps = TEMP_FUNC(hours, TEMP_REF, TEMP_RISE)
    
    #Take Euler-Maruyama step. 
    for i in range(1, N):
        # Define x_i distribution
        a = ALPHA(x[:, i - 1, :], THETA_DICT, I_S_tensor[i], I_D_tensor[i], temps[i], TEMP_REF, arrhenius_temp_dep, linear_temp_dep)
        b = BETA(x[:, i - 1, :], THETA_DICT)
        p_x_i = D.multivariate_normal.MultivariateNormal(loc = x[:, i - 1, :] + a * DT, covariance_matrix = b * DT)
        
        # Draw sample
        x[:, i, :] = p_x_i.sample()        
        #x[:, i, :][x[:, i, :] < lower_bound] = lower_bound #Bound all x above 0.
    
    return x, p_x0

In [None]:
t0 = time.time()
x, p_x0 = generate_x(num_sequences, alpha_SCON_multi, beta_SCON_C_multi, x0_SCON, x0_scale, t, dt, theta_samples, i_s, i_d, temp_gen, temp_ref, temp_rise)
print(x.shape, time.time() - t0)

In [None]:
x_file = 'data/dt_1_t_100000_n_1_minibatch/theta_from_multi_x_x_scon_c_{}.pt'.format(0)
x = torch.load(x_file)
x0_loc = torch.as_tensor(x0_SCON)
p_x0 = D.multivariate_normal.MultivariateNormal(loc = x0_loc,
                                                scale_tril = torch.diag(x0_loc * x0_scale))
print(x.shape)

## Inference

Let $x$ be the observed states (i.e. SOC, DOC, and MBC in the SCON-SS model) and $\theta$ be the model parameters. The generative model is:
- $\theta \sim \text{LogitNormal}(\mu_\theta, \sigma_\theta, a, b)$
- $x_1 \sim p(x_1|\theta)$
- For $i=2, ..., T$: $x_i \sim p(x_i|x_{i-1}, \theta)$

We use variational posterior $q_\phi(\theta)=\text{MultivariateLogitNormal}(\hat{\mu}, \hat{L}\hat{L}^T, a, b)$ with variational parameter $\phi=(\hat{\mu}, \hat{L})$. For minibatching, we partition the sequence $x$ into $B$ consecutive subsequences $x_b$. We sample a minibatch by drawing $b$ uniformly from $1, ..., B$. Let $u_b$ and $v_b$ denote the endpoints of the $b$th subsequence, $x_b = x_{u_b:v_b}$. The minibatch loss is:
$$
\mathcal{L} = E_{q_\phi(\theta)}\left[ \log p(\theta) - \log q(\theta) + \log p(x_b|\theta) \right] \\
$$

where the log likelihood term is given by:
$$
\log p(x_b|\theta) = \sum_{i=u_b}^{v_b} \log p(x_i|x_{i-1}, \theta)
$$
if $u_b > 1$, otherwise: $\log p(x_b|\theta) = \log p(x_1|\theta) + \sum_{i=u_b+1}^{v_b} \log p(x_i|x_{i-1}, \theta)$.

Reparameterized gradient wrt variational parameter $\phi$:
$$
\nabla_\phi \mathcal{L} \approx -\frac{1}{S} \sum_s \nabla_\phi \left[
    \log p(g_\phi(\epsilon^{(s)})) + \log p(x_b|g_\phi(\epsilon^{(s)})) - \log q(g_\phi(\epsilon^{(s)}))
    \right]
$$

where $\epsilon^{(s)} \sim \text{MultivariateNormal}(0, I)$ and $g_\phi(\epsilon)=(b-a) \odot \sigma(\hat{\mu}+\hat{L}\epsilon)+a$, with $\odot$ to denote elementwise multiplication.

Pseudocode:
- for each iteration $n=1, ..., N$:
  - Sample minibatch: $b \sim \text{Uniform}(1, ..., B)$
  - Sample $\theta$: $\theta^{(s)} \sim q(\theta)$ for $s=1, ..., S$
  - `loss` $= -\frac{1}{S} \sum_s \left[\log p(\theta^{(s)}) + \log p(x_b|\theta^{(s)}) - \log q(\theta^{(s)})\right]$ 
  - `loss.backward()`

**Implementation note:** The covariance matrix $\Sigma$ of the multivariate normal in PyTorch is parameterized in terms of a lower-triangular matrix $L$ with positive-valued diagonal entries, such that $\Sigma = LL^T$. In practice, we optimize the unconstrained transformation of $L$. The following code shows how the mapping is done back and forth:
```
def to_constrained(self, x):
    # Takes unconstrained square matrix x and returns L, s.t. Sigma = LL^T
    return x.tril(-1) + x.diagonal(dim1=-2, dim2=-1).exp().diag_embed()

def to_unconstrained(self, y):
    return y.tril(-1) + y.diagonal(dim1=-2, dim2=-1).log().diag_embed()
```

In [None]:
dt_flow = 1.0 # [1.0, 0.5]

In [None]:
active_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
torch.set_printoptions(precision = 8)

#Neural SDE parameters
n = int(t / dt_flow) + 1
t_span = np.linspace(0, t, n)
t_span_tensor = torch.reshape(torch.Tensor(t_span), [1, n, 1]).to(active_device) #T_span needs to be converted to tensor object. Additionally, facilitates conversion of I_S and I_D to tensor objects.

#Specify desired SBM SDE model type and details.
SBM_SDE_class = 'SCON'
diffusion_type = 'C'
theta_dist = 'RescaledLogitNormal' #String needs to be exact name of the distribution class. Other option is 'RescaledLogitNormal'.
theta_post_dist = 'MultivariateLogitNormal'

#Generate exogenous input vectors.
#Obtain temperature forcing function.
temp_tensor = temp_gen(t_span_tensor, temp_ref, temp_rise).to(active_device)

#Obtain SOC and DOC pool litter input vectors for use in flow SDE functions.
i_s_tensor = i_s(t_span_tensor).to(active_device) #Exogenous SOC input function
i_d_tensor = i_d(t_span_tensor).to(active_device) #Exogenous DOC input function

In [None]:
def calc_log_lik(C_PATH, PARAMS_DICT, DT, SBM_SDE_CLASS, INIT_PRIOR, start_idx, end_idx, calc_full=False):      
    # Compute drift and diffusion
    if calc_full:
        drift, diffusion_sqrt = SBM_SDE_CLASS.drift_diffusion(C_PATH, PARAMS_DICT)
        p_x = D.multivariate_normal.MultivariateNormal(loc = C_PATH[:, :-1, :] + drift * DT,
                                                       scale_tril = diffusion_sqrt * math.sqrt(DT)) 
    
        # Compute log p(x|theta) = log p(x|x0, theta) + log p(x0|theta)
        log_p_x = p_x.log_prob(C_PATH[:, 1:, :]) # log p(x_i|x_{i-1}, theta), (batch_size, N - 1)
        log_p_x0 = INIT_PRIOR.log_prob(C_PATH[:, 0, :]) # log p(x0|theta), (batch_size, )
        ll_minibatch = log_p_x[:, start_idx:end_idx-1].sum(-1)
        if start_idx == 0:
            ll_minibatch += log_p_x0
        ll_full = log_p_x.sum(-1) + log_p_x0 # log p(x|theta), (batch_size, )
    
    else:
        drift, diffusion_sqrt = SBM_SDE_CLASS.drift_diffusion(C_PATH, PARAMS_DICT, start_idx, end_idx) 
        p_x = D.multivariate_normal.MultivariateNormal(loc = C_PATH[:, start_idx:end_idx-1, :] + drift * DT,
                                                       scale_tril = diffusion_sqrt * math.sqrt(DT))
    
        # Compute log p(x|theta) = log p(x|x0, theta) + log p(x0|theta)
        ll_minibatch = p_x.log_prob(C_PATH[:, start_idx+1:end_idx, :]).sum(-1) # log p(x|x0, theta)
        if start_idx == 0:
            ll_minibatch += INIT_PRIOR.log_prob(C_PATH[:, 0, :]) # log p(x0|theta)
        ll_full = None
    
    return ll_minibatch, ll_full

def train(DEVICE, LR, NITER, BATCH_SIZE, MINIBATCH_SIZE, X_ALL, T, DT, N,
          T_SPAN_TENSOR, I_S_TENSOR, I_D_TENSOR, TEMP_TENSOR, TEMP_REF,
          SBM_SDE_CLASS, DIFFUSION_TYPE, X0_PRIOR, PRIOR_DIST_DETAILS_DICT, 
          THETA_DIST = None, THETA_POST_DIST = None, THETA_POST_INIT = None,
          LR_DECAY = 0.8, DECAY_STEP_SIZE = 50000, PRINT_EVERY = 100, MINIBATCH_INDICES=None,
          CALC_LOSS_EVERY=100):
    torch.manual_seed(seed)
    
    # Instantiate SBM_SDE object based on specified model and diffusion type.
    SBM_SDE_class_dict = {
            'SCON': SCON_optim,
            'SAWB': SAWB_optim,
            'SAWB-ECA': SAWB_ECA_optim
            }
    if SBM_SDE_CLASS not in SBM_SDE_class_dict:
        raise NotImplementedError('Other SBM SDEs aside from SCON, SAWB, and SAWB-ECA have not been implemented yet.')
    SBM_SDE_class = SBM_SDE_class_dict[SBM_SDE_CLASS]
    SBM_SDE = SBM_SDE_class(T_SPAN_TENSOR, I_S_TENSOR, I_D_TENSOR, TEMP_TENSOR, TEMP_REF, DIFFUSION_TYPE)

    # Load x, exclude CO2 and extra time steps 
    # (since data generation process uses smaller dt than dt_flow)
    #x_all = torch.load(X_FILE)
    step = (X_ALL.shape[1] - 1) // (N - 1)
    x = X_ALL[:, ::step, :SBM_SDE.state_dim]
    #x = x.T.expand(BATCH_SIZE, -1, -1)
    assert x.shape == (1, N, SBM_SDE.state_dim)

    # Convert prior details dictionary values to tensors.
    param_names = list(PRIOR_DIST_DETAILS_DICT.keys())
    prior_list = list(zip(*(PRIOR_DIST_DETAILS_DICT[k] for k in param_names))) #Unzip prior distribution details from dictionary values into individual lists.
    prior_means_tensor, prior_sds_tensor, prior_lowers_tensor, prior_uppers_tensor = torch.tensor(prior_list).to(DEVICE) #Ensure conversion of lists into tensors.

    # Retrieve desired distribution class based on string.
    dist_class_dict = {
            'TruncatedNormal': TruncatedNormal,
            'RescaledLogitNormal': RescaledLogitNormal,
            'MultivariateLogitNormal': MultivariateLogitNormal
            }
    THETA_PRIOR_CLASS = dist_class_dict[THETA_DIST]
    THETA_POST_CLASS = dist_class_dict[THETA_POST_DIST] if THETA_POST_DIST else dist_class_dict[THETA_DIST]
    
    # Define prior
    p_theta = THETA_PRIOR_CLASS(loc = prior_means_tensor, scale = prior_sds_tensor, a = prior_lowers_tensor, b = prior_uppers_tensor)

    # Initialize posterior q(theta) using its prior p(theta)
    learn_cov = (THETA_POST_DIST == 'MultivariateLogitNormal')
    if THETA_POST_INIT is None:
        THETA_POST_INIT = PRIOR_DIST_DETAILS_DICT
    q_theta = MeanField(DEVICE, param_names, THETA_POST_INIT, THETA_POST_CLASS, learn_cov)

    #Record loss throughout training (needs moving average bc of minibatching)
    #best_loss = 1e15
    losses = []
    losses_full = []
    times = []

    #Initiate optimizers.
    optimizer = optim.Adamax(list(q_theta.parameters()), lr = LR)
    
    # Sample minibatch endpoints
    if MINIBATCH_SIZE < N:
        if MINIBATCH_INDICES is None:
            step = (N - MINIBATCH_SIZE) // NITER
            MINIBATCH_INDICES = torch.arange(0, (N - MINIBATCH_SIZE), step)
        rand = torch.randint(len(MINIBATCH_INDICES), (NITER, ))
        print(torch.min(torch.bincount(rand)))
        batch_indices = MINIBATCH_INDICES[rand]
    
    #Training loop
    t0 = time.time()
    with tqdm(total = NITER, desc = f'Learning SDE and hidden parameters.', position = -1) as tq:
        for it in range(1, NITER + 1):
            optimizer.zero_grad()    
            
            # Sample theta ~ q(theta) and compute log q(theta)
            theta_dict, theta, log_q_theta, _ = q_theta(BATCH_SIZE)
            
            # Compute log p(theta)
            log_p_theta = p_theta.log_prob(theta).sum(-1)

            # Compute log p(x_{u-1:v}|theta) (unless u = 0, then x_{u:v})
            if MINIBATCH_SIZE < N:
                start_idx = max(0, batch_indices[it-1] - 1)              # u-1 if u > 0, else 0
                end_idx = min(N, batch_indices[it-1] + MINIBATCH_SIZE)   # v
            else:
                start_idx, end_idx = 0, N
            #print(batch_id, start_idx, end_idx)
            calc_loss = (it % CALC_LOSS_EVERY == 0) or (it == 1)
            ll, ll_full = calc_log_lik(x, theta_dict, DT, SBM_SDE, X0_PRIOR,
                                       start_idx, end_idx, calc_loss)

            # Compute negative ELBO: -(log p(theta) + log p(x|theta) - log q(theta))
            loss = -log_p_theta.mean() - N/MINIBATCH_SIZE * ll.mean() + log_q_theta.mean()
            losses.append(loss.item())
            if calc_loss:
                loss_full = -log_p_theta.mean() - ll_full.mean() + log_q_theta.mean()
                losses_full.append(loss_full.item())

            # Take a gradient step
            loss.backward()
            #torch.nn.utils.clip_grad_norm_(ELBO_params, 5.0)
            optimizer.step()
            
            # Record time
            times.append(time.time() - t0)
            
            if it % PRINT_EVERY == 0:
                if calc_loss:
                    print('Iteration {} loss: {}'.format(it, loss_full))
                else:
                    k = 1000
                    ma_loss = losses[it-k-1:it-1].mean()
                    print('Iteration {} moving average loss: {}'.format(it, ma_loss))
        
            if it % DECAY_STEP_SIZE == 0:
                optimizer.param_groups[0]['lr'] *= LR_DECAY

            tq.update()
    
    return q_theta, p_theta, losses, losses_full, times


In [None]:
#Training parameters
niter = 200000 # niter = epochs * (n - 1) / minibatch_size
train_lr = 0.01 #ELBO learning rate
batch_size = 40 #3 - number needed to fit UCI HPC3 RAM requirements with 16 GB RAM at t = 5000.
#minibatch_size = 1000
#minibatch_indices = torch.arange(0, t + 1 - minibatch_size, 100) + 1

In [None]:
size5k = 5000
indices5k = torch.arange(0, t + 1 - size5k, size5k) + 1
#indices5k

In [None]:
size1k = 1000
indices1k = torch.arange(0, t + 1 - size1k, size1k) + 1
#indices1k

In [None]:
#Call training loop function for SCON-C.
t0 = time.time()
q_theta_raw_mini5k, p_theta_mini5k, losses_noisy_mini5k, losses_mini5k, times_mini5k = train(
        active_device, train_lr, niter, batch_size, size5k, x, t, dt_flow, n, 
        t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref,
        SBM_SDE_class, diffusion_type, p_x0, theta_hyperparams,
        THETA_DIST=theta_dist, THETA_POST_DIST=theta_post_dist,
        LR_DECAY = 1.0, PRINT_EVERY = max(1, niter // 10), MINIBATCH_INDICES=indices5k)
print(time.time() - t0)

In [None]:
q_theta_raw, p_theta, losses_noisy, losses, times = q_theta_raw_mini5k, p_theta_mini5k, losses_noisy_mini5k, losses_mini5k, times_mini5k

In [None]:
#Call training loop function for SCON-C.
t0 = time.time()
q_theta_raw_mini1k, p_theta_mini1k, losses_noisy_mini1k, losses_mini1k, times_mini1k = train(
        active_device, train_lr, niter, batch_size, size1k, x, t, dt_flow, n, 
        t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref,
        SBM_SDE_class, diffusion_type, p_x0, theta_hyperparams,
        THETA_DIST=theta_dist, THETA_POST_DIST=theta_post_dist,
        LR_DECAY = 1.0, PRINT_EVERY = max(1, niter // 10), MINIBATCH_INDICES=indices1k)
print(time.time() - t0)

In [None]:
#Call training loop function for SCON-C.
t0 = time.time()
q_theta_raw_mf_mini5k, p_theta_mf_mini5k, losses_mf_noisy_mini5k, losses_mf_mini5k, times_mf_mini5k = train(
        active_device, train_lr, niter, batch_size, size5k, x, t, dt_flow, n, 
        t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref,
        SBM_SDE_class, diffusion_type, p_x0, theta_hyperparams,
        THETA_DIST=theta_dist, THETA_POST_DIST=theta_dist,
        LR_DECAY = 1.0, PRINT_EVERY = max(1, niter // 10), MINIBATCH_INDICES=indices5k)
print(time.time() - t0)

In [None]:
q_theta_raw_mf, p_theta_mf, losses_mf_noisy, losses_mf, times_mf = q_theta_raw_mf_mini5k, p_theta_mf_mini5k, losses_mf_noisy_mini5k, losses_mf_mini5k, times_mf_mini5k

In [None]:
#Call training loop function for SCON-C.
t0 = time.time()
q_theta_raw_mf_mini1k, p_theta_mf_mini1k, losses_mf_noisy_mini1k, losses_mf_mini1k, times_mf_mini1k = train(
        active_device, train_lr, niter, batch_size, size1k, x, t, dt_flow, n, 
        t_span_tensor, i_s_tensor, i_d_tensor, temp_tensor, temp_ref,
        SBM_SDE_class, diffusion_type, p_x0, theta_hyperparams,
        THETA_DIST=theta_dist, THETA_POST_DIST=theta_dist,
        LR_DECAY = 1.0, PRINT_EVERY = max(1, niter // 10), MINIBATCH_INDICES=indices1k)
print(time.time() - t0)

In [None]:
def moving_average(x, dim=-1, n=1000):
    cumsum = torch.cumsum(torch.as_tensor(x), dim)
    avg0 = cumsum[:n-1]/torch.arange(1, n)
    avg = (cumsum[n-1:] - cumsum[:-(n-1)])/n
    return torch.cat((avg0, avg)).detach()

In [None]:
losses_ma = moving_average(losses)
losses_mf_ma = moving_average(losses_mf)

## Visualizations

In [None]:
def plot_loss(loss_hist_list, labels, xscale='linear', ymin=None, ymax=None,
              colors=None, linestyles=None, xvals=None, xlabel='iteration'):
    plt.rcParams.update({'font.size': 12, 'lines.linewidth': 2, 'figure.figsize': (8, 6)})
    if colors is None:
        colors = [cm.tab10(i+1) for i in range(len(labels))]
    if linestyles is None:
        linestyles = ['-'] * len(labels)
    if xvals is None:
        xvals = [None] * len(labels)
    
    for loss_hist, x, label, c, s in zip(loss_hist_list, xvals, labels, colors, linestyles):
        if x is None:
            plt.plot(loss_hist, label=label, color=c, linestyle=s)
        else:
            plt.plot(x, loss_hist, label=label, color=c, linestyle=s)
    
    plt.xlabel(xlabel)
    #plt.title('Loss v iteration')
    plt.ylabel('loss')
    plt.legend()
    plt.xscale(xscale)
    plt.ylim((ymin, ymax))

In [None]:
labels = ['full-rank', 'mean-field']
plot_loss([losses, losses_mf], labels, ymax=-615000, ymin=-625000)

In [None]:
# epoch = iter * minibatch_size / (n - 1)
epochs = [torch.arange(0, niter + 1, 100) * size5k / (n - 1)] * 2
plot_loss([losses, losses_mf], labels, ymax=-615000, ymin=-625000, xvals=epochs, xlabel='epoch', xscale='symlog')

In [None]:
time_indices = torch.cat((torch.tensor([0]), torch.arange(99, niter + 1, 100)))
time_vals = [torch.tensor(times)[time_indices], torch.tensor(times_mf)[time_indices]]
plot_loss([losses, losses_mf], labels, ymax=-615000, ymin=-625000, xvals=time_vals, xlabel='time', xscale='symlog')

In [None]:
labels = ['full-rank (full sequence)', 'mean-field (full sequence)',
          'full-rank (minibatch size = 5,000)', 'mean-field (minibatch size = 5,000)',
          'full-rank (minibatch size = 1,000)', 'mean-field (minibatch size = 1,000)']
colors =  [cm.tab10(i) for i in [1, 2, 1, 2, 1, 2]]
linestyles = ['-', '-', '--', '--', ':', ':']

epochs = [None] * 2 + \
         [torch.arange(0, niter + 1, 100) * size5k / (n - 1)] * 2 + \
         [torch.arange(0, niter + 1, 100) * size1k / (n - 1)] * 2
plot_loss([losses_full, losses_mf_full, losses_mini5k, losses_mf_mini5k, losses_mini1k, losses_mf_mini1k],
          labels, ymax=-61000, ymin=-62200, 
          colors=colors, linestyles=linestyles, xvals=epochs, xscale='symlog', xlabel='epoch')

In [None]:
time_indices = torch.cat((torch.tensor([0]), torch.arange(99, niter + 1, 100) ))
time_vals = [times_full, times_mf_full] + [torch.tensor(times)[time_indices] for times in [times_mini5k, times_mf_mini5k, times_mini1k, times_mf_mini1k]]

plot_loss([losses_full, losses_mf_full, losses_mini5k, losses_mf_mini5k, losses_mini1k, losses_mf_mini1k],
          labels, ymax=-61000, ymin=-62200, xlabel='time', xscale='symlog',
          colors=colors, linestyles=linestyles, xvals=time_vals)

In [None]:
labels = ['full-rank (minibatch overlap)', 'mean-field (minibatch overlap)',
          'full-rank (minibatch no overlap)', 'mean-field (minibatch no overlap)']
colors =  [cm.tab10(i) for i in [1, 2, 1, 2]]
linestyles = ['-', '-', '--', '--']

xvals = [torch.arange(0, niter + 1, 100)] * 4
plot_loss([losses_mini5k_over, losses_mf_mini5k_over, losses_mini5k, losses_mf_mini5k],
          labels, ymax=-61000, ymin=-62200,
          colors=colors, linestyles=linestyles, xvals=xvals)

In [None]:
# Extracts the distribution from a MeanField object
def extract_dist(q):
    a, b = q.lowers, q.uppers
    loc = q.means
    if not q.learn_cov:
        scale = torch.max(q.sds, torch.ones_like(q.sds) * 1e-8)
        #scale = D.transform_to(q.dist.arg_constraints['scale'])(q.sds)
        return q.dist(loc, scale=scale, a=a, b=b)
    else:
        scale = D.transform_to(q.dist.arg_constraints['scale_tril'])(q.sds)
        return q.dist(loc, scale_tril=scale, a=a, b=b)

In [None]:
q_theta_mini5k = extract_dist(q_theta_raw_mini5k)
q_theta_mf_mini5k = extract_dist(q_theta_raw_mf_mini5k)
q_theta_mini1k = extract_dist(q_theta_raw_mini1k)
q_theta_mf_mini1k = extract_dist(q_theta_raw_mf_mini1k)

In [None]:
def plot_theta(p_theta, q_theta_list, theta, labels, param_names, num_pts=1000, eps=1e-5, ncols=4,
               colors=None, linestyles=None, device=active_device):
    plt.rcParams.update({'font.size': 16, 'lines.linewidth': 2})
    
    # Load posterior and define plot boundaries
    a, b = p_theta.a, p_theta.b
    x0 = p_theta.mean - 4*p_theta.stddev
    x1 = p_theta.mean + 4*p_theta.stddev
    q_marginals = []
    for q_theta in q_theta_list:
        if isinstance(q_theta, RescaledLogitNormal):
            q_marginal = q_theta
        else:
            scale_post = torch.diag(q_theta.covariance_matrix).sqrt()
            q_marginal = RescaledLogitNormal(q_theta.loc, scale_post, a=a, b=b)
        q_marginals.append(q_marginal)
        x0 = torch.fmin(x0, q_marginal.mean - 4*q_marginal.stddev)
        x1 = torch.fmax(x1, q_marginal.mean + 4*q_marginal.stddev)
        #print(x0, x1)
    x0 = torch.fmax(x0, a).detach()
    x1 = torch.fmin(x1, b).detach()
    x = torch.from_numpy(np.linspace(x0, x1, num_pts))
    
    # Load true theta
    #theta = torch.load(theta_file, map_location=device)
    
    # Compute pdfs
    #print(x[0, :], x[-1, :])
    prior_pdf = torch.exp(p_theta.log_prob(x)).detach()
    post_pdfs = []
    for q_theta in q_marginals:
        post_pdf = torch.exp(q_theta.log_prob(x)).detach()
        post_pdfs.append(post_pdf)
    
    # Plot prior v posterior v true theta
    num_params = len(param_names)
    nrows = int(num_params / ncols) + 1
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 4, nrows * 4))
    axes = np.atleast_2d(axes)
    k = 0
    if colors is None: colors = [cm.tab10(i+1) for i in range(len(post_pdfs))]
    if linestyles is None: linestyles = ['-'] * len(post_pdfs)
    for i, row in enumerate(axes):
        for j, ax in enumerate(row):
            if k < num_params:
                key = param_names[k]
                ax.plot(x[:, k], prior_pdf[:, k], label='Prior', color='tab:blue')
                for post_pdf, post_dist, c, l in zip(post_pdfs, labels, colors, linestyles):
                    label = 'Posterior {}'.format(post_dist)
                    ax.plot(x[:, k], post_pdf[:, k], label=label, color=c, linestyle=l)
                ax.axvline(theta[key], color='gray', label='True $\\theta$')
                ax.set_xlabel(key)
                if j == 0: ax.set_ylabel('density')
            elif k == num_params:
                handles, labels = axes[0, 0].get_legend_handles_labels()
                ax.legend(handles, labels, loc='center')
                ax.axis('off')
            else:
                fig.delaxes(axes[i, j])
            k += 1  
    plt.tight_layout()
    plt.suptitle('Marginal distributions')
    plt.subplots_adjust(top=0.95)
    plt.show()

In [None]:
labels = ['full-rank (full)', 'mean-field (full)',
          'full-rank (mini = 5,000)', 'mean-field (mini = 5,000)',
          'full-rank (mini = 1,000)', 'mean-field (mini = 1,000)']
plot_theta(p_theta_mini5k, [q_theta_full, q_theta_mf_full, q_theta_mini5k, q_theta_mf_mini5k, q_theta_mini1k, q_theta_mf_mini1k],
           theta_samples, labels, q_theta_raw_mini5k.keys, colors=colors, linestyles=linestyles)

In [None]:
plot_theta(p_theta, [q_theta, q_theta_mf],
           theta_samples, labels, q_theta_raw.keys)

In [None]:
def plot_corr(q_theta_list, labels, param_names, num_samples=100000):
    plt.rcParams.update({'font.size': 12})
    
    # Calculate empirical correlation
    corr_list = []
    for q_theta in q_theta_list:
        assert isinstance(q_theta, MultivariateLogitNormal)
        samples = q_theta.sample((num_samples, )) # (N, D)
        corr_mc = np.corrcoef(samples.T)
        corr_list.append(corr_mc)
    
    # Plot
    num_cols = len(q_theta_list)
    fig, axes = plt.subplots(1, num_cols, figsize=(8*num_cols, 8))
    axes = np.atleast_1d(axes)
    D = len(q_theta_list[0].loc)
    
    for i, ax in enumerate(axes):
        plot = ax.imshow(corr_list[i], cmap='coolwarm', vmin=-1, vmax=1)
        ax.set_xticks(range(D))
        ax.set_xticklabels(param_names, rotation='vertical')
        ax.set_yticks(range(D))
        ax.set_yticklabels(param_names)
        ax.set_title(labels[i])
        
    plt.tight_layout()
    plt.colorbar(plot, ax=axes, shrink=0.8)
    plt.suptitle('Correlation between parameters')
    plt.show()

In [None]:
plot_corr([q_theta_full, q_theta_mini5k, q_theta_mini1k],
          ['full', 'minibatch size = 5,000', 'minibatch size = 1,000'], priors.keys())

In [None]:
plot_corr([q_theta],
          ['minibatch size = 5,000'], q_theta_raw.keys)

In [None]:
def plot_x(x, t_span, n):
    num_sequences, time_steps, state_dim = x.shape
    fig, axes = plt.subplots(state_dim, figsize=(15, 15))
    step = (x.shape[1] - 1) // (n - 1)
    x_plot = x[:, ::step, :state_dim]
    #print(x.shape, x_plot.shape)
    
    labels = ['SOC', 'DOC', 'MBC', 'EEC']
    for i, ax in enumerate(axes):
        for j in range(num_sequences):
            ax.plot(t_span, x_plot[j, :, i])
        ax.set_ylabel(labels[i])

In [None]:
plot_x(x, t_span, n)

In [None]:
hyperparams_file = 'data/dt_1_t_1000000_n_1_minibatch/theta_from_x_hyperparams_scon_c_{}.pt'.format(seed)
theta_file = 'data/dt_1_t_1000000_n_1_minibatch/theta_from_x_theta_scon_c_{}.pt'.format(seed)
x_file = 'data/dt_1_t_1000000_n_1_minibatch/theta_from_x_x_scon_c_{}.pt'.format(seed)
loss_file = 'data/dt_1_t_1000000_n_1_minibatch/theta_from_x_loss_scon_c_{}.pt'.format(seed)
q_file = 'data/dt_1_t_1000000_n_1_minibatch/theta_from_x_q_scon_c_{}.pt'.format(seed)
time_file = 'data/dt_1_t_1000000_n_1_minibatch/theta_from_x_time_scon_c_{}.pt'.format(seed)
torch.save(theta_hyperparams, hyperparams_file)
torch.save(theta_samples, theta_file)
torch.save(x, x_file)
torch.save([losses, losses_mf], loss_file)
torch.save([q_theta, q_theta_mf], q_file)
torch.save([times, times_mf], time_file)

In [None]:
loss_file = 'data/dt_1_t_100000_n_1_minibatch/theta_from_multi_x_loss_mini1k_scon_c_{}.pt'.format(seed)
q_file = 'data/dt_1_t_100000_n_1_minibatch/theta_from_multi_x_q_mini1k_scon_c_{}.pt'.format(seed)
time_file = 'data/dt_1_t_100000_n_1_minibatch/theta_from_multi_x_time_mini1k_scon_c_{}.pt'.format(seed)
torch.save([losses_mini1k, losses_mf_mini1k], loss_file)
torch.save([q_theta_mini1k, q_theta_mf_mini1k], q_file)
torch.save([times_mini1k, times_mf_mini1k], time_file)

loss_file = 'data/dt_1_t_100000_n_1_minibatch/theta_from_multi_x_loss_mini5k_scon_c_{}.pt'.format(seed)
q_file = 'data/dt_1_t_100000_n_1_minibatch/theta_from_multi_x_q_mini5k_scon_c_{}.pt'.format(seed)
time_file = 'data/dt_1_t_100000_n_1_minibatch/theta_from_multi_x_time_mini5k_scon_c_{}.pt'.format(seed)
torch.save([losses_mini5k, losses_mf_mini5k], loss_file)
torch.save([q_theta_mini5k, q_theta_mf_mini5k], q_file)
torch.save([times_mini5k, times_mf_mini5k], time_file)

In [None]:
q_theta_full, q_theta_mf_full = torch.load('data/dt_1_t_100000_n_1_minibatch/theta_from_multi_x_q_full_scon_c_{}.pt'.format(seed))
losses_full, losses_mf_full = torch.load('data/dt_1_t_100000_n_1_minibatch/theta_from_multi_x_loss_full_scon_c_{}.pt'.format(seed))
times_full, times_mf_full = torch.load('data/dt_1_t_100000_n_1_minibatch/theta_from_multi_x_time_full_scon_c_{}.pt'.format(seed))


In [None]:
q_theta_mini5k, q_theta_mf_mini5k = torch.load('data/dt_1_t_100000_n_1_minibatch/theta_from_multi_x_q_mini5k_scon_c_{}.pt'.format(seed))
losses_mini5k, losses_mf_mini5k = torch.load('data/dt_1_t_100000_n_1_minibatch/theta_from_multi_x_loss_mini5k_scon_c_{}.pt'.format(seed))
times_mini5k, times_mf_mini5k = torch.load('data/dt_1_t_100000_n_1_minibatch/theta_from_multi_x_time_mini5k_scon_c_{}.pt'.format(seed))

In [None]:
q_theta_mini1k, q_theta_mf_mini1k = torch.load('data/dt_1_t_100000_n_1_minibatch/theta_from_multi_x_q_mini1k_scon_c_{}.pt'.format(seed))
losses_mini1k, losses_mf_mini1k = torch.load('data/dt_1_t_100000_n_1_minibatch/theta_from_multi_x_loss_mini1k_scon_c_{}.pt'.format(seed))
times_mini1k, times_mf_mini1k = torch.load('data/dt_1_t_100000_n_1_minibatch/theta_from_multi_x_time_mini1k_scon_c_{}.pt'.format(seed))

In [None]:
q_theta, q_theta_mf = torch.load('data/dt_1_t_1000000_n_1_minibatch/theta_from_x_q_scon_c_{}.pt'.format(seed))
losses, losses_mf = torch.load('data/dt_1_t_1000000_n_1_minibatch/theta_from_x_loss_scon_c_{}.pt'.format(seed))
times, times_mf = torch.load('data/dt_1_t_1000000_n_1_minibatch/theta_from_x_time_scon_c_{}.pt'.format(seed))

## Compare with unobserved $x$

In [None]:
q_theta_y_file = '../training_pt_outputs/q_theta_iter_310000_t_1000_dt_1.0_batch_40_layers_5_lr_0.0005_sd_scale_0.333_SCON-C_logit_multi_2021_09_22_11_48_46.pt'
q_theta_y_mf_file = '../training_pt_outputs/q_theta_iter_250000_t_1000_dt_1.0_batch_45_layers_5_lr_0.0005_sd_scale_0.333_SCON-C_no_CO2_logit_alt_2021_09_23_07_08_19.pt'
device=torch.device('cpu')
q_theta_y = extract_dist(torch.load(q_theta_y_file, map_location=device))
q_theta_y_mf = extract_dist(torch.load(q_theta_y_mf_file, map_location=device))

In [None]:
theta_file = '../generated_data/SCON-C_CO2_logit_alt_sample_y_t_1000_dt_0-01_sd_scale_0-333_rsample.pt'
labels = ['full-rank, observe $y$', 'full-rank, observe $x$', 'mean-field, observe $y$', 'mean-field observe $x$']
plot_theta(p_theta, [q_theta_y, q_theta, q_theta_y_mf, q_theta_mf], theta_file, labels, q_theta_raw.keys,
           colors=[cm.tab10(1), cm.tab10(1), cm.tab10(2), cm.tab10(2)], linestyles=['-', '--', '-', '--'])

In [None]:
plot_corr([q_theta_y, q_theta], ['Observe $y$', 'Observe $x$'], q_theta_raw.keys)

OMGGG I just realized `SBM_SDE` would need to change as well to support theta-from-multiple-x’s inference. We need to compute the negative ELBO:
$$\mathcal{L} = - \frac{1}{S} \sum_s \left[ \log p(\theta^{(s)}) + \log p(x|\theta^{(s)}) - \log q(\theta^{(s)}) \right]$$
where $S$ is the batch size.

With $M$ sequences of $x_i, i = 1, ..., M$, the log likelihood term is now:
$$\log p(x|\theta^{(s)}) = \sum_i \log p(x_i|\theta^{(s)})$$

So drift diffusion needs to take `C_PATH` of size `(S, M, N, D)` (currently `(S, N, D)`) and return a log likelihood tensor of size `(S, M)` (currently `(S, )`), where $N$ is the number of time steps and $D$ is the state dimensions.


In [None]:
#(batch_size, 1, 1) 
diff1 = torch.arange(5).reshape((-1, 1, 1))
diff2 = torch.arange(5, 10).reshape((-1, 1, 1))
diff3 = torch.arange(10, 15).reshape((-1, 1, 1))
diff_list = [diff1, diff2, diff3]
diff_list

In [None]:
tmp = 
tmp.shape

In [None]:
diff_tensor = torch.diag_embed(torch.sqrt(LowerBound.apply(torch.cat(diff_list, 2), 1e-8))) # (batch_size, 1, state_dim, state_dim)
diff_tensor.shape

In [None]:
diff_tensor[1] == torch.sqrt(torch.diag(torch.tensor([1, 6, 11])))

In [None]:
(batch_size, N-1, state_dim)

In [None]:
a = torch.randn(2, 3).reshape(2, 1, -1)
a.shape

In [None]:
torch.diag_embed(a).shape

In [None]:
a = torch.arange(12).reshape((2, 2, 3))
b = torch.arange(4).reshape((2, 2))

In [None]:
a * b

In [None]:
num_minibatches = 4

minibatch_size

In [None]:
5000/3

In [None]:
minibatch_size = n // 3
batch_indices = torch.arange(num_minibatches) * minibatch_size
batch_indices

In [None]:
x.shape

In [None]:
x0 = x[0]
x0.shape

In [None]:
batch = torch.randint(num_minibatches, (niter, ))
batch

In [None]:
[x_i[max(0, batch_indices[batch] - 1):batch_indices[batch + 1]] for x_i in x]

In [None]:
x0[batch_indices[batch0]-1:batch_indices[batch0+1]].shape

In [None]:
torch.gather(x, 0, )

In [None]:
minibatch_size = 1000
t - minibatch_size

In [None]:
start_indices = 

In [None]:
start_indices[torch.randint(len(start_indices), (niter, ))]

In [None]:
torch.all(torch.bincount(torch.randint(len(start_indices), (niter, ))) > 0)

In [None]:
start_idx = torch.randint(100001, (niter, ))

In [None]:
torch.bincount(start_idx)

In [None]:
minibatch_indices = torch.arange(0, t + 1 - minibatch_size, 100) + 1

In [None]:
minibatch_indices[-1]

In [None]:
(torch.arange(0, t - minibatch_size, 100) + 1)[-1]

In [None]:
s = {1, 2, 3}
l = []
l.extend(s)

In [None]:
float('123.45')