In [67]:
import torch
import torch.nn as nn
from torch.optim import Adam
import torch.nn.functional as F
from torch.distributions import kl_divergence, Normal, Categorical

In [6]:
import sys
sys.path.insert(0, './../')

In [217]:
from base_model import compute_marker_log_likelihood, compute_point_log_likelihood, generate_marker
from utils.metric import get_marker_metric, compute_time_expectation, get_time_metric

In [134]:
torch.__version__

'1.0.1.post2'

In [135]:
print(torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

True


In [136]:
def sample_gumbel(shape, eps=1e-20):
    unif = torch.rand(*shape).to(device)
    g = -torch.log(-torch.log(unif + eps))
    return g

def sample_gumbel_softmax(logits, temperature):
    """
        Input:
        logits: Tensor of log probs, shape = BS x k
        temperature = scalar
        
        Output: Tensor of values sampled from Gumbel softmax.
                These will tend towards a one-hot representation in the limit of temp -> 0
                shape = BS x k
    """
    g = sample_gumbel(logits.shape)
    assert g.shape == logits.shape
    h = (g + logits)/temperature
    y = F.softmax(h, dim=-1)
    return y

def reparameterize(mu, logvar):
        epsilon = torch.randn_like(mu).to(device)
        sigma = torch.exp(0.5 * logvar)
        return mu + epsilon.mul(sigma)

In [307]:
class Model1(nn.Module):
    def __init__(self, latent_dim=20, marker_dim=31, marker_type='real', hidden_dim=128, time_dim=2, n_cluster=5, x_given_t=False, time_loss='normal', gamma=1., dropout=None, base_intensity=None, time_influence=None):
        super().__init__()
        self.marker_type = marker_type
        self.marker_dim = marker_dim
        self.time_dim = time_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.cluster_dim = n_cluster
        self.x_given_t = x_given_t
        self.time_loss = time_loss
        self.sigma_min = 1e-2
        self.gamma = gamma
        self.dropout = dropout
        
        # Preprocessing networks
        # Embedding network
        self.x_embedding_layer = [128]
        self.t_embedding_layer = [8]
        self.embed_x, self.embed_t = self.create_embedding_nets()
        self.shared_output_layers = [256]
        self.inf_pre_module, self.gen_pre_module = self.create_preprocess_nets()
        
        # Forward RNN
        self.rnn = self.create_rnn()
        
        # Inference network
        self.encoder_layers = [64, 64]
        self.y_encoder, self.encoder_rnn, self.z_intmd_module, self.z_mu_module, self.z_logvar_module = self.create_inference_nets()
        
        # Generative network
        self.time_mu, self.time_logvar, self.output_x_mu, self.output_x_logvar = self.create_output_nets()
    
    def create_embedding_nets(self):
        # marker_dim is passed. timeseries_dim is 2
        if self.marker_type == 'categorical':
            x_module = nn.Embedding(self.marker_dim, self.x_embedding_layer[0])
        else:
            x_module = nn.Sequential(
                nn.Linear(self.marker_dim, self.x_embedding_layer[0]),
                nn.ReLU(),
        )
        
        t_module = nn.Sequential(
            nn.Linear(self.time_dim, self.t_embedding_layer[0]),
            nn.ReLU()
        )
        return x_module, t_module
    
    def create_preprocess_nets(self):
        # Inference net preprocessing
        hxty_input_dim = self.hidden_dim+self.x_embedding_layer[-1]+self.t_embedding_layer[-1]+self.cluster_dim
        inf_pre_module = nn.Linear(hxty_input_dim, hxty_input_dim)
        
        # Generative net preprocessing
        hzy_input_dim = self.hidden_dim+self.latent_dim+self.cluster_dim
        gen_pre_module = nn.Sequential(
            nn.ReLU(),nn.Dropout(self.dropout),
            nn.Linear(hzy_input_dim, self.shared_output_layers[-1]),
            nn.ReLU(),nn.Dropout(self.dropout))
        return inf_pre_module, gen_pre_module
        
    
    def create_rnn(self):
        rnn = nn.GRU(
            input_size=self.x_embedding_layer[-1]+self.t_embedding_layer[-1],
            hidden_size=self.hidden_dim,
        )
        return rnn
    
    def create_inference_nets(self):
        y_module = nn.Sequential(
            nn.Linear(self.hidden_dim, self.cluster_dim),
            nn.LogSoftmax(dim=-1)
        )
        
        encoder_rnn = nn.GRU(
            input_size=self.x_embedding_layer[-1]+self.t_embedding_layer[-1],
            hidden_size=self.hidden_dim,
        )
        
        z_input_dim = self.hidden_dim+self.x_embedding_layer[-1]+self.t_embedding_layer[-1]+self.cluster_dim
        z_intmd_module = nn.Sequential(
            nn.Linear(z_input_dim, self.encoder_layers[0]),
            nn.ReLU(),
            nn.Linear(self.encoder_layers[0], self.encoder_layers[1]),
            nn.ReLU(),
        )
        z_mu_module = nn.Linear(self.encoder_layers[1], self.latent_dim)
        z_logvar_module = nn.Linear(self.encoder_layers[1], self.latent_dim)
        return y_module, encoder_rnn, z_intmd_module, z_mu_module, z_logvar_module

    def create_output_nets(self):
        l = self.shared_output_layers[-1]
        t_module_mu = nn.Linear(l, 1)
        t_module_logvar = nn.Linear(l, 1)
        
        x_module_logvar = None
        if self.x_given_t:
            l += 1
        if self.marker_type == 'real':
            x_module_mu = nn.Linear(l, self.marker_dim)
            x_module_logvar = nn.Linear(l, self.marker_dim)
        elif self.marker_type == 'binary':#Fix binary
            x_module_mu = nn.Sequential(
                nn.Linear(l, self.marker_dim),
                nn.Sigmoid())
        elif self.marker_type == 'categorical':
            x_module_mu = nn.Sequential(
                nn.Linear(l, self.marker_dim)#,
                #nn.Softmax(dim=-1)
            )
        return t_module_mu, t_module_logvar, x_module_mu, x_module_logvar
    
    ### ENCODER ###
    def encoder(self, phi_xt, temp):
        """
        Input:
            phi_xt: Tensor of shape T x BS x (self.x_embedding_layer[-1]+self.t_embedding_layer[-1])
            temp: scalar
        Output:
            sample_y: Tensor of shape T x BS x cluster_dim
            sample_z: Tensor of shape T x BS x latent_dim
            logits_y: Tensor of shape 1 x BS x cluster_dim
            mu_z: Tensor of shape T x BS x latent_dim
            logvar_z: Tensor of shape T x BS x latent_dim
        """
        T,BS,_ = phi_xt.shape

        # Compute encoder RNN hidden states
        h_0 = torch.zeros(1, BS, self.hidden_dim).to(device)
        hidden_seq, _ = self.encoder_rnn(phi_xt, h_0)
        hidden_seq = torch.cat([h_0, hidden_seq], dim=0)
        
        # Encoder for y
        logits_y = self.y_encoder(hidden_seq[-1])[None, :, :] #shape(logits_y) = 1 x BS x k
        #shape(sample_y) = 1 x BS x k. Should tend to one-hot in the last dimension
        sample_y = sample_gumbel_softmax(logits_y, temp)
        repeat_vals = (T, -1,-1)
        sample_y = sample_y.expand(*repeat_vals) #T x BS x k
        
        # Encoder for z
        concat_hxty = torch.cat([hidden_seq[:-1], phi_xt, sample_y], dim=-1)
        phi_hxty = self.inf_pre_module(concat_hxty)
        z_intmd = self.z_intmd_module(phi_hxty)
        mu_z = self.z_mu_module(z_intmd)
        logvar_z = self.z_logvar_module(z_intmd)
        sample_z = reparameterize(mu_z, logvar_z)
        return sample_y, sample_z, logits_y, (mu_z, logvar_z)
    
    def forward(self, marker_seq, time_seq, anneal=1., mask=None, temp=0.5):
        time_log_likelihood, marker_log_likelihood, KL, metric_dict = self._forward(marker_seq, time_seq, temp, mask)

        marker_loss = (-1.* marker_log_likelihood *mask)[1:,:].sum()
        time_loss = (-1. *time_log_likelihood *mask)[1:,:].sum()
        
        NLL = self.gamma*time_loss + marker_loss
        loss = NLL + KL
        true_loss = time_loss + marker_loss
        meta_info = {"marker_ll":marker_loss.detach().cpu(), "time_ll":time_loss.detach().cpu(), "true_ll": true_loss.detach().cpu(), "kl": KL.detach().cpu()}
        return loss, {**meta_info, **metric_dict}
    
    def _forward(self, x, t, temp, mask):
        # Transform markers and timesteps into the embedding spaces
        phi_x, phi_t = self.embed_x(x), self.embed_t(t)
        phi_xt = torch.cat([phi_x, phi_t], dim=-1)
        T,BS,_ = phi_x.shape
                
        ## Inference
        # Get the sampled value and (mean + var) latent variable
        # using the hidden state sequence
        posterior_sample_y, posterior_sample_z, posterior_logits_y, (posterior_mu_z, posterior_logvar_z) = self.encoder(phi_xt, temp)

        repeat_vals = (T, -1,-1)
        posterior_logits_y = posterior_logits_y.expand(*repeat_vals)
        # Create distributions for Posterior random vars
        posterior_dist_z = Normal(posterior_mu_z, torch.exp(posterior_logvar_z*0.5))
        posterior_dist_y = Categorical(logits=posterior_logits_y)
        
        # Prior is just a Normal(0,1) dist for z and Uniform Categorical for y
        prior_dist_z = Normal(0.*posterior_mu_z, 1. + 0.*posterior_mu_z)
        prior_dist_y = Categorical(probs=1/self.cluster_dim + 0.*posterior_logits_y)

        ## Generative Part
        
        # Use the embedded markers and times to create another set of 
        # hidden vectors. Can reuse the h_0 and time_marker combined computed above

        # Run RNN over the concatenated embedded sequence
        h_0 = torch.zeros(1, BS, self.hidden_dim).to(device)
        # Run RNN
        hidden_seq, _ = self.rnn(phi_xt, h_0)
        # Append h_0 to h_1 .. h_T
        hidden_seq = torch.cat([h_0, hidden_seq], dim=0)
        
        # Combine (z_t, h_t, y) form the input for the generative part
        concat_hzy = torch.cat([hidden_seq[:-1], posterior_sample_z, posterior_sample_y], dim=-1)
        phi_hzy = self.gen_pre_module(concat_hzy)
        mu_marker, logvar_marker = generate_marker(self, phi_hzy, None)
        time_log_likelihood, mu_time = compute_point_log_likelihood(self, phi_hzy, t)
        marker_log_likelihood = compute_marker_log_likelihood(self, x, mu_marker, logvar_marker)
        
        KL_cluster = kl_divergence(posterior_dist_y, prior_dist_y)*mask
        KL_z = kl_divergence(posterior_dist_z, prior_dist_z).sum(-1)*mask
        KL = KL_cluster.sum() + KL_z.sum()
        try:
            assert (KL >= 0)
        except:
            import pdb; pdb.set_trace()
        metric_dict = {}
        with torch.no_grad():
            if self.time_loss == 'intensity':
                mu_time = compute_time_expectation(self, hidden_seq, t, mask)[:,:, None]
            get_marker_metric(self.marker_type, mu_marker, x, mask, metric_dict)
            get_time_metric(mu_time,  t, mask, metric_dict)
            
        return time_log_likelihood, marker_log_likelihood, KL, metric_dict

In [295]:
from trainer import train

In [296]:
def trainer(model, data = None, val_data=None, lr= 1e-3, l2_reg=1e-2, epoch = 400, batch_size = 32):

    optimizer = Adam(model.parameters(), lr=lr, weight_decay=l2_reg)

    for epoch_number in range(epoch):
        for i in range(0, len(data), batch_size):
            optimizer.zero_grad()
            loss, metrics = model(data['x'][i:i+batch_size], data['t'][i:i+batch_size], mask=mask)
            loss.backward()
            optimizer.step()
            print("loss:", loss.detach().cpu()/reduce(lambda x,y: x*y, data['x'].shape[0:2]))
            print(metrics)
#         train(model, epoch_number, data, optimizer, batch_size, val_data)
    return model

In [297]:
from functools import reduce

In [298]:
def main(model, data, val_data):
#     model = model().to(device)
#     data, _ = generate_mpp(type='hawkes', num_sample=1000)
#     val_data, _ = generate_mpp(type='hawkes', num_sample = 200)
#     print("Times: Data Shape: {}, Val Data Shape: {}".format(data['t'].shape, val_data['t'].shape))
#     print("Markers: Data Shape: {}, Val Data Shape: {}".format(data['x'].shape, val_data['x'].shape))
    trainer(model, data=data, val_data=val_data)

In [308]:
model = Model1(marker_dim=22, time_dim=2).to(device)
x,t = torch.rand(10, 32, 22).to(device), torch.rand(10, 32, 2).to(device)
data = {"x": x, "t": t}
mask = torch.ones(10,32).to(device)
main(model, data, None)

loss: tensor(25.1965)
{'marker_ll': tensor(7687.7861), 'time_ll': tensor(344.4263), 'true_ll': tensor(8032.2124), 'kl': tensor(30.6646), 'marker_mse': array(2473.9775, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(92.71465, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(24.1855)
{'marker_ll': tensor(7378.2627), 'time_ll': tensor(334.5691), 'true_ll': tensor(7712.8315), 'kl': tensor(26.5300), 'marker_mse': array(2101.4514, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(80.322495, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(23.1243)
{'marker_ll': tensor(7056.9346), 'time_ll': tensor(318.2906), 'true_ll': tensor(7375.2251), 'kl': tensor(24.5386), 'marker_mse': array(1779.2332, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(60.76291, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(21.8511)
{'marker_ll

loss: tensor(6.5772)
{'marker_ll': tensor(2010.3193), 'time_ll': tensor(84.4396), 'true_ll': tensor(2094.7590), 'kl': tensor(9.9304), 'marker_mse': array(720.1993, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(24.075848, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(6.4287)
{'marker_ll': tensor(1961.5106), 'time_ll': tensor(85.8252), 'true_ll': tensor(2047.3358), 'kl': tensor(9.8537), 'marker_mse': array(721.12415, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(24.093678, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(6.4726)
{'marker_ll': tensor(1973.1880), 'time_ll': tensor(88.5458), 'true_ll': tensor(2061.7339), 'kl': tensor(9.5096), 'marker_mse': array(724.12585, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(24.41608, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(6.4681)
{'marker_ll': tensor(

loss: tensor(5.5542)
{'marker_ll': tensor(1696.9141), 'time_ll': tensor(78.9745), 'true_ll': tensor(1775.8887), 'kl': tensor(1.4427), 'marker_mse': array(670.91595, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(23.917393, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(5.5207)
{'marker_ll': tensor(1688.1744), 'time_ll': tensor(77.0055), 'true_ll': tensor(1765.1799), 'kl': tensor(1.4479), 'marker_mse': array(670.98425, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(23.666508, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(5.4793)
{'marker_ll': tensor(1677.7041), 'time_ll': tensor(74.2203), 'true_ll': tensor(1751.9244), 'kl': tensor(1.4499), 'marker_mse': array(668.5812, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(23.274237, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(5.4719)
{'marker_ll': tensor

loss: tensor(4.7375)
{'marker_ll': tensor(1446.5382), 'time_ll': tensor(68.3616), 'true_ll': tensor(1514.8998), 'kl': tensor(1.0885), 'marker_mse': array(622.18427, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(22.85963, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(4.7247)
{'marker_ll': tensor(1443.2684), 'time_ll': tensor(67.5215), 'true_ll': tensor(1510.7900), 'kl': tensor(1.1055), 'marker_mse': array(621.7543, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(22.720833, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(4.6726)
{'marker_ll': tensor(1424.6760), 'time_ll': tensor(69.4483), 'true_ll': tensor(1494.1243), 'kl': tensor(1.1192), 'marker_mse': array(618.12134, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(22.975147, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(4.6277)
{'marker_ll': tensor(

loss: tensor(3.5508)
{'marker_ll': tensor(1079.7539), 'time_ll': tensor(55.2822), 'true_ll': tensor(1135.0360), 'kl': tensor(1.2355), 'marker_mse': array(555.0423, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(21.38086, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(3.5435)
{'marker_ll': tensor(1079.9412), 'time_ll': tensor(52.7309), 'true_ll': tensor(1132.6720), 'kl': tensor(1.2440), 'marker_mse': array(555.12915, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(21.037956, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(3.4956)
{'marker_ll': tensor(1064.5817), 'time_ll': tensor(52.7761), 'true_ll': tensor(1117.3578), 'kl': tensor(1.2496), 'marker_mse': array(552.4732, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(21.042877, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(3.4714)
{'marker_ll': tensor(1

loss: tensor(2.0392)
{'marker_ll': tensor(618.5049), 'time_ll': tensor(31.8477), 'true_ll': tensor(650.3527), 'kl': tensor(2.1977), 'marker_mse': array(483.34424, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(19.067024, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(2.0350)
{'marker_ll': tensor(616.9078), 'time_ll': tensor(32.0738), 'true_ll': tensor(648.9816), 'kl': tensor(2.2259), 'marker_mse': array(483.2583, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(19.055275, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(1.9709)
{'marker_ll': tensor(597.0117), 'time_ll': tensor(31.4190), 'true_ll': tensor(628.4307), 'kl': tensor(2.2514), 'marker_mse': array(480.99127, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(18.949879, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(1.9020)
{'marker_ll': tensor(576.2

loss: tensor(0.4609)
{'marker_ll': tensor(135.1902), 'time_ll': tensor(9.9435), 'true_ll': tensor(145.1337), 'kl': tensor(2.3656), 'marker_mse': array(426.69775, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(17.169939, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(0.2811)
{'marker_ll': tensor(81.4432), 'time_ll': tensor(6.2445), 'true_ll': tensor(87.6877), 'kl': tensor(2.2704), 'marker_mse': array(421.1983, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(16.892048, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(0.1959)
{'marker_ll': tensor(53.8148), 'time_ll': tensor(6.6173), 'true_ll': tensor(60.4321), 'kl': tensor(2.2709), 'marker_mse': array(418.8338, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(16.905312, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(0.1867)
{'marker_ll': tensor(45.5604), 'ti

loss: tensor(-1.9304)
{'marker_ll': tensor(-603.2188), 'time_ll': tensor(-17.8923), 'true_ll': tensor(-621.1111), 'kl': tensor(3.3963), 'marker_mse': array(363.96704, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(15.108957, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(-2.0274)
{'marker_ll': tensor(-640.1667), 'time_ll': tensor(-12.0559), 'true_ll': tensor(-652.2227), 'kl': tensor(3.4517), 'marker_mse': array(361.2928, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(15.843124, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(-2.0867)
{'marker_ll': tensor(-659.2068), 'time_ll': tensor(-12.0583), 'true_ll': tensor(-671.2651), 'kl': tensor(3.5161), 'marker_mse': array(360.55005, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(15.744751, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(-2.0901)
{'marker_ll':

loss: tensor(-4.5254)
{'marker_ll': tensor(-1418.0359), 'time_ll': tensor(-34.9222), 'true_ll': tensor(-1452.9581), 'kl': tensor(4.8255), 'marker_mse': array(309.2351, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(14.323984, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(-4.4287)
{'marker_ll': tensor(-1392.0719), 'time_ll': tensor(-30.0349), 'true_ll': tensor(-1422.1068), 'kl': tensor(4.9078), 'marker_mse': array(311.7355, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(14.460331, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(-4.6774)
{'marker_ll': tensor(-1472.6571), 'time_ll': tensor(-29.3651), 'true_ll': tensor(-1502.0222), 'kl': tensor(5.2589), 'marker_mse': array(305.7106, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(14.771584, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(-4.7050)
{'marker_

loss: tensor(-7.5436)
{'marker_ll': tensor(-2354.4041), 'time_ll': tensor(-62.8580), 'true_ll': tensor(-2417.2620), 'kl': tensor(3.3065), 'marker_mse': array(260.44232, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(12.410884, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(-7.2372)
{'marker_ll': tensor(-2258.6465), 'time_ll': tensor(-60.5607), 'true_ll': tensor(-2319.2070), 'kl': tensor(3.3141), 'marker_mse': array(263.76, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(12.437577, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(-7.3361)
{'marker_ll': tensor(-2285.0281), 'time_ll': tensor(-66.2266), 'true_ll': tensor(-2351.2546), 'kl': tensor(3.7179), 'marker_mse': array(262.43375, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(12.356903, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(-7.3582)
{'marker_

loss: tensor(-8.8506)
{'marker_ll': tensor(-2747.3589), 'time_ll': tensor(-87.3606), 'true_ll': tensor(-2834.7195), 'kl': tensor(2.5329), 'marker_mse': array(237.63182, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(10.939708, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(-9.3232)
{'marker_ll': tensor(-2894.5781), 'time_ll': tensor(-90.9200), 'true_ll': tensor(-2985.4980), 'kl': tensor(2.0709), 'marker_mse': array(235.10757, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(10.780605, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(-9.4953)
{'marker_ll': tensor(-2950.9224), 'time_ll': tensor(-89.6121), 'true_ll': tensor(-3040.5344), 'kl': tensor(2.0383), 'marker_mse': array(231.03491, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(10.758009, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(-9.5953)
{'mark

loss: tensor(-11.7591)
{'marker_ll': tensor(-3639.3101), 'time_ll': tensor(-125.5744), 'true_ll': tensor(-3764.8845), 'kl': tensor(1.9791), 'marker_mse': array(205.52908, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(9.102482, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(-11.9210)
{'marker_ll': tensor(-3686.2461), 'time_ll': tensor(-130.3874), 'true_ll': tensor(-3816.6335), 'kl': tensor(1.9222), 'marker_mse': array(207.84706, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(8.853317, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(-12.0152)
{'marker_ll': tensor(-3719.2598), 'time_ll': tensor(-127.4332), 'true_ll': tensor(-3846.6929), 'kl': tensor(1.8292), 'marker_mse': array(208.53049, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(8.984295, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(-12.0482)
{'

loss: tensor(-13.0459)
{'marker_ll': tensor(-4022.3096), 'time_ll': tensor(-154.6574), 'true_ll': tensor(-4176.9668), 'kl': tensor(2.2651), 'marker_mse': array(190.10594, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(7.9783115, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(-13.3856)
{'marker_ll': tensor(-4135.7456), 'time_ll': tensor(-149.9184), 'true_ll': tensor(-4285.6641), 'kl': tensor(2.2804), 'marker_mse': array(191.7953, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(8.234463, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(-13.5500)
{'marker_ll': tensor(-4177.2188), 'time_ll': tensor(-161.0334), 'true_ll': tensor(-4338.2520), 'kl': tensor(2.2378), 'marker_mse': array(191.37975, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(7.6920233, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(-13.9765)
{

loss: tensor(-15.4426)
{'marker_ll': tensor(-4762.2734), 'time_ll': tensor(-181.8772), 'true_ll': tensor(-4944.1504), 'kl': tensor(2.5087), 'marker_mse': array(177.33722, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(7.1870527, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(-16.2236)
{'marker_ll': tensor(-4997.1709), 'time_ll': tensor(-197.0060), 'true_ll': tensor(-5194.1768), 'kl': tensor(2.6292), 'marker_mse': array(172.40337, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(6.7250347, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(-16.2373)
{'marker_ll': tensor(-5002.0234), 'time_ll': tensor(-196.5882), 'true_ll': tensor(-5198.6118), 'kl': tensor(2.6625), 'marker_mse': array(170.47758, dtype=float32), 'marker_mse_count': array(320., dtype=float32), 'time_mse': array(6.7084036, dtype=float32), 'time_mse_count': array(288., dtype=float32)}
loss: tensor(-16.5646)