In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import time
import matplotlib.pyplot as plt
import os
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
import torch

import pickle

In [2]:
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import pandas as pd

In [4]:
device = torch.device("cuda")

In [5]:
import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample

class BayesianCNNSingleFC(PyroModule):
    def __init__(
        self,
        num_classes,
        device,
        activation='sigmoid',
        prior_dist='gaussian',
        mu=0,
        b=10.0,
        prior_params=None
    ):
        super().__init__()

        # Store device
        self.device = device

        # Activation setup: accept string or callable
        if isinstance(activation, str):
            act_map = {
                'relu': F.relu,
                'tanh': F.tanh,
                'wg': self.actWG,
                'rwg': self.actRWG,
                'sigmoid': F.sigmoid,
                'sinusoidal': torch.sin,
                'relu6': F.relu6,
                'leaky_relu': F.leaky_relu,
                'selu': F.selu,
            }
            try:
                self.activation_fn = act_map[activation]
            except KeyError:
                raise ValueError(f"Unsupported activation: {activation}")
        elif callable(activation):
            self.activation_fn = activation
        else:
            raise ValueError("activation must be a string or callable")

        # Prior distribution setup
        self.prior_dist = prior_dist
        # Set default prior parameters if not provided
        default_params = {'mu': 0.0, 'b': 10.0}
        params = default_params if prior_params is None else prior_params
        self.prior_mu = torch.tensor(params.get('mu', params['mu']), device=device)
        self.prior_b  = torch.tensor(params.get('b', params['b']), device=device)

        print(f"Using prior distribution: {self.prior_dist} with mu={self.prior_mu.item()} and b={self.prior_b.item()}")

        # Layer definitions with priors
        self.conv1 = PyroModule[nn.Conv2d](3, 32, kernel_size=5, stride=1, padding=2)
        self.conv1.weight = PyroSample(self._make_prior([32, 3, 5, 5]))
        self.conv1.bias   = PyroSample(self._make_prior([32]))
        self.bn1 = nn.BatchNorm2d(32)  # Use nn.BatchNorm2d directly, no PyroModule needed

        self.conv2 = PyroModule[nn.Conv2d](32, 64, kernel_size=5, stride=1, padding=2)
        self.conv2.weight = PyroSample(self._make_prior([64, 32, 5, 5]))
        self.conv2.bias   = PyroSample(self._make_prior([64]))
        self.bn2 =  nn.BatchNorm2d(64)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.gap  = nn.AdaptiveAvgPool2d((1, 1))

        self.fc1 = PyroModule[nn.Linear](64, num_classes)
        self.fc1.weight = PyroSample(self._make_prior([num_classes, 64]))
        self.fc1.bias   = PyroSample(self._make_prior([num_classes]))

    def actWG(self, x, alpha=1.0):
        # Weight-gradient activation
        return x * torch.exp(-alpha * x ** 2)
    
    def actRWG(self, x, alpha=1.0):
        wg = x * torch.exp(-alpha * x ** 2)
        # compare elementwise with zero
        return torch.max(torch.zeros_like(wg), wg)

    def _make_prior(self, shape):
        """
        Construct a prior distribution based on self.prior_dist and parameters.
        """
        if self.prior_dist == 'gaussian':
            base = dist.Normal(self.prior_mu, self.prior_b)
        elif self.prior_dist == 'laplace':
            base = dist.Laplace(self.prior_mu, self.prior_b)
        elif self.prior_dist == 'uniform':
            # uniform over [-b, b]
            base = dist.Uniform(-self.prior_b, self.prior_b)
        else:
            raise ValueError(f"Unsupported prior distribution: {self.prior_dist}")
        return base.expand(shape).to_event(len(shape))

    def forward(self, x, y=None):
        # x: [B, 3, 64, 64]
        #x = self.activation_fn(self.conv1(x).to(self.device))
        x = self.conv1(x).to(self.device)
        x = self.bn1(x)
        x = self.activation_fn(x)
        x = self.pool(x)
        #x = self.activation_fn(self.conv2(x))
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.activation_fn(x)
        x = self.pool(x)
        x = self.gap(x)
        x = x.view(x.size(0), -1)
        logits = self.fc1(x)

        if y is not None:
            with pyro.plate("data", x.size(0)):
                pyro.sample("obs", dist.Categorical(logits=logits), obs=y)

        return logits

In [6]:
eurosat_mean = [0.344, 0.380, 0.408]
eurosat_std  = [0.190, 0.137, 0.115]

old_mean = [0.3444, 0.3803, 0.4078]
old_std = [0.0914, 0.0651, 0.0552]


def load_data(batch_size=54):
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=eurosat_mean, 
                             std=eurosat_std)
    ])

    dataset = datasets.EuroSAT(root='./data', transform=transform, download=True)

    torch.manual_seed(42)

    #train_size = int(0.8 * len(dataset))
    #test_size = len(dataset) - train_size
    #train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    
    with open('datasplit/split_indices.pkl', 'rb') as f:
        split = pickle.load(f)
        train_dataset = Subset(dataset, split['train'])
        test_dataset = Subset(dataset, split['test'])

    # Add num_workers and pin_memory for faster data loading
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                             num_workers=4, pin_memory=True, persistent_workers=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size,
                            num_workers=4, pin_memory=True, persistent_workers=True)
    return train_loader, test_loader

In [7]:
import torch
import pyro.poutine as poutine
from pyro.infer import Trace_ELBO

elbo = Trace_ELBO()


import torch
import pyro.poutine as poutine

def compute_kl_nll_full(model, guide, x, y):
    # 1) Trace the guide (latent-only)
    guide_trace = poutine.trace(guide).get_trace(x, y)
    # 2) Replay the model with those latent samples
    model_trace = poutine.trace(
        poutine.replay(model, trace=guide_trace)
    ).get_trace(x, y)

    kl  = torch.tensor(0., device=x.device)
    nll = torch.tensor(0., device=x.device)

    def sum_lp(site):
        # Try built-in sums first
        if "log_prob_sum" in site:
            return site["log_prob_sum"]
        if "log_prob" in site:
            return site["log_prob"].sum()
        # Fallback to calling fn.log_prob
        lp = site["fn"].log_prob(site["value"])
        return lp.sum()

    # KL from all latent sites in the guide
    for name, site in guide_trace.nodes.items():
        if site["type"] != "sample" or site["is_observed"]:
            continue
        if name not in model_trace.nodes:
            continue
        q_lp = sum_lp(site)
        p_lp = sum_lp(model_trace.nodes[name])
        kl   = kl + (q_lp - p_lp)

    # NLL from the observed "obs" site in the model
    for name, site in model_trace.nodes.items():
        if site["type"] == "sample" and site["is_observed"]:
            # negative log-likelihood = -sum log p(y|x,w)
            nll = nll - sum_lp(site)

    return kl, nll


In [8]:
# training SVI function

import os
import torch
import pyro
from tqdm import tqdm
import numpy as np

def train_svi_with_stats(
    model,
    guide,
    svi,
    train_loader,
    device,
    num_epochs=10,
    save_epochs=None,
    save_dir='results',
    model_filename_pattern='model_{activation}_{prior}_epoch_{epoch}_{timestamp}.pth',
    guide_filename_pattern='guide_{activation}_{prior}_epoch_{epoch}_{timestamp}.pth',
    param_store_filename_pattern='param_store_{activation}_{prior}_epoch_{epoch}_{timestamp}.pkl',
    accuracies_filename_pattern='accuracy_results_{activation}_{prior}_{timestamp}.csv',
    losses_filename_pattern='losses_{activation}_{prior}_{timestamp}.csv',
    model_config_filename_pattern='config_{activation}_{prior}_{timestamp}.json'
):
    """
    Train the SVI model, track losses/accuracies, and
    save artifacts only when accuracy improves, naming files
    like `model_relu_gaussian_epoch_3.pth`.
    """
    
    # Pull names off the model if available, else fall back
    #act_name  = getattr(model, 'activation', getattr(model, 'activation_name', 'act'))
    act_name = model.activation_fn.__name__ if hasattr(model.activation_fn, '__name__') else str(model.activation_fn)
    prior_name = getattr(model, 'prior_dist', 'prior')
    timestamp = time.strftime("%Y%m%d_%H%M%S")

    os.makedirs(save_dir, exist_ok=True)
    save_epochs = set(save_epochs or range(1, num_epochs+1))

    pyro.clear_param_store()
    model.to(device)

    epoch_losses, epoch_accuracies, accuracy_epochs = [], [], []
    weight_stats = {'epochs': [], 'means': [], 'stds': []}
    bias_stats   = {'epochs': [], 'means': [], 'stds': []}
    best_acc = 0.0

    for epoch in range(1, num_epochs+1):
        model.train()
        total_loss = 0.0
        batches = 0

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}"):
            images, labels = images.to(device), labels.to(device)
            total_loss += svi.step(images, labels)

            batch_loss = svi.step(images, labels)

            with torch.no_grad():
                kl, nll = compute_kl_nll_full(model, guide, images, labels)

            # 3) print them side by side
            total_kl_nll = kl.item() + nll.item()
            print(f" batch_loss: {batch_loss:.2f}  |  KL+NLL: {total_kl_nll:.2f}"
                f"  (KL={kl.item():.2f}, NLL={nll.item():.2f})")
            
            batches += 1

        avg_loss = total_loss / batches
        epoch_losses.append(avg_loss)
        print(f"Epoch {epoch} - ELBO Loss: {avg_loss:.4f}")


        if epoch == 1 or epoch % 10 == 0 or epoch == num_epochs:
            model.eval(); guide.eval()
            correct, total = 0, 0
            with torch.no_grad():
                for images, labels in tqdm(train_loader, desc=f"Acc check epoch {epoch}"):
                    images, labels = images.to(device), labels.to(device)
                    trace = pyro.poutine.trace(guide).get_trace(images)
                    replayed = pyro.poutine.replay(model, trace=trace)
                    logits = replayed(images)
                    preds = torch.argmax(logits, dim=1)
                    correct += (preds == labels).sum().item()
                    total += labels.size(0)

            acc = correct/total
            epoch_accuracies.append(acc); accuracy_epochs.append(epoch)
            print(f"Epoch {epoch} - Train Acc: {acc*100:.2f}%")

            # record stats...
            w_means, w_stds, b_means, b_stds = [], [], [], []
            for name, param in pyro.get_param_store().items():
                if 'loc' in name:
                    w_means.append(param.mean().item()); w_stds.append(param.std().item())
                elif 'scale' in name:
                    b_means.append(param.mean().item()); b_stds.append(param.std().item())
            weight_stats['epochs'].append(epoch)
            weight_stats['means'].append(w_means)
            weight_stats['stds'].append(w_stds)
            bias_stats['epochs'].append(epoch)
            bias_stats['means'].append(b_means)
            bias_stats['stds'].append(b_stds)

            for name, param in pyro.get_param_store().items():
                if 'loc' in name or 'scale' in name:
                    print(f"{name}: {param.detach().cpu().numpy()}")

            # only save when accuracy improves
            if acc > best_acc:
                best_acc = acc
                fname_model = model_filename_pattern.format(activation=act_name, prior=prior_name, epoch="best", timestamp=timestamp)
                fname_guide = guide_filename_pattern.format(activation=act_name, prior=prior_name, epoch="best", timestamp=timestamp)
                fname_ps    = param_store_filename_pattern.format(activation=act_name, prior=prior_name, epoch="best", timestamp=timestamp)

                torch.save(model.state_dict(), os.path.join(save_dir, fname_model))
                torch.save(guide.state_dict(), os.path.join(save_dir, fname_guide))
                pyro.get_param_store().save(os.path.join(save_dir, fname_ps))
                print(f"  ↳ Saved: {fname_model}, {fname_guide}, {fname_ps}")

    # save losses per epoch in a csv file, with consistent file naming
    accuracies_df = pd.DataFrame({
        'epoch': accuracy_epochs,
        'accuracy': epoch_accuracies
    })
    accuracies_df.to_csv(os.path.join(save_dir,accuracies_filename_pattern.format(activation=act_name, prior=prior_name, timestamp=timestamp)), index=False)

    loss_df = pd.DataFrame({
        'epoch': list(range(1, epoch + 1)),
        'loss': epoch_losses
    })
    loss_df.to_csv(os.path.join(save_dir,losses_filename_pattern.format(activation=act_name, prior=prior_name, timestamp=timestamp)), index=False)
            
    # save model configuration in a json file
    config = {
        'activation': act_name,
        'prior': prior_name,
        'num_epochs': num_epochs,
        'best_accuracy_at_epoch': accuracy_epochs[np.argmax(epoch_accuracies)],
        'best_accuracy': best_acc,
        'batch_size': train_loader.batch_size,
        'train_size': len(train_loader.dataset),
        'prior_params': {
            'mu': model.prior_mu.item(),
            'b': model.prior_b.item()
        },
    }
    config_filename = model_config_filename_pattern.format(activation=act_name, prior=prior_name, timestamp=timestamp)

    with open(os.path.join(save_dir, config_filename), 'w') as f:
        import json
        json.dump(config, f, indent=4)
        print(f"Configuration saved to {config_filename}")

    return epoch_losses, epoch_accuracies, accuracy_epochs, weight_stats, bias_stats, os.path.join(save_dir, fname_model), os.path.join(save_dir, fname_guide), os.path.join(save_dir, fname_ps), timestamp


In [9]:
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.optim import Adam
from tqdm import tqdm
import pandas as pd

In [10]:
num_classes = 10

bayesian_model = BayesianCNNSingleFC(num_classes,
        device,
        activation='relu',
        prior_dist='gaussian',
        prior_params={'mu': 0.0, 'b': 2.0})

#act_name  = getattr(bayesian_model, 'activation', getattr(bayesian_model, 'activation', 'act'))
#prior_name = getattr(bayesian_model, 'prior_dist', 'prior')

#guide = AutoDiagonalNormal(bayesian_model)

from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.infer.autoguide.initialization import init_to_median
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO

# 1) construct your guide so its locs start at p(w).mean=0
guide = AutoDiagonalNormal(
    bayesian_model,
    init_loc_fn=init_to_median(num_samples=1),   # all μ_q ← prior mean (0)
    init_scale=0.1               # set initial σ_q=0.1
)

optimizer = Adam({"lr": 1e-3,"weight_decay": 1e-4})  # Increased from 1e-4 to 1e-3, and added weight decay
svi = pyro.infer.SVI(model=bayesian_model,
                     guide=guide,
                     optim=optimizer,
                     #loss=pyro.infer.Trace_ELBO(num_particles=1,
                     #                           )) #TODO
                     loss=elbo)


Using prior distribution: gaussian with mu=0.0 and b=2.0


In [11]:
pyro.clear_param_store()

# Ensure model and guide are on the correct device
bayesian_model.to(device)
guide.to(device)

train_loader, test_loader = load_data(batch_size=54)

In [12]:
losses, accuracies, accuracy_epochs, weight_stats, bias_stats, best_model_path, best_guide_path, best_param_store_path, experiment_timestamp = train_svi_with_stats(
    bayesian_model,
    guide,
    svi,
    train_loader,
    device,
    num_epochs=10,
    save_epochs=None,
    save_dir='results_GP_eurosat_TEST')

Epoch 1/10:   1%|          | 4/360 [00:01<01:16,  4.68it/s]

 batch_loss: 163475.29  |  KL+NLL: 115386.82  (KL=114798.72, NLL=588.10)
 batch_loss: 163571.22  |  KL+NLL: 115480.65  (KL=114753.08, NLL=727.57)
 batch_loss: 163388.36  |  KL+NLL: 115250.76  (KL=114703.83, NLL=546.93)
 batch_loss: 162786.83  |  KL+NLL: 115260.36  (KL=114680.80, NLL=579.56)
 batch_loss: 162547.89  |  KL+NLL: 115141.22  (KL=114632.80, NLL=508.42)


Epoch 1/10:   2%|▎         | 9/360 [00:01<00:35, 10.01it/s]

 batch_loss: 162520.00  |  KL+NLL: 115112.70  (KL=114609.34, NLL=503.36)
 batch_loss: 162420.62  |  KL+NLL: 115051.90  (KL=114574.30, NLL=477.60)
 batch_loss: 162384.98  |  KL+NLL: 115198.89  (KL=114550.73, NLL=648.15)
 batch_loss: 162163.46  |  KL+NLL: 114842.22  (KL=114481.93, NLL=360.29)
 batch_loss: 161907.58  |  KL+NLL: 114784.74  (KL=114451.66, NLL=333.09)


Epoch 1/10:   4%|▍         | 14/360 [00:01<00:24, 14.25it/s]

 batch_loss: 161985.17  |  KL+NLL: 114764.76  (KL=114435.16, NLL=329.61)
 batch_loss: 161978.89  |  KL+NLL: 114746.97  (KL=114398.24, NLL=348.73)
 batch_loss: 161470.04  |  KL+NLL: 114758.35  (KL=114342.87, NLL=415.48)
 batch_loss: 161185.98  |  KL+NLL: 114598.46  (KL=114311.91, NLL=286.55)
 batch_loss: 161392.11  |  KL+NLL: 114631.03  (KL=114280.27, NLL=350.76)


Epoch 1/10:   6%|▌         | 20/360 [00:01<00:19, 17.63it/s]

 batch_loss: 161116.53  |  KL+NLL: 114609.47  (KL=114265.64, NLL=343.83)
 batch_loss: 160974.24  |  KL+NLL: 114524.26  (KL=114202.34, NLL=321.92)
 batch_loss: 160569.16  |  KL+NLL: 114477.97  (KL=114175.27, NLL=302.70)
 batch_loss: 160326.73  |  KL+NLL: 114495.31  (KL=114152.30, NLL=343.01)
 batch_loss: 160280.89  |  KL+NLL: 114287.16  (KL=114120.42, NLL=166.73)


Epoch 1/10:   6%|▋         | 23/360 [00:01<00:18, 18.63it/s]

 batch_loss: 160321.69  |  KL+NLL: 114419.17  (KL=114075.63, NLL=343.54)
 batch_loss: 160060.10  |  KL+NLL: 114348.76  (KL=114024.93, NLL=323.83)
 batch_loss: 160005.29  |  KL+NLL: 114342.19  (KL=114000.12, NLL=342.06)
 batch_loss: 159964.44  |  KL+NLL: 114251.15  (KL=113947.45, NLL=303.71)
 batch_loss: 159792.04  |  KL+NLL: 114159.32  (KL=113904.12, NLL=255.20)


Epoch 1/10:   8%|▊         | 29/360 [00:02<00:16, 19.72it/s]

 batch_loss: 159596.67  |  KL+NLL: 114168.68  (KL=113878.94, NLL=289.75)
 batch_loss: 159606.31  |  KL+NLL: 114095.11  (KL=113815.27, NLL=279.84)
 batch_loss: 159429.24  |  KL+NLL: 114103.78  (KL=113819.36, NLL=284.42)
 batch_loss: 159120.99  |  KL+NLL: 113950.03  (KL=113756.91, NLL=193.12)
 batch_loss: 159154.01  |  KL+NLL: 114002.19  (KL=113726.69, NLL=275.51)


Epoch 1/10:  10%|▉         | 35/360 [00:02<00:15, 20.52it/s]

 batch_loss: 158803.65  |  KL+NLL: 113875.71  (KL=113690.98, NLL=184.73)
 batch_loss: 158612.49  |  KL+NLL: 113815.21  (KL=113636.24, NLL=178.97)
 batch_loss: 158542.72  |  KL+NLL: 113769.34  (KL=113615.73, NLL=153.61)
 batch_loss: 158289.36  |  KL+NLL: 113853.16  (KL=113578.61, NLL=274.55)
 batch_loss: 158113.47  |  KL+NLL: 113723.50  (KL=113533.64, NLL=189.86)


Epoch 1/10:  11%|█         | 38/360 [00:02<00:15, 20.71it/s]

 batch_loss: 158009.06  |  KL+NLL: 113685.96  (KL=113500.01, NLL=185.95)
 batch_loss: 158171.96  |  KL+NLL: 113659.93  (KL=113478.13, NLL=181.80)
 batch_loss: 157900.25  |  KL+NLL: 113640.38  (KL=113422.14, NLL=218.24)
 batch_loss: 157505.88  |  KL+NLL: 113559.79  (KL=113397.00, NLL=162.79)
 batch_loss: 157649.89  |  KL+NLL: 113555.98  (KL=113361.80, NLL=194.18)


Epoch 1/10:  12%|█▏        | 44/360 [00:02<00:15, 20.68it/s]

 batch_loss: 157387.47  |  KL+NLL: 113491.21  (KL=113326.66, NLL=164.55)
 batch_loss: 157270.62  |  KL+NLL: 113551.69  (KL=113293.01, NLL=258.68)
 batch_loss: 157118.56  |  KL+NLL: 113398.80  (KL=113248.12, NLL=150.68)
 batch_loss: 156858.79  |  KL+NLL: 113344.52  (KL=113203.34, NLL=141.18)
 batch_loss: 156920.51  |  KL+NLL: 113294.53  (KL=113155.98, NLL=138.56)


Epoch 1/10:  14%|█▍        | 50/360 [00:03<00:14, 20.79it/s]

 batch_loss: 156667.31  |  KL+NLL: 113286.38  (KL=113136.13, NLL=150.25)
 batch_loss: 156690.07  |  KL+NLL: 113332.23  (KL=113106.62, NLL=225.61)
 batch_loss: 156247.37  |  KL+NLL: 113205.51  (KL=113058.75, NLL=146.76)
 batch_loss: 156309.38  |  KL+NLL: 113217.98  (KL=113045.36, NLL=172.62)
 batch_loss: 156262.61  |  KL+NLL: 113199.49  (KL=112994.76, NLL=204.73)


Epoch 1/10:  15%|█▍        | 53/360 [00:03<00:14, 20.95it/s]

 batch_loss: 155739.49  |  KL+NLL: 113127.72  (KL=112955.50, NLL=172.22)
 batch_loss: 155965.32  |  KL+NLL: 113067.33  (KL=112909.56, NLL=157.77)
 batch_loss: 155869.55  |  KL+NLL: 113009.27  (KL=112875.62, NLL=133.64)
 batch_loss: 155513.72  |  KL+NLL: 113022.48  (KL=112843.96, NLL=178.52)
 batch_loss: 155318.78  |  KL+NLL: 112949.94  (KL=112803.42, NLL=146.52)


Epoch 1/10:  16%|█▋        | 59/360 [00:03<00:14, 20.86it/s]

 batch_loss: 155116.10  |  KL+NLL: 112894.75  (KL=112766.95, NLL=127.80)
 batch_loss: 155054.42  |  KL+NLL: 112924.17  (KL=112727.60, NLL=196.57)
 batch_loss: 154970.17  |  KL+NLL: 112842.77  (KL=112714.98, NLL=127.78)
 batch_loss: 154748.44  |  KL+NLL: 112841.34  (KL=112644.27, NLL=197.07)
 batch_loss: 154799.25  |  KL+NLL: 112780.13  (KL=112641.60, NLL=138.53)


Epoch 1/10:  18%|█▊        | 65/360 [00:04<00:14, 20.24it/s]

 batch_loss: 154199.06  |  KL+NLL: 112726.47  (KL=112600.72, NLL=125.75)
 batch_loss: 154606.81  |  KL+NLL: 112698.85  (KL=112546.38, NLL=152.47)
 batch_loss: 154551.64  |  KL+NLL: 112663.39  (KL=112519.88, NLL=143.50)
 batch_loss: 154346.98  |  KL+NLL: 112690.53  (KL=112512.23, NLL=178.29)
 batch_loss: 153980.85  |  KL+NLL: 112595.19  (KL=112456.09, NLL=139.10)


Epoch 1/10:  19%|█▉        | 68/360 [00:04<00:14, 20.52it/s]

 batch_loss: 153851.66  |  KL+NLL: 112633.71  (KL=112419.98, NLL=213.74)
 batch_loss: 153774.23  |  KL+NLL: 112487.36  (KL=112372.91, NLL=114.45)
 batch_loss: 153571.52  |  KL+NLL: 112522.79  (KL=112355.84, NLL=166.94)
 batch_loss: 153724.15  |  KL+NLL: 112474.00  (KL=112303.83, NLL=170.17)
 batch_loss: 153632.71  |  KL+NLL: 112457.69  (KL=112284.98, NLL=172.71)


Epoch 1/10:  21%|██        | 74/360 [00:04<00:13, 20.99it/s]

 batch_loss: 153526.96  |  KL+NLL: 112383.16  (KL=112256.61, NLL=126.55)
 batch_loss: 153167.21  |  KL+NLL: 112372.38  (KL=112214.30, NLL=158.08)
 batch_loss: 152684.03  |  KL+NLL: 112275.10  (KL=112170.37, NLL=104.73)
 batch_loss: 152711.77  |  KL+NLL: 112249.17  (KL=112132.05, NLL=117.13)
 batch_loss: 152496.27  |  KL+NLL: 112183.55  (KL=112068.01, NLL=115.54)


Epoch 1/10:  22%|██▏       | 80/360 [00:04<00:13, 20.48it/s]

 batch_loss: 152490.82  |  KL+NLL: 112169.05  (KL=112054.20, NLL=114.85)
 batch_loss: 152547.63  |  KL+NLL: 112155.76  (KL=112044.71, NLL=111.05)
 batch_loss: 151959.34  |  KL+NLL: 112088.67  (KL=112003.52, NLL=85.15)
 batch_loss: 152147.81  |  KL+NLL: 112097.36  (KL=111958.77, NLL=138.59)
 batch_loss: 152007.57  |  KL+NLL: 112043.37  (KL=111925.76, NLL=117.61)


Epoch 1/10:  23%|██▎       | 83/360 [00:04<00:13, 20.68it/s]

 batch_loss: 151663.53  |  KL+NLL: 112034.66  (KL=111911.66, NLL=122.99)
 batch_loss: 151765.34  |  KL+NLL: 111976.32  (KL=111845.68, NLL=130.64)
 batch_loss: 151081.15  |  KL+NLL: 111977.48  (KL=111835.04, NLL=142.44)
 batch_loss: 150912.39  |  KL+NLL: 111877.21  (KL=111784.04, NLL=93.17)
 batch_loss: 151271.15  |  KL+NLL: 111876.18  (KL=111734.88, NLL=141.31)


Epoch 1/10:  25%|██▍       | 89/360 [00:05<00:13, 20.56it/s]

 batch_loss: 150987.29  |  KL+NLL: 111852.81  (KL=111724.17, NLL=128.64)
 batch_loss: 151038.51  |  KL+NLL: 111832.76  (KL=111696.26, NLL=136.50)
 batch_loss: 150879.41  |  KL+NLL: 111793.22  (KL=111653.19, NLL=140.03)
 batch_loss: 150606.45  |  KL+NLL: 111718.47  (KL=111594.39, NLL=124.08)
 batch_loss: 150578.13  |  KL+NLL: 111708.06  (KL=111590.88, NLL=117.19)


Epoch 1/10:  26%|██▋       | 95/360 [00:05<00:12, 20.93it/s]

 batch_loss: 150327.56  |  KL+NLL: 111652.80  (KL=111523.77, NLL=129.04)
 batch_loss: 150254.91  |  KL+NLL: 111676.30  (KL=111513.88, NLL=162.42)
 batch_loss: 150157.63  |  KL+NLL: 111568.27  (KL=111482.04, NLL=86.23)
 batch_loss: 149967.11  |  KL+NLL: 111553.09  (KL=111462.17, NLL=90.92)
 batch_loss: 149850.81  |  KL+NLL: 111528.81  (KL=111405.06, NLL=123.75)


Epoch 1/10:  27%|██▋       | 98/360 [00:05<00:12, 20.79it/s]

 batch_loss: 149692.15  |  KL+NLL: 111465.89  (KL=111364.01, NLL=101.88)
 batch_loss: 149357.45  |  KL+NLL: 111431.64  (KL=111346.12, NLL=85.51)
 batch_loss: 149133.70  |  KL+NLL: 111380.80  (KL=111282.43, NLL=98.37)
 batch_loss: 149596.27  |  KL+NLL: 111382.44  (KL=111260.18, NLL=122.26)
 batch_loss: 149259.45  |  KL+NLL: 111354.05  (KL=111257.38, NLL=96.66)


Epoch 1/10:  29%|██▉       | 104/360 [00:05<00:12, 20.88it/s]

 batch_loss: 149118.21  |  KL+NLL: 111338.16  (KL=111208.93, NLL=129.23)
 batch_loss: 148777.70  |  KL+NLL: 111275.54  (KL=111165.96, NLL=109.58)
 batch_loss: 148988.67  |  KL+NLL: 111281.90  (KL=111154.84, NLL=127.05)
 batch_loss: 148652.17  |  KL+NLL: 111170.18  (KL=111101.16, NLL=69.01)
 batch_loss: 148521.94  |  KL+NLL: 111213.85  (KL=111067.34, NLL=146.52)


Epoch 1/10:  31%|███       | 110/360 [00:06<00:11, 21.06it/s]

 batch_loss: 148468.05  |  KL+NLL: 111137.54  (KL=111037.73, NLL=99.81)
 batch_loss: 148118.23  |  KL+NLL: 111159.10  (KL=111009.48, NLL=149.61)
 batch_loss: 148200.05  |  KL+NLL: 111039.83  (KL=110954.64, NLL=85.19)
 batch_loss: 147979.76  |  KL+NLL: 111080.06  (KL=110940.99, NLL=139.07)
 batch_loss: 147679.57  |  KL+NLL: 111037.27  (KL=110913.20, NLL=124.08)


Epoch 1/10:  31%|███▏      | 113/360 [00:06<00:11, 21.27it/s]

 batch_loss: 147600.66  |  KL+NLL: 110968.97  (KL=110866.80, NLL=102.17)
 batch_loss: 148225.92  |  KL+NLL: 110956.43  (KL=110838.42, NLL=118.01)
 batch_loss: 147482.20  |  KL+NLL: 110883.23  (KL=110801.91, NLL=81.31)
 batch_loss: 147652.51  |  KL+NLL: 110852.23  (KL=110772.89, NLL=79.34)
 batch_loss: 146939.02  |  KL+NLL: 110852.83  (KL=110765.95, NLL=86.88)


Epoch 1/10:  33%|███▎      | 119/360 [00:06<00:11, 20.80it/s]

 batch_loss: 146977.40  |  KL+NLL: 110801.00  (KL=110713.86, NLL=87.14)
 batch_loss: 146985.43  |  KL+NLL: 110710.25  (KL=110646.45, NLL=63.80)
 batch_loss: 146829.44  |  KL+NLL: 110740.92  (KL=110653.16, NLL=87.77)
 batch_loss: 146666.82  |  KL+NLL: 110670.35  (KL=110590.54, NLL=79.82)
 batch_loss: 146711.38  |  KL+NLL: 110668.66  (KL=110570.27, NLL=98.40)


Epoch 1/10:  35%|███▍      | 125/360 [00:06<00:11, 20.77it/s]

 batch_loss: 146398.03  |  KL+NLL: 110645.71  (KL=110525.46, NLL=120.25)
 batch_loss: 146146.78  |  KL+NLL: 110596.74  (KL=110479.68, NLL=117.06)
 batch_loss: 146085.98  |  KL+NLL: 110548.38  (KL=110481.07, NLL=67.31)
 batch_loss: 146098.77  |  KL+NLL: 110537.77  (KL=110433.45, NLL=104.33)
 batch_loss: 145693.03  |  KL+NLL: 110467.20  (KL=110399.78, NLL=67.42)


Epoch 1/10:  36%|███▌      | 128/360 [00:07<00:11, 20.68it/s]

 batch_loss: 145846.40  |  KL+NLL: 110473.85  (KL=110373.98, NLL=99.87)
 batch_loss: 145740.12  |  KL+NLL: 110511.00  (KL=110352.25, NLL=158.75)
 batch_loss: 145548.94  |  KL+NLL: 110436.94  (KL=110319.07, NLL=117.87)
 batch_loss: 145212.26  |  KL+NLL: 110407.81  (KL=110284.64, NLL=123.17)
 batch_loss: 145317.99  |  KL+NLL: 110336.78  (KL=110237.23, NLL=99.55)


Epoch 1/10:  37%|███▋      | 134/360 [00:07<00:10, 20.69it/s]

 batch_loss: 145053.71  |  KL+NLL: 110335.35  (KL=110239.28, NLL=96.07)
 batch_loss: 145133.70  |  KL+NLL: 110260.47  (KL=110186.12, NLL=74.34)
 batch_loss: 144778.13  |  KL+NLL: 110219.75  (KL=110130.25, NLL=89.50)
 batch_loss: 144486.51  |  KL+NLL: 110196.83  (KL=110129.12, NLL=67.71)
 batch_loss: 144427.36  |  KL+NLL: 110194.23  (KL=110093.98, NLL=100.26)


Epoch 1/10:  39%|███▉      | 140/360 [00:07<00:10, 20.68it/s]

 batch_loss: 144229.44  |  KL+NLL: 110112.56  (KL=110035.44, NLL=77.13)
 batch_loss: 144226.53  |  KL+NLL: 110101.77  (KL=110002.70, NLL=99.07)
 batch_loss: 144075.24  |  KL+NLL: 110068.80  (KL=109993.01, NLL=75.79)
 batch_loss: 143986.69  |  KL+NLL: 110039.17  (KL=109936.91, NLL=102.25)
 batch_loss: 143676.91  |  KL+NLL: 109994.71  (KL=109925.44, NLL=69.27)


Epoch 1/10:  40%|███▉      | 143/360 [00:07<00:10, 20.43it/s]

 batch_loss: 143682.97  |  KL+NLL: 109959.51  (KL=109888.43, NLL=71.08)
 batch_loss: 143642.72  |  KL+NLL: 110006.77  (KL=109878.05, NLL=128.72)
 batch_loss: 143546.08  |  KL+NLL: 109922.68  (KL=109828.05, NLL=94.62)
 batch_loss: 143223.93  |  KL+NLL: 109880.05  (KL=109789.85, NLL=90.20)
 batch_loss: 143172.24  |  KL+NLL: 109814.33  (KL=109757.04, NLL=57.29)


Epoch 1/10:  41%|████▏     | 149/360 [00:08<00:10, 20.22it/s]

 batch_loss: 143261.95  |  KL+NLL: 109781.68  (KL=109712.37, NLL=69.32)
 batch_loss: 142880.07  |  KL+NLL: 109842.75  (KL=109722.87, NLL=119.88)
 batch_loss: 142800.98  |  KL+NLL: 109763.25  (KL=109686.85, NLL=76.40)
 batch_loss: 142543.26  |  KL+NLL: 109669.62  (KL=109615.43, NLL=54.19)
 batch_loss: 142355.46  |  KL+NLL: 109683.45  (KL=109604.83, NLL=78.62)


Epoch 1/10:  43%|████▎     | 155/360 [00:08<00:09, 20.64it/s]

 batch_loss: 142661.13  |  KL+NLL: 109645.16  (KL=109564.26, NLL=80.90)
 batch_loss: 142246.64  |  KL+NLL: 109592.58  (KL=109541.23, NLL=51.35)
 batch_loss: 141816.27  |  KL+NLL: 109568.83  (KL=109510.56, NLL=58.27)
 batch_loss: 142195.30  |  KL+NLL: 109566.71  (KL=109483.49, NLL=83.22)
 batch_loss: 141835.32  |  KL+NLL: 109556.94  (KL=109483.34, NLL=73.60)


Epoch 1/10:  44%|████▍     | 158/360 [00:08<00:09, 20.35it/s]

 batch_loss: 141531.06  |  KL+NLL: 109504.04  (KL=109407.41, NLL=96.64)
 batch_loss: 141572.03  |  KL+NLL: 109445.49  (KL=109372.95, NLL=72.54)
 batch_loss: 141421.03  |  KL+NLL: 109449.23  (KL=109370.66, NLL=78.57)
 batch_loss: 141269.98  |  KL+NLL: 109407.76  (KL=109334.72, NLL=73.04)
 batch_loss: 141239.37  |  KL+NLL: 109393.36  (KL=109310.64, NLL=82.72)


Epoch 1/10:  46%|████▌     | 164/360 [00:08<00:09, 20.41it/s]

 batch_loss: 140902.62  |  KL+NLL: 109392.97  (KL=109268.97, NLL=124.00)
 batch_loss: 140746.95  |  KL+NLL: 109302.71  (KL=109240.69, NLL=62.02)
 batch_loss: 140770.49  |  KL+NLL: 109284.00  (KL=109213.66, NLL=70.34)
 batch_loss: 140569.92  |  KL+NLL: 109241.67  (KL=109179.20, NLL=62.47)
 batch_loss: 140564.82  |  KL+NLL: 109230.42  (KL=109134.72, NLL=95.70)


Epoch 1/10:  47%|████▋     | 170/360 [00:09<00:09, 20.33it/s]

 batch_loss: 140333.44  |  KL+NLL: 109174.74  (KL=109111.73, NLL=63.00)
 batch_loss: 140484.11  |  KL+NLL: 109147.96  (KL=109075.98, NLL=71.98)
 batch_loss: 140001.55  |  KL+NLL: 109184.42  (KL=109083.15, NLL=101.28)
 batch_loss: 139828.25  |  KL+NLL: 109100.68  (KL=109033.10, NLL=67.58)
 batch_loss: 140037.24  |  KL+NLL: 109047.68  (KL=108987.25, NLL=60.43)


Epoch 1/10:  48%|████▊     | 173/360 [00:09<00:09, 20.48it/s]

 batch_loss: 139869.91  |  KL+NLL: 109089.21  (KL=108984.31, NLL=104.89)
 batch_loss: 139710.05  |  KL+NLL: 109033.67  (KL=108939.23, NLL=94.44)
 batch_loss: 139705.19  |  KL+NLL: 108997.18  (KL=108907.06, NLL=90.12)
 batch_loss: 139546.03  |  KL+NLL: 108966.40  (KL=108869.97, NLL=96.43)
 batch_loss: 138977.00  |  KL+NLL: 108893.00  (KL=108819.40, NLL=73.60)


Epoch 1/10:  50%|████▉     | 179/360 [00:09<00:08, 20.65it/s]

 batch_loss: 139102.37  |  KL+NLL: 108896.10  (KL=108832.38, NLL=63.73)
 batch_loss: 138892.69  |  KL+NLL: 108853.93  (KL=108793.70, NLL=60.23)
 batch_loss: 138827.15  |  KL+NLL: 108837.09  (KL=108767.66, NLL=69.42)
 batch_loss: 138733.05  |  KL+NLL: 108796.16  (KL=108734.96, NLL=61.20)
 batch_loss: 138305.26  |  KL+NLL: 108777.07  (KL=108678.41, NLL=98.65)


Epoch 1/10:  51%|█████▏    | 185/360 [00:09<00:08, 20.60it/s]

 batch_loss: 138686.03  |  KL+NLL: 108790.41  (KL=108647.84, NLL=142.57)
 batch_loss: 138158.63  |  KL+NLL: 108723.75  (KL=108657.17, NLL=66.57)
 batch_loss: 138176.42  |  KL+NLL: 108679.77  (KL=108597.11, NLL=82.66)
 batch_loss: 138280.10  |  KL+NLL: 108662.39  (KL=108586.82, NLL=75.57)
 batch_loss: 137528.48  |  KL+NLL: 108635.10  (KL=108544.34, NLL=90.75)


Epoch 1/10:  52%|█████▏    | 188/360 [00:09<00:08, 20.51it/s]

 batch_loss: 137571.86  |  KL+NLL: 108590.83  (KL=108512.50, NLL=78.33)
 batch_loss: 137755.17  |  KL+NLL: 108577.62  (KL=108485.02, NLL=92.61)
 batch_loss: 137599.44  |  KL+NLL: 108498.09  (KL=108429.89, NLL=68.20)
 batch_loss: 137570.22  |  KL+NLL: 108522.93  (KL=108433.12, NLL=89.81)
 batch_loss: 137210.63  |  KL+NLL: 108490.47  (KL=108394.02, NLL=96.45)


Epoch 1/10:  54%|█████▍    | 194/360 [00:10<00:07, 21.02it/s]

 batch_loss: 137226.20  |  KL+NLL: 108418.66  (KL=108350.34, NLL=68.31)
 batch_loss: 137139.08  |  KL+NLL: 108401.61  (KL=108331.25, NLL=70.36)
 batch_loss: 136898.53  |  KL+NLL: 108372.03  (KL=108293.99, NLL=78.04)
 batch_loss: 136778.32  |  KL+NLL: 108367.19  (KL=108276.76, NLL=90.44)
 batch_loss: 136622.83  |  KL+NLL: 108309.29  (KL=108245.66, NLL=63.62)


Epoch 1/10:  56%|█████▌    | 200/360 [00:10<00:07, 21.32it/s]

 batch_loss: 136542.85  |  KL+NLL: 108314.63  (KL=108239.30, NLL=75.34)
 batch_loss: 136284.90  |  KL+NLL: 108258.38  (KL=108180.55, NLL=77.82)
 batch_loss: 136338.58  |  KL+NLL: 108227.29  (KL=108138.36, NLL=88.93)
 batch_loss: 136184.56  |  KL+NLL: 108204.25  (KL=108135.75, NLL=68.50)
 batch_loss: 136215.55  |  KL+NLL: 108150.60  (KL=108087.88, NLL=62.72)


Epoch 1/10:  56%|█████▋    | 203/360 [00:10<00:07, 21.12it/s]

 batch_loss: 136182.68  |  KL+NLL: 108159.62  (KL=108100.75, NLL=58.87)
 batch_loss: 135634.68  |  KL+NLL: 108116.73  (KL=108040.27, NLL=76.45)
 batch_loss: 135460.03  |  KL+NLL: 108098.61  (KL=108024.90, NLL=73.71)
 batch_loss: 135338.37  |  KL+NLL: 108054.74  (KL=107976.29, NLL=78.45)
 batch_loss: 135380.12  |  KL+NLL: 108022.96  (KL=107947.67, NLL=75.29)


Epoch 1/10:  58%|█████▊    | 209/360 [00:10<00:07, 21.30it/s]

 batch_loss: 135248.54  |  KL+NLL: 108023.26  (KL=107936.92, NLL=86.34)
 batch_loss: 135242.74  |  KL+NLL: 108037.55  (KL=107912.77, NLL=124.78)
 batch_loss: 135383.81  |  KL+NLL: 107917.77  (KL=107850.76, NLL=67.01)
 batch_loss: 134914.67  |  KL+NLL: 107940.11  (KL=107841.63, NLL=98.48)
 batch_loss: 134707.72  |  KL+NLL: 107922.42  (KL=107812.53, NLL=109.89)


Epoch 1/10:  60%|█████▉    | 215/360 [00:11<00:06, 21.39it/s]

 batch_loss: 134637.62  |  KL+NLL: 107842.13  (KL=107765.39, NLL=76.74)
 batch_loss: 134718.08  |  KL+NLL: 107820.64  (KL=107748.70, NLL=71.94)
 batch_loss: 134327.33  |  KL+NLL: 107785.67  (KL=107723.95, NLL=61.72)
 batch_loss: 134168.76  |  KL+NLL: 107794.31  (KL=107714.05, NLL=80.26)
 batch_loss: 133775.93  |  KL+NLL: 107767.33  (KL=107694.76, NLL=72.58)


Epoch 1/10:  61%|██████    | 218/360 [00:11<00:06, 21.46it/s]

 batch_loss: 133579.76  |  KL+NLL: 107749.74  (KL=107646.20, NLL=103.53)
 batch_loss: 133773.79  |  KL+NLL: 107669.93  (KL=107604.48, NLL=65.45)
 batch_loss: 133813.41  |  KL+NLL: 107624.12  (KL=107570.08, NLL=54.04)
 batch_loss: 133386.50  |  KL+NLL: 107602.78  (KL=107541.45, NLL=61.33)
 batch_loss: 133324.89  |  KL+NLL: 107636.16  (KL=107537.30, NLL=98.85)


Epoch 1/10:  62%|██████▏   | 224/360 [00:11<00:06, 20.52it/s]

 batch_loss: 133006.63  |  KL+NLL: 107599.57  (KL=107513.60, NLL=85.97)
 batch_loss: 133077.14  |  KL+NLL: 107606.15  (KL=107481.70, NLL=124.45)
 batch_loss: 132836.61  |  KL+NLL: 107510.18  (KL=107435.97, NLL=74.21)
 batch_loss: 132846.66  |  KL+NLL: 107526.15  (KL=107413.20, NLL=112.96)


Epoch 1/10:  63%|██████▎   | 227/360 [00:11<00:06, 20.54it/s]

 batch_loss: 132599.12  |  KL+NLL: 107466.95  (KL=107375.41, NLL=91.54)
 batch_loss: 132602.76  |  KL+NLL: 107427.07  (KL=107344.47, NLL=82.60)
 batch_loss: 132587.02  |  KL+NLL: 107389.43  (KL=107332.28, NLL=57.15)
 batch_loss: 132153.10  |  KL+NLL: 107353.17  (KL=107296.84, NLL=56.32)
 batch_loss: 132394.09  |  KL+NLL: 107369.33  (KL=107298.48, NLL=70.86)


Epoch 1/10:  65%|██████▍   | 233/360 [00:12<00:06, 20.21it/s]

 batch_loss: 132186.28  |  KL+NLL: 107347.65  (KL=107263.86, NLL=83.79)
 batch_loss: 132061.02  |  KL+NLL: 107325.63  (KL=107227.71, NLL=97.92)
 batch_loss: 131757.33  |  KL+NLL: 107281.79  (KL=107175.08, NLL=106.72)
 batch_loss: 131870.17  |  KL+NLL: 107222.86  (KL=107134.86, NLL=88.01)


Epoch 1/10:  66%|██████▌   | 236/360 [00:12<00:06, 20.13it/s]

 batch_loss: 131725.97  |  KL+NLL: 107195.62  (KL=107110.84, NLL=84.78)
 batch_loss: 131430.30  |  KL+NLL: 107154.18  (KL=107091.39, NLL=62.79)
 batch_loss: 131431.22  |  KL+NLL: 107170.41  (KL=107080.13, NLL=90.28)
 batch_loss: 131080.86  |  KL+NLL: 107115.21  (KL=107038.97, NLL=76.24)
 batch_loss: 131179.68  |  KL+NLL: 107091.44  (KL=107018.93, NLL=72.52)


Epoch 1/10:  67%|██████▋   | 242/360 [00:12<00:05, 19.99it/s]

 batch_loss: 131033.27  |  KL+NLL: 107060.59  (KL=107007.25, NLL=53.34)
 batch_loss: 131284.43  |  KL+NLL: 107034.30  (KL=106966.03, NLL=68.27)
 batch_loss: 130693.10  |  KL+NLL: 107035.32  (KL=106940.91, NLL=94.41)
 batch_loss: 130279.01  |  KL+NLL: 107013.47  (KL=106938.66, NLL=74.81)
 batch_loss: 130276.21  |  KL+NLL: 106948.91  (KL=106886.07, NLL=62.84)


Epoch 1/10:  69%|██████▉   | 248/360 [00:12<00:05, 20.16it/s]

 batch_loss: 130223.53  |  KL+NLL: 106964.88  (KL=106895.91, NLL=68.97)
 batch_loss: 130279.72  |  KL+NLL: 106889.52  (KL=106833.78, NLL=55.74)
 batch_loss: 129970.68  |  KL+NLL: 106858.15  (KL=106802.84, NLL=55.31)
 batch_loss: 129983.99  |  KL+NLL: 106893.40  (KL=106796.02, NLL=97.39)
 batch_loss: 129459.57  |  KL+NLL: 106791.71  (KL=106731.91, NLL=59.79)


Epoch 1/10:  70%|██████▉   | 251/360 [00:13<00:05, 19.99it/s]

 batch_loss: 129482.31  |  KL+NLL: 106817.43  (KL=106727.37, NLL=90.07)
 batch_loss: 129594.97  |  KL+NLL: 106782.57  (KL=106709.73, NLL=72.83)
 batch_loss: 129564.32  |  KL+NLL: 106703.01  (KL=106653.34, NLL=49.66)
 batch_loss: 129305.31  |  KL+NLL: 106732.63  (KL=106649.03, NLL=83.60)
 batch_loss: 129192.66  |  KL+NLL: 106751.75  (KL=106619.01, NLL=132.74)


Epoch 1/10:  71%|███████▏  | 257/360 [00:13<00:05, 20.39it/s]

 batch_loss: 128925.04  |  KL+NLL: 106648.52  (KL=106583.51, NLL=65.01)
 batch_loss: 129179.55  |  KL+NLL: 106627.73  (KL=106544.21, NLL=83.52)
 batch_loss: 129027.89  |  KL+NLL: 106573.12  (KL=106520.44, NLL=52.68)
 batch_loss: 128785.49  |  KL+NLL: 106574.56  (KL=106504.45, NLL=70.11)
 batch_loss: 128387.62  |  KL+NLL: 106587.16  (KL=106520.59, NLL=66.58)


Epoch 1/10:  73%|███████▎  | 263/360 [00:13<00:04, 20.78it/s]

 batch_loss: 128197.17  |  KL+NLL: 106529.44  (KL=106443.78, NLL=85.66)
 batch_loss: 128221.02  |  KL+NLL: 106499.47  (KL=106421.84, NLL=77.63)
 batch_loss: 128393.19  |  KL+NLL: 106536.52  (KL=106460.56, NLL=75.96)
 batch_loss: 128200.81  |  KL+NLL: 106469.75  (KL=106377.41, NLL=92.34)
 batch_loss: 128089.65  |  KL+NLL: 106410.34  (KL=106347.77, NLL=62.58)


Epoch 1/10:  74%|███████▍  | 266/360 [00:13<00:04, 20.94it/s]

 batch_loss: 127796.16  |  KL+NLL: 106407.92  (KL=106320.49, NLL=87.43)
 batch_loss: 127910.73  |  KL+NLL: 106401.61  (KL=106319.80, NLL=81.80)
 batch_loss: 127648.70  |  KL+NLL: 106323.35  (KL=106252.51, NLL=70.84)
 batch_loss: 127471.62  |  KL+NLL: 106338.17  (KL=106246.70, NLL=91.47)
 batch_loss: 127343.55  |  KL+NLL: 106314.92  (KL=106221.35, NLL=93.56)


Epoch 1/10:  76%|███████▌  | 272/360 [00:14<00:04, 20.72it/s]

 batch_loss: 127315.65  |  KL+NLL: 106271.03  (KL=106207.27, NLL=63.76)
 batch_loss: 127061.90  |  KL+NLL: 106242.45  (KL=106166.72, NLL=75.73)
 batch_loss: 127111.30  |  KL+NLL: 106249.76  (KL=106123.91, NLL=125.86)
 batch_loss: 126705.37  |  KL+NLL: 106154.40  (KL=106093.37, NLL=61.03)


Epoch 1/10:  76%|███████▋  | 275/360 [00:14<00:04, 20.35it/s]

 batch_loss: 126590.51  |  KL+NLL: 106148.44  (KL=106098.15, NLL=50.29)
 batch_loss: 126812.88  |  KL+NLL: 106163.28  (KL=106063.69, NLL=99.59)
 batch_loss: 126536.56  |  KL+NLL: 106128.37  (KL=106019.53, NLL=108.84)
 batch_loss: 126391.66  |  KL+NLL: 106077.37  (KL=105999.43, NLL=77.94)


Epoch 1/10:  78%|███████▊  | 281/360 [00:14<00:03, 20.55it/s]

 batch_loss: 125925.05  |  KL+NLL: 106047.68  (KL=105978.98, NLL=68.69)
 batch_loss: 125957.97  |  KL+NLL: 106042.73  (KL=105966.78, NLL=75.95)
 batch_loss: 125941.32  |  KL+NLL: 105968.90  (KL=105904.66, NLL=64.25)
 batch_loss: 125941.22  |  KL+NLL: 105945.77  (KL=105903.75, NLL=42.02)
 batch_loss: 125983.26  |  KL+NLL: 106022.58  (KL=105927.84, NLL=94.74)


Epoch 1/10:  79%|███████▉  | 284/360 [00:14<00:03, 20.61it/s]

 batch_loss: 125642.52  |  KL+NLL: 105936.82  (KL=105858.87, NLL=77.96)
 batch_loss: 125653.55  |  KL+NLL: 105936.23  (KL=105848.63, NLL=87.60)
 batch_loss: 125060.62  |  KL+NLL: 105878.72  (KL=105805.20, NLL=73.51)
 batch_loss: 125258.31  |  KL+NLL: 105866.60  (KL=105798.45, NLL=68.16)
 batch_loss: 125121.54  |  KL+NLL: 105842.85  (KL=105738.82, NLL=104.03)


Epoch 1/10:  81%|████████  | 290/360 [00:14<00:03, 20.75it/s]

 batch_loss: 124849.50  |  KL+NLL: 105818.96  (KL=105733.71, NLL=85.25)
 batch_loss: 125098.13  |  KL+NLL: 105786.58  (KL=105689.72, NLL=96.86)
 batch_loss: 124980.58  |  KL+NLL: 105765.04  (KL=105688.73, NLL=76.31)
 batch_loss: 124911.45  |  KL+NLL: 105695.60  (KL=105639.13, NLL=56.47)
 batch_loss: 124613.34  |  KL+NLL: 105738.61  (KL=105623.62, NLL=114.98)


Epoch 1/10:  82%|████████▏ | 296/360 [00:15<00:03, 20.29it/s]

 batch_loss: 124412.23  |  KL+NLL: 105723.48  (KL=105619.69, NLL=103.79)
 batch_loss: 124311.16  |  KL+NLL: 105621.61  (KL=105552.70, NLL=68.92)
 batch_loss: 124067.11  |  KL+NLL: 105652.29  (KL=105582.60, NLL=69.69)
 batch_loss: 123946.64  |  KL+NLL: 105595.41  (KL=105527.60, NLL=67.80)
 batch_loss: 124055.75  |  KL+NLL: 105583.45  (KL=105519.41, NLL=64.03)


Epoch 1/10:  83%|████████▎ | 299/360 [00:15<00:02, 20.54it/s]

 batch_loss: 123645.56  |  KL+NLL: 105591.28  (KL=105493.77, NLL=97.51)
 batch_loss: 123420.74  |  KL+NLL: 105511.36  (KL=105442.76, NLL=68.61)
 batch_loss: 123471.60  |  KL+NLL: 105488.98  (KL=105436.98, NLL=52.00)
 batch_loss: 123235.65  |  KL+NLL: 105482.64  (KL=105419.39, NLL=63.25)
 batch_loss: 123336.53  |  KL+NLL: 105459.90  (KL=105371.30, NLL=88.60)


Epoch 1/10:  85%|████████▍ | 305/360 [00:15<00:02, 20.34it/s]

 batch_loss: 123560.47  |  KL+NLL: 105433.30  (KL=105357.70, NLL=75.60)
 batch_loss: 123027.47  |  KL+NLL: 105407.47  (KL=105332.98, NLL=74.49)
 batch_loss: 123230.00  |  KL+NLL: 105387.11  (KL=105333.66, NLL=53.44)
 batch_loss: 122928.79  |  KL+NLL: 105352.14  (KL=105274.02, NLL=78.13)


Epoch 1/10:  86%|████████▌ | 308/360 [00:15<00:02, 20.02it/s]

 batch_loss: 122175.49  |  KL+NLL: 105345.96  (KL=105262.33, NLL=83.63)
 batch_loss: 122800.09  |  KL+NLL: 105331.91  (KL=105254.84, NLL=77.07)
 batch_loss: 122151.72  |  KL+NLL: 105279.95  (KL=105184.87, NLL=95.08)
 batch_loss: 122196.07  |  KL+NLL: 105250.28  (KL=105187.03, NLL=63.25)
 batch_loss: 121901.03  |  KL+NLL: 105222.91  (KL=105155.59, NLL=67.33)


Epoch 1/10:  87%|████████▋ | 314/360 [00:16<00:02, 20.09it/s]

 batch_loss: 122376.79  |  KL+NLL: 105242.29  (KL=105133.33, NLL=108.96)
 batch_loss: 121833.15  |  KL+NLL: 105217.45  (KL=105133.79, NLL=83.66)
 batch_loss: 121645.65  |  KL+NLL: 105230.17  (KL=105093.26, NLL=136.91)
 batch_loss: 121727.97  |  KL+NLL: 105139.90  (KL=105054.17, NLL=85.73)
 batch_loss: 121578.75  |  KL+NLL: 105071.83  (KL=105011.44, NLL=60.39)


Epoch 1/10:  89%|████████▉ | 320/360 [00:16<00:01, 20.91it/s]

 batch_loss: 121343.39  |  KL+NLL: 105092.00  (KL=105009.87, NLL=82.13)
 batch_loss: 120978.50  |  KL+NLL: 105076.50  (KL=104991.73, NLL=84.77)
 batch_loss: 121158.07  |  KL+NLL: 105042.62  (KL=104963.09, NLL=79.53)
 batch_loss: 121499.94  |  KL+NLL: 105012.04  (KL=104940.20, NLL=71.84)
 batch_loss: 120955.55  |  KL+NLL: 104994.98  (KL=104923.18, NLL=71.80)


Epoch 1/10:  90%|████████▉ | 323/360 [00:16<00:01, 20.75it/s]

 batch_loss: 120878.64  |  KL+NLL: 104991.35  (KL=104903.14, NLL=88.21)
 batch_loss: 120932.97  |  KL+NLL: 104937.49  (KL=104872.56, NLL=64.93)
 batch_loss: 120549.44  |  KL+NLL: 104920.97  (KL=104825.51, NLL=95.46)
 batch_loss: 120866.05  |  KL+NLL: 104907.74  (KL=104826.98, NLL=80.75)
 batch_loss: 120709.29  |  KL+NLL: 104885.76  (KL=104807.33, NLL=78.43)


Epoch 1/10:  91%|█████████▏| 329/360 [00:16<00:01, 20.89it/s]

 batch_loss: 120249.45  |  KL+NLL: 104869.99  (KL=104807.00, NLL=62.99)
 batch_loss: 120103.98  |  KL+NLL: 104868.89  (KL=104777.88, NLL=91.01)
 batch_loss: 120046.87  |  KL+NLL: 104836.92  (KL=104758.21, NLL=78.71)
 batch_loss: 119907.46  |  KL+NLL: 104803.03  (KL=104704.98, NLL=98.05)
 batch_loss: 119545.21  |  KL+NLL: 104761.73  (KL=104681.63, NLL=80.10)


Epoch 1/10:  93%|█████████▎| 335/360 [00:17<00:01, 20.71it/s]

 batch_loss: 119539.05  |  KL+NLL: 104756.58  (KL=104667.42, NLL=89.16)
 batch_loss: 119545.68  |  KL+NLL: 104750.19  (KL=104643.05, NLL=107.14)
 batch_loss: 119400.84  |  KL+NLL: 104682.30  (KL=104606.41, NLL=75.89)
 batch_loss: 119290.44  |  KL+NLL: 104666.85  (KL=104596.76, NLL=70.09)
 batch_loss: 119056.91  |  KL+NLL: 104650.42  (KL=104542.46, NLL=107.95)


Epoch 1/10:  94%|█████████▍| 338/360 [00:17<00:01, 20.97it/s]

 batch_loss: 118709.87  |  KL+NLL: 104619.55  (KL=104565.04, NLL=54.52)
 batch_loss: 119311.09  |  KL+NLL: 104578.05  (KL=104505.09, NLL=72.95)
 batch_loss: 118779.77  |  KL+NLL: 104566.77  (KL=104498.52, NLL=68.24)
 batch_loss: 118817.94  |  KL+NLL: 104545.11  (KL=104466.04, NLL=79.07)
 batch_loss: 118636.39  |  KL+NLL: 104525.95  (KL=104448.91, NLL=77.04)


Epoch 1/10:  96%|█████████▌| 344/360 [00:17<00:00, 20.81it/s]

 batch_loss: 118575.09  |  KL+NLL: 104478.04  (KL=104422.75, NLL=55.29)
 batch_loss: 118534.55  |  KL+NLL: 104451.18  (KL=104378.93, NLL=72.25)
 batch_loss: 118023.34  |  KL+NLL: 104490.70  (KL=104411.23, NLL=79.46)
 batch_loss: 118137.91  |  KL+NLL: 104438.14  (KL=104371.03, NLL=67.11)
 batch_loss: 118059.87  |  KL+NLL: 104375.25  (KL=104316.19, NLL=59.06)


Epoch 1/10:  97%|█████████▋| 350/360 [00:17<00:00, 20.57it/s]

 batch_loss: 117867.54  |  KL+NLL: 104376.92  (KL=104310.61, NLL=66.31)
 batch_loss: 117660.12  |  KL+NLL: 104377.38  (KL=104313.29, NLL=64.09)
 batch_loss: 117655.82  |  KL+NLL: 104377.77  (KL=104277.11, NLL=100.66)
 batch_loss: 117710.23  |  KL+NLL: 104290.71  (KL=104233.68, NLL=57.03)
 batch_loss: 117342.89  |  KL+NLL: 104301.49  (KL=104229.16, NLL=72.33)


Epoch 1/10:  98%|█████████▊| 353/360 [00:17<00:00, 20.73it/s]

 batch_loss: 117314.25  |  KL+NLL: 104263.12  (KL=104200.66, NLL=62.47)
 batch_loss: 117044.52  |  KL+NLL: 104252.91  (KL=104155.86, NLL=97.05)
 batch_loss: 117358.15  |  KL+NLL: 104223.71  (KL=104144.90, NLL=78.81)
 batch_loss: 116619.34  |  KL+NLL: 104185.37  (KL=104092.88, NLL=92.49)
 batch_loss: 116646.12  |  KL+NLL: 104162.59  (KL=104079.36, NLL=83.23)


Epoch 1/10: 100%|██████████| 360/360 [00:18<00:00, 19.70it/s]


 batch_loss: 116498.06  |  KL+NLL: 104162.26  (KL=104095.48, NLL=66.77)
 batch_loss: 116669.36  |  KL+NLL: 104136.48  (KL=104065.20, NLL=71.28)
 batch_loss: 116302.49  |  KL+NLL: 104117.11  (KL=104055.44, NLL=61.67)
 batch_loss: 116231.11  |  KL+NLL: 104113.16  (KL=104034.79, NLL=78.37)
 batch_loss: 116385.15  |  KL+NLL: 104036.03  (KL=103967.17, NLL=68.86)
Epoch 1 - ELBO Loss: 138951.5106


Acc check epoch 1: 100%|██████████| 360/360 [00:05<00:00, 70.27it/s]


Epoch 1 - Train Acc: 51.28%
AutoDiagonalNormal.loc: [ 0.27853602  3.9507675  -0.36595303 ... -2.3827617  -2.791243
 -0.7042019 ]
AutoDiagonalNormal.scale: [0.19050516 0.18906902 0.19086038 ... 0.18131262 0.19339208 0.1617544 ]
  ↳ Saved: model_relu_gaussian_epoch_best_20250715_100559.pth, guide_relu_gaussian_epoch_best_20250715_100559.pth, param_store_relu_gaussian_epoch_best_20250715_100559.pkl


Epoch 2/10:   0%|          | 1/360 [00:00<00:48,  7.43it/s]

 batch_loss: 116447.83  |  KL+NLL: 104086.89  (KL=103998.48, NLL=88.41)


Epoch 2/10:   1%|          | 3/360 [00:00<00:26, 13.52it/s]

 batch_loss: 116066.82  |  KL+NLL: 104019.98  (KL=103945.55, NLL=74.42)
 batch_loss: 115750.99  |  KL+NLL: 103989.35  (KL=103933.32, NLL=56.03)
 batch_loss: 115549.97  |  KL+NLL: 104008.73  (KL=103919.77, NLL=88.95)


Epoch 2/10:   1%|▏         | 5/360 [00:00<00:21, 16.18it/s]

 batch_loss: 115461.32  |  KL+NLL: 103981.83  (KL=103886.98, NLL=94.85)
 batch_loss: 115617.10  |  KL+NLL: 103889.96  (KL=103823.01, NLL=66.95)


Epoch 2/10:   2%|▏         | 8/360 [00:00<00:18, 18.61it/s]

 batch_loss: 115204.60  |  KL+NLL: 103916.31  (KL=103831.59, NLL=84.72)
 batch_loss: 115352.60  |  KL+NLL: 103842.09  (KL=103772.45, NLL=69.64)
 batch_loss: 114960.16  |  KL+NLL: 103877.66  (KL=103801.18, NLL=76.48)


Epoch 2/10:   3%|▎         | 10/360 [00:00<00:18, 18.76it/s]

 batch_loss: 115096.89  |  KL+NLL: 103838.84  (KL=103775.20, NLL=63.64)
 batch_loss: 114643.29  |  KL+NLL: 103826.71  (KL=103755.16, NLL=71.55)


Epoch 2/10:   4%|▎         | 13/360 [00:00<00:17, 19.88it/s]

 batch_loss: 114793.99  |  KL+NLL: 103786.82  (KL=103700.86, NLL=85.96)
 batch_loss: 114835.55  |  KL+NLL: 103766.53  (KL=103671.30, NLL=95.23)
 batch_loss: 114703.01  |  KL+NLL: 103784.63  (KL=103712.62, NLL=72.01)


Epoch 2/10:   4%|▍         | 15/360 [00:00<00:17, 19.86it/s]

 batch_loss: 114499.77  |  KL+NLL: 103762.69  (KL=103674.72, NLL=87.97)
 batch_loss: 114011.13  |  KL+NLL: 103690.72  (KL=103605.45, NLL=85.28)


Epoch 2/10:   5%|▍         | 17/360 [00:00<00:17, 19.84it/s]

 batch_loss: 114262.29  |  KL+NLL: 103706.35  (KL=103624.90, NLL=81.45)
 batch_loss: 114305.77  |  KL+NLL: 103665.95  (KL=103584.46, NLL=81.49)
 batch_loss: 113809.04  |  KL+NLL: 103655.98  (KL=103584.91, NLL=71.07)


Epoch 2/10:   6%|▌         | 20/360 [00:01<00:16, 20.09it/s]

 batch_loss: 113688.76  |  KL+NLL: 103674.41  (KL=103534.05, NLL=140.36)
 batch_loss: 113642.60  |  KL+NLL: 103595.38  (KL=103533.08, NLL=62.30)


Epoch 2/10:   6%|▋         | 23/360 [00:01<00:16, 20.10it/s]

 batch_loss: 113805.91  |  KL+NLL: 103585.07  (KL=103500.05, NLL=85.02)
 batch_loss: 113482.59  |  KL+NLL: 103550.73  (KL=103462.77, NLL=87.96)


Epoch 2/10:   7%|▋         | 26/360 [00:01<00:16, 20.27it/s]

 batch_loss: 113456.80  |  KL+NLL: 103518.51  (KL=103438.37, NLL=80.15)
 batch_loss: 113274.46  |  KL+NLL: 103513.41  (KL=103430.23, NLL=83.18)
 batch_loss: 112921.07  |  KL+NLL: 103476.79  (KL=103415.02, NLL=61.76)
 batch_loss: 113336.88  |  KL+NLL: 103476.17  (KL=103390.16, NLL=86.02)
 batch_loss: 112768.00  |  KL+NLL: 103452.40  (KL=103388.98, NLL=63.43)


Epoch 2/10:   8%|▊         | 29/360 [00:01<00:16, 19.68it/s]

 batch_loss: 112512.41  |  KL+NLL: 103433.77  (KL=103363.09, NLL=70.68)
 batch_loss: 112580.25  |  KL+NLL: 103405.36  (KL=103333.74, NLL=71.61)
 batch_loss: 112543.92  |  KL+NLL: 103374.17  (KL=103276.15, NLL=98.02)


Epoch 2/10:   9%|▉         | 32/360 [00:01<00:16, 20.09it/s]

 batch_loss: 112594.72  |  KL+NLL: 103391.98  (KL=103308.71, NLL=83.27)
 batch_loss: 112419.22  |  KL+NLL: 103329.67  (KL=103260.62, NLL=69.05)


Epoch 2/10:  10%|▉         | 35/360 [00:01<00:16, 19.98it/s]

 batch_loss: 112260.61  |  KL+NLL: 103344.47  (KL=103237.84, NLL=106.64)
 batch_loss: 112384.54  |  KL+NLL: 103263.14  (KL=103209.74, NLL=53.40)
 batch_loss: 112184.27  |  KL+NLL: 103266.34  (KL=103190.04, NLL=76.30)


Epoch 2/10:  11%|█         | 38/360 [00:01<00:15, 20.32it/s]

 batch_loss: 111849.41  |  KL+NLL: 103265.56  (KL=103192.62, NLL=72.94)
 batch_loss: 111262.97  |  KL+NLL: 103281.75  (KL=103167.62, NLL=114.13)


Epoch 2/10:  11%|█▏        | 41/360 [00:02<00:15, 20.57it/s]

 batch_loss: 111559.62  |  KL+NLL: 103252.81  (KL=103151.57, NLL=101.24)
 batch_loss: 111466.95  |  KL+NLL: 103235.03  (KL=103146.65, NLL=88.38)
 batch_loss: 111505.29  |  KL+NLL: 103188.18  (KL=103109.44, NLL=78.75)
 batch_loss: 111251.36  |  KL+NLL: 103166.61  (KL=103094.43, NLL=72.18)
 batch_loss: 111181.44  |  KL+NLL: 103117.34  (KL=103049.05, NLL=68.29)


Epoch 2/10:  12%|█▏        | 44/360 [00:02<00:15, 20.77it/s]

 batch_loss: 110908.76  |  KL+NLL: 103136.20  (KL=103042.66, NLL=93.54)
 batch_loss: 110950.30  |  KL+NLL: 103085.41  (KL=103022.59, NLL=62.82)
 batch_loss: 110481.42  |  KL+NLL: 103013.81  (KL=102960.80, NLL=53.01)


Epoch 2/10:  13%|█▎        | 47/360 [00:02<00:15, 20.56it/s]

 batch_loss: 110734.61  |  KL+NLL: 103075.52  (KL=102993.18, NLL=82.34)
 batch_loss: 110649.04  |  KL+NLL: 103150.56  (KL=102951.34, NLL=199.21)


Epoch 2/10:  14%|█▍        | 50/360 [00:02<00:15, 20.53it/s]

 batch_loss: 110495.76  |  KL+NLL: 103062.29  (KL=102944.46, NLL=117.83)
 batch_loss: 110076.37  |  KL+NLL: 103011.80  (KL=102925.06, NLL=86.74)


Epoch 2/10:  15%|█▍        | 53/360 [00:02<00:15, 20.26it/s]

 batch_loss: 110279.63  |  KL+NLL: 103000.23  (KL=102904.75, NLL=95.48)
 batch_loss: 110211.61  |  KL+NLL: 102967.20  (KL=102852.84, NLL=114.36)
 batch_loss: 110014.31  |  KL+NLL: 102961.03  (KL=102890.05, NLL=70.99)
 batch_loss: 109571.81  |  KL+NLL: 102924.31  (KL=102866.16, NLL=58.15)
 batch_loss: 109564.08  |  KL+NLL: 102940.36  (KL=102836.80, NLL=103.56)


Epoch 2/10:  16%|█▌        | 56/360 [00:02<00:14, 20.50it/s]

 batch_loss: 109552.26  |  KL+NLL: 102878.38  (KL=102799.32, NLL=79.06)
 batch_loss: 109491.59  |  KL+NLL: 102830.00  (KL=102763.01, NLL=66.99)
 batch_loss: 109379.63  |  KL+NLL: 102819.57  (KL=102755.73, NLL=63.84)


Epoch 2/10:  16%|█▋        | 59/360 [00:02<00:14, 20.92it/s]

 batch_loss: 109322.60  |  KL+NLL: 102832.90  (KL=102760.05, NLL=72.85)
 batch_loss: 109260.13  |  KL+NLL: 102783.36  (KL=102715.59, NLL=67.77)


Epoch 2/10:  17%|█▋        | 62/360 [00:03<00:14, 20.63it/s]

 batch_loss: 109167.68  |  KL+NLL: 102735.80  (KL=102672.11, NLL=63.69)
 batch_loss: 108904.23  |  KL+NLL: 102751.47  (KL=102656.63, NLL=94.83)
 batch_loss: 108954.09  |  KL+NLL: 102708.05  (KL=102644.79, NLL=63.26)


Epoch 2/10:  18%|█▊        | 65/360 [00:03<00:14, 20.58it/s]

 batch_loss: 108737.46  |  KL+NLL: 102700.35  (KL=102614.27, NLL=86.08)
 batch_loss: 108507.14  |  KL+NLL: 102705.03  (KL=102629.19, NLL=75.84)


Epoch 2/10:  19%|█▉        | 68/360 [00:03<00:14, 20.85it/s]

 batch_loss: 108620.99  |  KL+NLL: 102712.62  (KL=102607.23, NLL=105.39)
 batch_loss: 108528.63  |  KL+NLL: 102661.23  (KL=102584.30, NLL=76.93)
 batch_loss: 108369.67  |  KL+NLL: 102632.88  (KL=102543.33, NLL=89.55)
 batch_loss: 108022.80  |  KL+NLL: 102605.72  (KL=102524.24, NLL=81.48)
 batch_loss: 108019.81  |  KL+NLL: 102553.33  (KL=102490.96, NLL=62.37)


Epoch 2/10:  20%|█▉        | 71/360 [00:03<00:13, 20.87it/s]

 batch_loss: 107756.01  |  KL+NLL: 102566.53  (KL=102480.45, NLL=86.09)
 batch_loss: 107837.34  |  KL+NLL: 102572.08  (KL=102492.66, NLL=79.42)
 batch_loss: 107423.84  |  KL+NLL: 102529.91  (KL=102469.49, NLL=60.41)


Epoch 2/10:  21%|██        | 74/360 [00:03<00:13, 20.96it/s]

 batch_loss: 107858.85  |  KL+NLL: 102536.90  (KL=102419.20, NLL=117.70)
 batch_loss: 107680.93  |  KL+NLL: 102485.08  (KL=102373.35, NLL=111.73)


Epoch 2/10:  21%|██▏       | 77/360 [00:03<00:13, 21.06it/s]

 batch_loss: 107389.61  |  KL+NLL: 102495.87  (KL=102414.83, NLL=81.04)
 batch_loss: 107092.87  |  KL+NLL: 102448.12  (KL=102359.16, NLL=88.95)
 batch_loss: 107102.76  |  KL+NLL: 102396.53  (KL=102342.04, NLL=54.49)


Epoch 2/10:  22%|██▏       | 80/360 [00:03<00:13, 20.90it/s]

 batch_loss: 106890.77  |  KL+NLL: 102378.37  (KL=102293.89, NLL=84.48)
 batch_loss: 107202.96  |  KL+NLL: 102391.10  (KL=102323.34, NLL=67.76)
 batch_loss: 106691.37  |  KL+NLL: 102353.06  (KL=102293.66, NLL=59.40)
 batch_loss: 106467.36  |  KL+NLL: 102369.18  (KL=102278.43, NLL=90.75)


Epoch 2/10:  23%|██▎       | 83/360 [00:04<00:13, 20.40it/s]

 batch_loss: 106522.08  |  KL+NLL: 102327.98  (KL=102242.01, NLL=85.97)
 batch_loss: 106176.05  |  KL+NLL: 102354.93  (KL=102242.52, NLL=112.41)
 batch_loss: 106344.55  |  KL+NLL: 102386.85  (KL=102263.27, NLL=123.57)


Epoch 2/10:  24%|██▍       | 86/360 [00:04<00:13, 20.66it/s]

 batch_loss: 106414.40  |  KL+NLL: 102276.76  (KL=102169.76, NLL=107.00)
 batch_loss: 105954.48  |  KL+NLL: 102272.67  (KL=102193.38, NLL=79.29)


Epoch 2/10:  25%|██▍       | 89/360 [00:04<00:13, 20.40it/s]

 batch_loss: 106119.57  |  KL+NLL: 102250.35  (KL=102168.80, NLL=81.54)
 batch_loss: 105915.90  |  KL+NLL: 102236.18  (KL=102146.29, NLL=89.89)
 batch_loss: 105676.21  |  KL+NLL: 102203.16  (KL=102146.70, NLL=56.45)


Epoch 2/10:  26%|██▌       | 92/360 [00:04<00:13, 20.45it/s]

 batch_loss: 105481.43  |  KL+NLL: 102185.94  (KL=102103.23, NLL=82.71)
 batch_loss: 105495.49  |  KL+NLL: 102145.04  (KL=102068.24, NLL=76.80)


Epoch 2/10:  26%|██▋       | 95/360 [00:04<00:13, 20.24it/s]

 batch_loss: 105597.13  |  KL+NLL: 102116.43  (KL=102054.92, NLL=61.51)
 batch_loss: 105709.71  |  KL+NLL: 102168.20  (KL=102074.16, NLL=94.04)
 batch_loss: 105189.82  |  KL+NLL: 102100.87  (KL=102025.88, NLL=75.00)
 batch_loss: 105381.31  |  KL+NLL: 102104.23  (KL=102020.10, NLL=84.12)
 batch_loss: 104986.71  |  KL+NLL: 102074.39  (KL=102024.03, NLL=50.36)


Epoch 2/10:  27%|██▋       | 98/360 [00:04<00:12, 20.33it/s]

 batch_loss: 104758.67  |  KL+NLL: 102006.31  (KL=101937.59, NLL=68.72)
 batch_loss: 104762.58  |  KL+NLL: 102061.64  (KL=101978.88, NLL=82.76)
 batch_loss: 104731.29  |  KL+NLL: 102025.23  (KL=101913.48, NLL=111.75)


Epoch 2/10:  28%|██▊       | 101/360 [00:05<00:12, 20.67it/s]

 batch_loss: 104604.41  |  KL+NLL: 101947.17  (KL=101869.41, NLL=77.77)
 batch_loss: 104496.46  |  KL+NLL: 101957.21  (KL=101892.88, NLL=64.34)


Epoch 2/10:  29%|██▉       | 104/360 [00:05<00:12, 20.41it/s]

 batch_loss: 104511.41  |  KL+NLL: 101940.82  (KL=101884.04, NLL=56.78)
 batch_loss: 104348.80  |  KL+NLL: 101931.97  (KL=101834.58, NLL=97.39)
 batch_loss: 104115.18  |  KL+NLL: 101909.96  (KL=101822.70, NLL=87.26)


Epoch 2/10:  30%|██▉       | 107/360 [00:05<00:12, 20.49it/s]

 batch_loss: 104016.84  |  KL+NLL: 101871.72  (KL=101791.29, NLL=80.43)
 batch_loss: 103886.30  |  KL+NLL: 101902.74  (KL=101811.16, NLL=91.58)
 batch_loss: 103873.81  |  KL+NLL: 101878.08  (KL=101788.62, NLL=89.46)
 batch_loss: 103828.25  |  KL+NLL: 101859.38  (KL=101771.61, NLL=87.77)


Epoch 2/10:  31%|███       | 110/360 [00:05<00:12, 20.45it/s]

 batch_loss: 103505.38  |  KL+NLL: 101801.05  (KL=101740.55, NLL=60.51)
 batch_loss: 103254.01  |  KL+NLL: 101799.93  (KL=101743.50, NLL=56.43)
 batch_loss: 103160.41  |  KL+NLL: 101751.69  (KL=101678.62, NLL=73.07)


Epoch 2/10:  31%|███▏      | 113/360 [00:05<00:12, 20.47it/s]

 batch_loss: 103294.07  |  KL+NLL: 101775.79  (KL=101707.64, NLL=68.15)
 batch_loss: 102960.95  |  KL+NLL: 101792.86  (KL=101678.52, NLL=114.35)


Epoch 2/10:  32%|███▏      | 116/360 [00:05<00:11, 20.42it/s]

 batch_loss: 102874.08  |  KL+NLL: 101750.23  (KL=101644.15, NLL=106.08)
 batch_loss: 102967.47  |  KL+NLL: 101761.45  (KL=101675.53, NLL=85.92)
 batch_loss: 102535.15  |  KL+NLL: 101665.97  (KL=101601.99, NLL=63.98)


Epoch 2/10:  33%|███▎      | 119/360 [00:05<00:11, 20.81it/s]

 batch_loss: 102621.09  |  KL+NLL: 101717.39  (KL=101622.65, NLL=94.75)
 batch_loss: 102443.89  |  KL+NLL: 101667.13  (KL=101575.56, NLL=91.56)


Epoch 2/10:  34%|███▍      | 122/360 [00:06<00:11, 20.38it/s]

 batch_loss: 102586.17  |  KL+NLL: 101635.91  (KL=101533.58, NLL=102.33)
 batch_loss: 102127.27  |  KL+NLL: 101629.17  (KL=101530.69, NLL=98.48)
 batch_loss: 101936.10  |  KL+NLL: 101597.76  (KL=101524.98, NLL=72.77)
 batch_loss: 102065.96  |  KL+NLL: 101581.18  (KL=101511.05, NLL=70.14)
 batch_loss: 101746.80  |  KL+NLL: 101601.55  (KL=101506.09, NLL=95.47)


Epoch 2/10:  35%|███▍      | 125/360 [00:06<00:11, 20.37it/s]

 batch_loss: 101864.70  |  KL+NLL: 101547.02  (KL=101485.44, NLL=61.58)
 batch_loss: 101654.94  |  KL+NLL: 101525.96  (KL=101438.61, NLL=87.35)
 batch_loss: 101662.47  |  KL+NLL: 101572.90  (KL=101417.11, NLL=155.79)


Epoch 2/10:  36%|███▌      | 128/360 [00:06<00:11, 20.28it/s]

 batch_loss: 101399.42  |  KL+NLL: 101515.38  (KL=101434.05, NLL=81.32)
 batch_loss: 101511.14  |  KL+NLL: 101458.93  (KL=101381.64, NLL=77.29)


Epoch 2/10:  36%|███▋      | 131/360 [00:06<00:11, 20.64it/s]

 batch_loss: 101102.98  |  KL+NLL: 101438.32  (KL=101370.45, NLL=67.87)
 batch_loss: 101402.90  |  KL+NLL: 101463.81  (KL=101375.10, NLL=88.71)
 batch_loss: 100801.50  |  KL+NLL: 101425.92  (KL=101353.97, NLL=71.95)


Epoch 2/10:  37%|███▋      | 134/360 [00:06<00:11, 20.32it/s]

 batch_loss: 101022.57  |  KL+NLL: 101440.17  (KL=101317.52, NLL=122.64)
 batch_loss: 100819.59  |  KL+NLL: 101417.34  (KL=101323.25, NLL=94.09)


Epoch 2/10:  38%|███▊      | 137/360 [00:06<00:11, 20.27it/s]

 batch_loss: 100722.52  |  KL+NLL: 101346.76  (KL=101259.73, NLL=87.03)
 batch_loss: 101017.42  |  KL+NLL: 101373.50  (KL=101264.95, NLL=108.56)
 batch_loss: 100630.94  |  KL+NLL: 101326.07  (KL=101245.96, NLL=80.11)
 batch_loss: 100433.16  |  KL+NLL: 101317.64  (KL=101215.10, NLL=102.54)
 batch_loss: 100364.35  |  KL+NLL: 101287.32  (KL=101195.95, NLL=91.38)


Epoch 2/10:  39%|███▉      | 140/360 [00:06<00:10, 20.33it/s]

 batch_loss: 100205.05  |  KL+NLL: 101281.20  (KL=101216.42, NLL=64.77)
 batch_loss: 99981.47  |  KL+NLL: 101229.78  (KL=101170.30, NLL=59.48)
 batch_loss: 99966.25  |  KL+NLL: 101231.42  (KL=101170.98, NLL=60.44)


Epoch 2/10:  40%|███▉      | 143/360 [00:07<00:10, 20.23it/s]

 batch_loss: 99790.22  |  KL+NLL: 101230.56  (KL=101159.54, NLL=71.02)
 batch_loss: 99628.97  |  KL+NLL: 101225.01  (KL=101146.53, NLL=78.47)


Epoch 2/10:  41%|████      | 146/360 [00:07<00:10, 20.41it/s]

 batch_loss: 99796.35  |  KL+NLL: 101179.80  (KL=101093.84, NLL=85.97)
 batch_loss: 99679.40  |  KL+NLL: 101155.41  (KL=101071.71, NLL=83.70)
 batch_loss: 99136.25  |  KL+NLL: 101144.41  (KL=101061.01, NLL=83.40)


Epoch 2/10:  41%|████▏     | 149/360 [00:07<00:10, 20.43it/s]

 batch_loss: 99203.60  |  KL+NLL: 101158.98  (KL=101068.77, NLL=90.21)
 batch_loss: 99281.51  |  KL+NLL: 101142.62  (KL=101064.01, NLL=78.61)


Epoch 2/10:  42%|████▏     | 152/360 [00:07<00:10, 20.47it/s]

 batch_loss: 98891.64  |  KL+NLL: 101121.69  (KL=101041.45, NLL=80.23)
 batch_loss: 98606.32  |  KL+NLL: 101066.28  (KL=100975.66, NLL=90.62)
 batch_loss: 98583.02  |  KL+NLL: 101008.61  (KL=100940.08, NLL=68.53)
 batch_loss: 98883.24  |  KL+NLL: 101072.53  (KL=100984.62, NLL=87.92)
 batch_loss: 98485.78  |  KL+NLL: 101021.10  (KL=100935.45, NLL=85.65)


Epoch 2/10:  43%|████▎     | 155/360 [00:07<00:09, 20.50it/s]

 batch_loss: 98836.64  |  KL+NLL: 101018.63  (KL=100914.14, NLL=104.49)
 batch_loss: 98618.63  |  KL+NLL: 101027.61  (KL=100968.87, NLL=58.74)
 batch_loss: 98277.84  |  KL+NLL: 101027.53  (KL=100917.42, NLL=110.11)


Epoch 2/10:  44%|████▍     | 158/360 [00:07<00:09, 20.47it/s]

 batch_loss: 98389.80  |  KL+NLL: 100990.22  (KL=100892.98, NLL=97.24)
 batch_loss: 98173.01  |  KL+NLL: 100979.42  (KL=100905.20, NLL=74.22)


Epoch 2/10:  45%|████▍     | 161/360 [00:07<00:09, 20.77it/s]

 batch_loss: 98266.92  |  KL+NLL: 100933.17  (KL=100869.75, NLL=63.42)
 batch_loss: 97762.19  |  KL+NLL: 100860.42  (KL=100800.74, NLL=59.67)
 batch_loss: 97759.51  |  KL+NLL: 100975.63  (KL=100880.73, NLL=94.89)


Epoch 2/10:  46%|████▌     | 164/360 [00:08<00:09, 20.81it/s]

 batch_loss: 97351.00  |  KL+NLL: 100840.77  (KL=100771.78, NLL=68.98)
 batch_loss: 97603.31  |  KL+NLL: 100873.34  (KL=100778.99, NLL=94.35)


Epoch 2/10:  46%|████▋     | 167/360 [00:08<00:09, 20.68it/s]

 batch_loss: 97318.91  |  KL+NLL: 100879.04  (KL=100784.23, NLL=94.80)
 batch_loss: 97493.57  |  KL+NLL: 100817.60  (KL=100752.43, NLL=65.17)
 batch_loss: 97296.80  |  KL+NLL: 100774.19  (KL=100710.18, NLL=64.01)
 batch_loss: 97308.76  |  KL+NLL: 100828.29  (KL=100738.86, NLL=89.43)
 batch_loss: 97187.39  |  KL+NLL: 100822.02  (KL=100719.59, NLL=102.44)


Epoch 2/10:  47%|████▋     | 170/360 [00:08<00:09, 20.25it/s]

 batch_loss: 96950.63  |  KL+NLL: 100772.67  (KL=100707.19, NLL=65.48)
 batch_loss: 96914.74  |  KL+NLL: 100799.93  (KL=100693.30, NLL=106.63)
 batch_loss: 96852.72  |  KL+NLL: 100715.94  (KL=100644.76, NLL=71.18)


Epoch 2/10:  48%|████▊     | 173/360 [00:08<00:09, 20.45it/s]

 batch_loss: 96634.62  |  KL+NLL: 100725.95  (KL=100645.14, NLL=80.81)
 batch_loss: 96718.25  |  KL+NLL: 100659.16  (KL=100584.34, NLL=74.82)


Epoch 2/10:  49%|████▉     | 176/360 [00:08<00:09, 20.09it/s]

 batch_loss: 96920.74  |  KL+NLL: 100628.55  (KL=100585.95, NLL=42.60)
 batch_loss: 96363.29  |  KL+NLL: 100611.68  (KL=100535.44, NLL=76.24)


Epoch 2/10:  50%|████▉     | 179/360 [00:08<00:09, 20.10it/s]

 batch_loss: 96264.88  |  KL+NLL: 100599.00  (KL=100519.59, NLL=79.41)
 batch_loss: 96183.76  |  KL+NLL: 100672.02  (KL=100579.95, NLL=92.07)
 batch_loss: 95852.91  |  KL+NLL: 100623.92  (KL=100527.45, NLL=96.47)
 batch_loss: 95787.37  |  KL+NLL: 100607.28  (KL=100521.19, NLL=86.10)
 batch_loss: 95828.32  |  KL+NLL: 100537.06  (KL=100485.53, NLL=51.53)


Epoch 2/10:  51%|█████     | 182/360 [00:08<00:08, 20.33it/s]

 batch_loss: 95685.54  |  KL+NLL: 100644.88  (KL=100468.77, NLL=176.11)
 batch_loss: 95672.45  |  KL+NLL: 100547.08  (KL=100482.20, NLL=64.88)
 batch_loss: 95874.25  |  KL+NLL: 100512.82  (KL=100459.16, NLL=53.66)


Epoch 2/10:  51%|█████▏    | 185/360 [00:09<00:08, 20.72it/s]

 batch_loss: 95361.79  |  KL+NLL: 100503.68  (KL=100434.95, NLL=68.73)
 batch_loss: 95520.34  |  KL+NLL: 100530.73  (KL=100441.91, NLL=88.82)


Epoch 2/10:  52%|█████▏    | 188/360 [00:09<00:08, 20.33it/s]

 batch_loss: 95229.67  |  KL+NLL: 100478.20  (KL=100411.98, NLL=66.22)
 batch_loss: 95129.12  |  KL+NLL: 100478.87  (KL=100356.73, NLL=122.14)
 batch_loss: 94917.23  |  KL+NLL: 100454.13  (KL=100370.77, NLL=83.37)
 batch_loss: 94709.72  |  KL+NLL: 100351.64  (KL=100297.02, NLL=54.62)


Epoch 2/10:  53%|█████▎    | 191/360 [00:09<00:08, 20.01it/s]

 batch_loss: 94720.71  |  KL+NLL: 100419.99  (KL=100335.52, NLL=84.47)
 batch_loss: 94772.66  |  KL+NLL: 100434.28  (KL=100314.02, NLL=120.26)


Epoch 2/10:  54%|█████▍    | 194/360 [00:09<00:08, 20.31it/s]

 batch_loss: 94660.07  |  KL+NLL: 100358.88  (KL=100289.98, NLL=68.90)
 batch_loss: 94218.62  |  KL+NLL: 100443.14  (KL=100324.82, NLL=118.32)
 batch_loss: 94148.59  |  KL+NLL: 100342.59  (KL=100246.88, NLL=95.71)


Epoch 2/10:  55%|█████▍    | 197/360 [00:09<00:08, 20.18it/s]

 batch_loss: 93951.49  |  KL+NLL: 100350.67  (KL=100225.73, NLL=124.93)
 batch_loss: 94112.78  |  KL+NLL: 100337.27  (KL=100259.62, NLL=77.64)


Epoch 2/10:  56%|█████▌    | 200/360 [00:09<00:07, 20.45it/s]

 batch_loss: 93838.02  |  KL+NLL: 100273.59  (KL=100206.47, NLL=67.12)
 batch_loss: 93648.69  |  KL+NLL: 100288.30  (KL=100194.70, NLL=93.61)
 batch_loss: 93923.50  |  KL+NLL: 100277.46  (KL=100188.55, NLL=88.91)
 batch_loss: 93709.61  |  KL+NLL: 100236.13  (KL=100169.91, NLL=66.22)
 batch_loss: 93466.25  |  KL+NLL: 100307.01  (KL=100203.45, NLL=103.57)


Epoch 2/10:  56%|█████▋    | 203/360 [00:10<00:07, 20.40it/s]

 batch_loss: 93015.68  |  KL+NLL: 100302.19  (KL=100152.20, NLL=149.99)
 batch_loss: 93283.83  |  KL+NLL: 100212.07  (KL=100111.88, NLL=100.19)
 batch_loss: 93174.13  |  KL+NLL: 100213.85  (KL=100127.81, NLL=86.04)


Epoch 2/10:  57%|█████▋    | 206/360 [00:10<00:07, 20.80it/s]

 batch_loss: 92832.27  |  KL+NLL: 100207.05  (KL=100133.47, NLL=73.58)
 batch_loss: 92705.35  |  KL+NLL: 100196.85  (KL=100080.55, NLL=116.30)


Epoch 2/10:  58%|█████▊    | 209/360 [00:10<00:07, 20.99it/s]

 batch_loss: 93010.93  |  KL+NLL: 100209.23  (KL=100089.29, NLL=119.95)
 batch_loss: 92929.10  |  KL+NLL: 100195.91  (KL=100046.27, NLL=149.63)
 batch_loss: 92936.00  |  KL+NLL: 100158.07  (KL=100047.52, NLL=110.55)


Epoch 2/10:  59%|█████▉    | 212/360 [00:10<00:07, 21.03it/s]

 batch_loss: 92731.53  |  KL+NLL: 100077.81  (KL=100002.51, NLL=75.31)
 batch_loss: 92916.21  |  KL+NLL: 100160.70  (KL=100034.05, NLL=126.64)


Epoch 2/10:  60%|█████▉    | 215/360 [00:10<00:06, 21.00it/s]

 batch_loss: 92647.46  |  KL+NLL: 100136.13  (KL=100006.20, NLL=129.94)
 batch_loss: 92445.05  |  KL+NLL: 100077.01  (KL=99977.38, NLL=99.63)
 batch_loss: 92244.74  |  KL+NLL: 100051.23  (KL=99958.59, NLL=92.64)
 batch_loss: 92211.17  |  KL+NLL: 99985.21  (KL=99909.92, NLL=75.29)
 batch_loss: 92102.16  |  KL+NLL: 100012.36  (KL=99932.98, NLL=79.38)


Epoch 2/10:  61%|██████    | 218/360 [00:10<00:06, 20.72it/s]

 batch_loss: 91716.23  |  KL+NLL: 100034.95  (KL=99952.65, NLL=82.30)
 batch_loss: 91858.19  |  KL+NLL: 100012.86  (KL=99915.86, NLL=97.00)
 batch_loss: 91539.45  |  KL+NLL: 99955.41  (KL=99886.61, NLL=68.81)


Epoch 2/10:  61%|██████▏   | 221/360 [00:10<00:06, 20.64it/s]

 batch_loss: 91737.68  |  KL+NLL: 99956.82  (KL=99892.56, NLL=64.26)
 batch_loss: 91143.64  |  KL+NLL: 99954.74  (KL=99887.02, NLL=67.72)


Epoch 2/10:  62%|██████▏   | 224/360 [00:11<00:06, 20.58it/s]

 batch_loss: 91110.34  |  KL+NLL: 99907.48  (KL=99844.30, NLL=63.19)
 batch_loss: 91305.61  |  KL+NLL: 99912.62  (KL=99828.62, NLL=83.99)
 batch_loss: 91071.16  |  KL+NLL: 99882.64  (KL=99824.41, NLL=58.23)


Epoch 2/10:  63%|██████▎   | 227/360 [00:11<00:06, 20.66it/s]

 batch_loss: 91154.16  |  KL+NLL: 99864.43  (KL=99777.13, NLL=87.30)
 batch_loss: 90845.52  |  KL+NLL: 99799.70  (KL=99741.55, NLL=58.15)


Epoch 2/10:  64%|██████▍   | 230/360 [00:11<00:06, 20.56it/s]

 batch_loss: 90549.09  |  KL+NLL: 99880.46  (KL=99798.66, NLL=81.80)
 batch_loss: 90527.11  |  KL+NLL: 99824.49  (KL=99738.86, NLL=85.63)
 batch_loss: 90779.59  |  KL+NLL: 99824.89  (KL=99724.40, NLL=100.49)
 batch_loss: 90545.31  |  KL+NLL: 99820.50  (KL=99732.45, NLL=88.06)


Epoch 2/10:  65%|██████▍   | 233/360 [00:11<00:06, 20.06it/s]

 batch_loss: 90399.29  |  KL+NLL: 99844.52  (KL=99729.74, NLL=114.78)
 batch_loss: 90545.09  |  KL+NLL: 99768.39  (KL=99657.32, NLL=111.07)
 batch_loss: 90480.12  |  KL+NLL: 99769.67  (KL=99664.98, NLL=104.70)
 batch_loss: 89823.49  |  KL+NLL: 99779.62  (KL=99696.76, NLL=82.86)


Epoch 2/10:  66%|██████▌   | 236/360 [00:11<00:06, 20.11it/s]

 batch_loss: 90051.30  |  KL+NLL: 99718.70  (KL=99647.88, NLL=70.81)
 batch_loss: 89786.97  |  KL+NLL: 99730.47  (KL=99641.77, NLL=88.70)
 batch_loss: 89687.53  |  KL+NLL: 99693.73  (KL=99613.77, NLL=79.96)


Epoch 2/10:  66%|██████▋   | 239/360 [00:11<00:06, 19.87it/s]

 batch_loss: 89710.14  |  KL+NLL: 99756.80  (KL=99607.37, NLL=149.43)
 batch_loss: 89725.43  |  KL+NLL: 99703.48  (KL=99608.65, NLL=94.83)


Epoch 2/10:  67%|██████▋   | 242/360 [00:11<00:05, 20.27it/s]

 batch_loss: 89736.83  |  KL+NLL: 99682.86  (KL=99563.98, NLL=118.89)
 batch_loss: 89351.14  |  KL+NLL: 99690.09  (KL=99565.16, NLL=124.93)
 batch_loss: 89310.92  |  KL+NLL: 99661.94  (KL=99579.12, NLL=82.82)
 batch_loss: 89304.60  |  KL+NLL: 99637.29  (KL=99532.53, NLL=104.76)


Epoch 2/10:  69%|██████▉   | 248/360 [00:12<00:05, 20.26it/s]

 batch_loss: 88858.47  |  KL+NLL: 99559.93  (KL=99507.22, NLL=52.71)
 batch_loss: 88835.77  |  KL+NLL: 99612.63  (KL=99480.13, NLL=132.49)
 batch_loss: 88829.45  |  KL+NLL: 99551.13  (KL=99456.26, NLL=94.88)
 batch_loss: 88964.24  |  KL+NLL: 99536.38  (KL=99467.85, NLL=68.53)
 batch_loss: 88444.28  |  KL+NLL: 99556.39  (KL=99469.59, NLL=86.81)


Epoch 2/10:  70%|██████▉   | 251/360 [00:12<00:05, 20.69it/s]

 batch_loss: 88539.57  |  KL+NLL: 99478.36  (KL=99420.30, NLL=58.07)
 batch_loss: 88627.72  |  KL+NLL: 99562.70  (KL=99473.51, NLL=89.19)
 batch_loss: 88192.19  |  KL+NLL: 99525.90  (KL=99426.69, NLL=99.22)
 batch_loss: 88193.40  |  KL+NLL: 99454.13  (KL=99375.60, NLL=78.53)


Epoch 2/10:  71%|███████   | 254/360 [00:12<00:05, 20.36it/s]

 batch_loss: 88043.75  |  KL+NLL: 99421.20  (KL=99344.09, NLL=77.11)


Epoch 2/10:  71%|███████▏  | 257/360 [00:12<00:05, 20.56it/s]

 batch_loss: 87991.05  |  KL+NLL: 99426.12  (KL=99351.52, NLL=74.60)
 batch_loss: 87819.89  |  KL+NLL: 99429.53  (KL=99341.95, NLL=87.58)
 batch_loss: 87959.37  |  KL+NLL: 99433.71  (KL=99345.79, NLL=87.92)
 batch_loss: 87820.87  |  KL+NLL: 99428.78  (KL=99322.10, NLL=106.68)
 batch_loss: 87650.72  |  KL+NLL: 99397.88  (KL=99313.77, NLL=84.11)


Epoch 2/10:  73%|███████▎  | 263/360 [00:12<00:04, 20.85it/s]

 batch_loss: 87854.06  |  KL+NLL: 99394.52  (KL=99283.60, NLL=110.92)
 batch_loss: 87419.11  |  KL+NLL: 99473.05  (KL=99302.20, NLL=170.85)
 batch_loss: 87111.01  |  KL+NLL: 99373.64  (KL=99289.42, NLL=84.22)
 batch_loss: 87200.56  |  KL+NLL: 99362.25  (KL=99290.74, NLL=71.50)
 batch_loss: 87384.09  |  KL+NLL: 99309.81  (KL=99234.04, NLL=75.77)


Epoch 2/10:  74%|███████▍  | 266/360 [00:13<00:04, 20.65it/s]

 batch_loss: 87102.43  |  KL+NLL: 99305.21  (KL=99241.62, NLL=63.59)
 batch_loss: 86751.57  |  KL+NLL: 99333.22  (KL=99212.58, NLL=120.64)
 batch_loss: 86864.72  |  KL+NLL: 99314.63  (KL=99196.94, NLL=117.69)
 batch_loss: 86741.55  |  KL+NLL: 99392.54  (KL=99251.67, NLL=140.87)


Epoch 2/10:  76%|███████▌  | 272/360 [00:13<00:04, 20.20it/s]

 batch_loss: 86513.61  |  KL+NLL: 99240.91  (KL=99181.32, NLL=59.59)
 batch_loss: 86405.92  |  KL+NLL: 99279.83  (KL=99173.46, NLL=106.37)
 batch_loss: 86316.29  |  KL+NLL: 99245.81  (KL=99119.45, NLL=126.36)
 batch_loss: 86319.59  |  KL+NLL: 99262.25  (KL=99129.75, NLL=132.50)
 batch_loss: 86125.10  |  KL+NLL: 99206.38  (KL=99134.71, NLL=71.67)


Epoch 2/10:  76%|███████▋  | 275/360 [00:13<00:04, 19.83it/s]

 batch_loss: 86278.30  |  KL+NLL: 99205.45  (KL=99093.80, NLL=111.65)
 batch_loss: 86193.78  |  KL+NLL: 99237.36  (KL=99121.82, NLL=115.54)
 batch_loss: 85783.48  |  KL+NLL: 99172.54  (KL=99107.73, NLL=64.80)


Epoch 2/10:  77%|███████▋  | 277/360 [00:13<00:04, 19.86it/s]

 batch_loss: 86091.60  |  KL+NLL: 99193.17  (KL=99048.42, NLL=144.75)


Epoch 2/10:  78%|███████▊  | 280/360 [00:13<00:03, 20.08it/s]

 batch_loss: 85522.18  |  KL+NLL: 99149.70  (KL=99040.14, NLL=109.56)
 batch_loss: 85580.43  |  KL+NLL: 99152.81  (KL=99070.14, NLL=82.66)
 batch_loss: 85522.25  |  KL+NLL: 99132.90  (KL=99010.80, NLL=122.11)
 batch_loss: 85389.29  |  KL+NLL: 99084.63  (KL=99021.92, NLL=62.71)
 batch_loss: 85356.53  |  KL+NLL: 99114.94  (KL=98987.00, NLL=127.94)


Epoch 2/10:  79%|███████▉  | 286/360 [00:14<00:03, 20.44it/s]

 batch_loss: 84990.27  |  KL+NLL: 99056.99  (KL=98980.30, NLL=76.69)
 batch_loss: 85201.89  |  KL+NLL: 99073.06  (KL=98971.39, NLL=101.67)
 batch_loss: 85220.10  |  KL+NLL: 99016.19  (KL=98959.46, NLL=56.73)
 batch_loss: 85131.99  |  KL+NLL: 99018.59  (KL=98935.98, NLL=82.62)
 batch_loss: 84755.54  |  KL+NLL: 99026.49  (KL=98912.64, NLL=113.84)


Epoch 2/10:  80%|████████  | 289/360 [00:14<00:03, 20.81it/s]

 batch_loss: 84721.45  |  KL+NLL: 99062.01  (KL=98966.20, NLL=95.81)
 batch_loss: 84568.75  |  KL+NLL: 99073.60  (KL=98943.16, NLL=130.44)
 batch_loss: 84520.55  |  KL+NLL: 99001.45  (KL=98888.87, NLL=112.58)
 batch_loss: 84540.49  |  KL+NLL: 98968.07  (KL=98860.13, NLL=107.94)


Epoch 2/10:  81%|████████  | 292/360 [00:14<00:03, 20.76it/s]

 batch_loss: 83977.85  |  KL+NLL: 98973.85  (KL=98873.18, NLL=100.67)


Epoch 2/10:  82%|████████▏ | 295/360 [00:14<00:03, 20.90it/s]

 batch_loss: 84290.45  |  KL+NLL: 98899.28  (KL=98849.09, NLL=50.19)
 batch_loss: 84362.71  |  KL+NLL: 98906.29  (KL=98840.22, NLL=66.07)
 batch_loss: 83520.59  |  KL+NLL: 98897.76  (KL=98835.05, NLL=62.71)
 batch_loss: 84078.78  |  KL+NLL: 98918.62  (KL=98825.12, NLL=93.51)
 batch_loss: 83830.54  |  KL+NLL: 98815.63  (KL=98750.52, NLL=65.10)


Epoch 2/10:  84%|████████▎ | 301/360 [00:14<00:02, 20.98it/s]

 batch_loss: 83599.81  |  KL+NLL: 98865.71  (KL=98778.04, NLL=87.68)
 batch_loss: 83900.97  |  KL+NLL: 98859.39  (KL=98776.30, NLL=83.09)
 batch_loss: 83701.43  |  KL+NLL: 98837.75  (KL=98746.33, NLL=91.42)
 batch_loss: 83313.50  |  KL+NLL: 98859.26  (KL=98773.46, NLL=85.80)
 batch_loss: 83350.24  |  KL+NLL: 98830.04  (KL=98725.05, NLL=104.99)


Epoch 2/10:  84%|████████▍ | 304/360 [00:14<00:02, 20.65it/s]

 batch_loss: 83494.04  |  KL+NLL: 98809.94  (KL=98702.72, NLL=107.22)
 batch_loss: 82966.38  |  KL+NLL: 98785.89  (KL=98711.54, NLL=74.35)
 batch_loss: 83149.46  |  KL+NLL: 98807.91  (KL=98707.02, NLL=100.89)
 batch_loss: 83077.15  |  KL+NLL: 98813.58  (KL=98691.50, NLL=122.08)


Epoch 2/10:  85%|████████▌ | 307/360 [00:15<00:02, 20.58it/s]

 batch_loss: 82611.86  |  KL+NLL: 98816.81  (KL=98672.48, NLL=144.34)


Epoch 2/10:  86%|████████▌ | 310/360 [00:15<00:02, 20.39it/s]

 batch_loss: 82412.27  |  KL+NLL: 98742.60  (KL=98622.16, NLL=120.44)
 batch_loss: 82629.36  |  KL+NLL: 98742.61  (KL=98652.02, NLL=90.58)
 batch_loss: 82280.18  |  KL+NLL: 98698.76  (KL=98594.81, NLL=103.95)
 batch_loss: 82311.55  |  KL+NLL: 98747.11  (KL=98646.81, NLL=100.30)


Epoch 2/10:  87%|████████▋ | 313/360 [00:15<00:02, 20.46it/s]

 batch_loss: 82341.24  |  KL+NLL: 98769.03  (KL=98613.59, NLL=155.44)
 batch_loss: 82384.52  |  KL+NLL: 98715.65  (KL=98609.87, NLL=105.78)
 batch_loss: 81981.88  |  KL+NLL: 98682.63  (KL=98588.49, NLL=94.13)
 batch_loss: 81896.55  |  KL+NLL: 98615.90  (KL=98567.25, NLL=48.65)


Epoch 2/10:  88%|████████▊ | 316/360 [00:15<00:02, 20.90it/s]

 batch_loss: 81967.43  |  KL+NLL: 98653.89  (KL=98558.97, NLL=94.92)


Epoch 2/10:  89%|████████▊ | 319/360 [00:15<00:01, 20.91it/s]

 batch_loss: 81681.60  |  KL+NLL: 98634.90  (KL=98553.17, NLL=81.73)
 batch_loss: 81482.96  |  KL+NLL: 98702.53  (KL=98579.80, NLL=122.73)
 batch_loss: 81424.00  |  KL+NLL: 98710.18  (KL=98554.25, NLL=155.93)
 batch_loss: 81474.23  |  KL+NLL: 98625.12  (KL=98524.47, NLL=100.66)
 batch_loss: 81374.89  |  KL+NLL: 98571.45  (KL=98461.54, NLL=109.91)


Epoch 2/10:  90%|█████████ | 325/360 [00:15<00:01, 20.79it/s]

 batch_loss: 81229.61  |  KL+NLL: 98585.33  (KL=98495.48, NLL=89.86)
 batch_loss: 81355.32  |  KL+NLL: 98637.39  (KL=98504.38, NLL=133.02)
 batch_loss: 81104.86  |  KL+NLL: 98562.25  (KL=98490.56, NLL=71.69)
 batch_loss: 81076.08  |  KL+NLL: 98544.43  (KL=98479.95, NLL=64.49)
 batch_loss: 81021.25  |  KL+NLL: 98519.55  (KL=98436.89, NLL=82.66)


Epoch 2/10:  91%|█████████ | 328/360 [00:16<00:01, 20.87it/s]

 batch_loss: 81107.41  |  KL+NLL: 98599.71  (KL=98464.40, NLL=135.31)
 batch_loss: 80764.09  |  KL+NLL: 98506.15  (KL=98414.53, NLL=91.62)
 batch_loss: 80839.43  |  KL+NLL: 98443.56  (KL=98370.12, NLL=73.44)
 batch_loss: 80640.65  |  KL+NLL: 98519.91  (KL=98409.63, NLL=110.27)


Epoch 2/10:  92%|█████████▏| 331/360 [00:16<00:01, 20.70it/s]

 batch_loss: 80230.10  |  KL+NLL: 98493.31  (KL=98405.04, NLL=88.27)


Epoch 2/10:  93%|█████████▎| 334/360 [00:16<00:01, 20.66it/s]

 batch_loss: 80457.48  |  KL+NLL: 98450.56  (KL=98349.55, NLL=101.01)
 batch_loss: 80144.08  |  KL+NLL: 98436.35  (KL=98352.37, NLL=83.98)
 batch_loss: 80275.97  |  KL+NLL: 98421.22  (KL=98327.77, NLL=93.45)
 batch_loss: 80270.15  |  KL+NLL: 98388.27  (KL=98303.40, NLL=84.87)
 batch_loss: 80210.80  |  KL+NLL: 98429.42  (KL=98314.81, NLL=114.60)


Epoch 2/10:  94%|█████████▎| 337/360 [00:16<00:01, 20.55it/s]

 batch_loss: 79964.88  |  KL+NLL: 98410.85  (KL=98325.15, NLL=85.70)
 batch_loss: 79860.20  |  KL+NLL: 98357.68  (KL=98255.91, NLL=101.77)
 batch_loss: 79370.56  |  KL+NLL: 98388.23  (KL=98306.08, NLL=82.15)


Epoch 2/10:  94%|█████████▍| 340/360 [00:16<00:00, 20.16it/s]

 batch_loss: 79630.85  |  KL+NLL: 98391.94  (KL=98310.91, NLL=81.03)


Epoch 2/10:  95%|█████████▌| 343/360 [00:16<00:00, 20.40it/s]

 batch_loss: 79384.31  |  KL+NLL: 98385.19  (KL=98274.30, NLL=110.89)
 batch_loss: 79369.82  |  KL+NLL: 98385.75  (KL=98251.55, NLL=134.20)
 batch_loss: 79440.73  |  KL+NLL: 98308.35  (KL=98251.95, NLL=56.40)
 batch_loss: 79151.78  |  KL+NLL: 98317.71  (KL=98225.06, NLL=92.65)
 batch_loss: 78889.52  |  KL+NLL: 98362.28  (KL=98239.95, NLL=122.33)


Epoch 2/10:  97%|█████████▋| 349/360 [00:17<00:00, 20.52it/s]

 batch_loss: 79196.27  |  KL+NLL: 98352.03  (KL=98187.88, NLL=164.15)
 batch_loss: 78956.49  |  KL+NLL: 98359.14  (KL=98224.87, NLL=134.27)
 batch_loss: 78869.09  |  KL+NLL: 98253.50  (KL=98175.11, NLL=78.39)
 batch_loss: 78812.81  |  KL+NLL: 98312.94  (KL=98209.78, NLL=103.16)
 batch_loss: 78765.03  |  KL+NLL: 98331.88  (KL=98232.06, NLL=99.82)


Epoch 2/10:  98%|█████████▊| 352/360 [00:17<00:00, 20.40it/s]

 batch_loss: 78334.82  |  KL+NLL: 98245.24  (KL=98164.52, NLL=80.72)
 batch_loss: 78574.28  |  KL+NLL: 98228.46  (KL=98128.30, NLL=100.16)
 batch_loss: 78478.29  |  KL+NLL: 98245.61  (KL=98143.94, NLL=101.68)
 batch_loss: 78334.72  |  KL+NLL: 98164.27  (KL=98095.16, NLL=69.11)


Epoch 2/10:  99%|█████████▊| 355/360 [00:17<00:00, 20.93it/s]

 batch_loss: 77969.41  |  KL+NLL: 98197.84  (KL=98119.74, NLL=78.10)


Epoch 2/10:  99%|█████████▉| 358/360 [00:17<00:00, 21.30it/s]

 batch_loss: 78094.46  |  KL+NLL: 98165.98  (KL=98105.20, NLL=60.79)
 batch_loss: 77913.56  |  KL+NLL: 98203.35  (KL=98108.71, NLL=94.64)
 batch_loss: 78160.05  |  KL+NLL: 98180.32  (KL=98090.05, NLL=90.28)
 batch_loss: 77941.57  |  KL+NLL: 98191.76  (KL=98081.65, NLL=110.11)


Epoch 2/10: 100%|██████████| 360/360 [00:17<00:00, 20.42it/s]


 batch_loss: 78030.49  |  KL+NLL: 98106.85  (KL=97999.40, NLL=107.45)
Epoch 2 - ELBO Loss: 96243.1898


Epoch 3/10:   0%|          | 1/360 [00:00<00:48,  7.45it/s]

 batch_loss: 77641.80  |  KL+NLL: 98113.83  (KL=98036.68, NLL=77.16)
 batch_loss: 77409.46  |  KL+NLL: 98123.55  (KL=98030.08, NLL=93.47)


Epoch 3/10:   1%|          | 4/360 [00:00<00:23, 15.31it/s]

 batch_loss: 77307.01  |  KL+NLL: 98139.32  (KL=98035.17, NLL=104.14)
 batch_loss: 77044.37  |  KL+NLL: 98134.10  (KL=98055.58, NLL=78.52)
 batch_loss: 77162.40  |  KL+NLL: 98092.29  (KL=97997.43, NLL=94.86)


Epoch 3/10:   2%|▏         | 7/360 [00:00<00:19, 18.21it/s]

 batch_loss: 77055.65  |  KL+NLL: 98061.79  (KL=97979.40, NLL=82.40)
 batch_loss: 76870.93  |  KL+NLL: 98078.78  (KL=97973.86, NLL=104.93)


Epoch 3/10:   2%|▎         | 9/360 [00:00<00:18, 18.73it/s]

 batch_loss: 77022.48  |  KL+NLL: 98130.92  (KL=97940.66, NLL=190.26)
 batch_loss: 76789.80  |  KL+NLL: 98004.36  (KL=97909.73, NLL=94.62)
 batch_loss: 76749.17  |  KL+NLL: 98075.28  (KL=97966.22, NLL=109.06)


Epoch 3/10:   3%|▎         | 12/360 [00:00<00:17, 19.75it/s]

 batch_loss: 76613.87  |  KL+NLL: 98088.65  (KL=97918.14, NLL=170.51)
 batch_loss: 76396.65  |  KL+NLL: 98012.15  (KL=97928.86, NLL=83.29)


Epoch 3/10:   4%|▍         | 14/360 [00:00<00:17, 19.75it/s]

 batch_loss: 76412.20  |  KL+NLL: 97976.38  (KL=97901.66, NLL=74.72)
 batch_loss: 76661.31  |  KL+NLL: 98040.39  (KL=97879.30, NLL=161.09)
 batch_loss: 76315.66  |  KL+NLL: 98015.99  (KL=97867.19, NLL=148.80)


Epoch 3/10:   4%|▍         | 16/360 [00:00<00:17, 19.66it/s]

 batch_loss: 76271.25  |  KL+NLL: 97929.89  (KL=97840.05, NLL=89.84)
 batch_loss: 76114.85  |  KL+NLL: 97954.94  (KL=97853.11, NLL=101.83)


Epoch 3/10:   5%|▌         | 19/360 [00:01<00:16, 20.14it/s]

 batch_loss: 75801.02  |  KL+NLL: 98024.03  (KL=97853.21, NLL=170.82)
 batch_loss: 75914.08  |  KL+NLL: 97954.72  (KL=97849.59, NLL=105.13)
 batch_loss: 75597.10  |  KL+NLL: 97916.31  (KL=97835.32, NLL=80.99)


Epoch 3/10:   6%|▌         | 22/360 [00:01<00:16, 20.64it/s]

 batch_loss: 75983.48  |  KL+NLL: 97897.70  (KL=97806.81, NLL=90.88)
 batch_loss: 75485.79  |  KL+NLL: 97924.17  (KL=97818.95, NLL=105.22)


Epoch 3/10:   7%|▋         | 25/360 [00:01<00:16, 20.59it/s]

 batch_loss: 75748.92  |  KL+NLL: 97931.76  (KL=97841.84, NLL=89.92)
 batch_loss: 75028.41  |  KL+NLL: 97847.48  (KL=97795.66, NLL=51.81)
 batch_loss: 75073.76  |  KL+NLL: 97857.41  (KL=97777.01, NLL=80.40)
 batch_loss: 75206.66  |  KL+NLL: 97835.73  (KL=97761.78, NLL=73.95)
 batch_loss: 75147.29  |  KL+NLL: 97834.01  (KL=97774.69, NLL=59.33)


Epoch 3/10:   8%|▊         | 28/360 [00:01<00:16, 20.72it/s]

 batch_loss: 75486.41  |  KL+NLL: 97844.18  (KL=97714.60, NLL=129.57)
 batch_loss: 75333.98  |  KL+NLL: 97865.07  (KL=97731.95, NLL=133.11)
 batch_loss: 74616.00  |  KL+NLL: 97799.81  (KL=97705.48, NLL=94.33)


Epoch 3/10:   9%|▊         | 31/360 [00:01<00:15, 20.58it/s]

 batch_loss: 74546.66  |  KL+NLL: 97858.08  (KL=97731.46, NLL=126.62)
 batch_loss: 74893.47  |  KL+NLL: 97777.30  (KL=97687.02, NLL=90.29)


Epoch 3/10:   9%|▉         | 34/360 [00:01<00:16, 20.32it/s]

 batch_loss: 74438.20  |  KL+NLL: 97812.26  (KL=97725.35, NLL=86.91)
 batch_loss: 74285.81  |  KL+NLL: 97729.74  (KL=97670.91, NLL=58.84)


Epoch 3/10:  10%|█         | 37/360 [00:01<00:15, 20.50it/s]

 batch_loss: 74475.09  |  KL+NLL: 97787.02  (KL=97660.69, NLL=126.33)
 batch_loss: 74569.89  |  KL+NLL: 97754.52  (KL=97665.52, NLL=89.00)
 batch_loss: 74238.10  |  KL+NLL: 97859.39  (KL=97682.17, NLL=177.22)
 batch_loss: 73945.43  |  KL+NLL: 97688.44  (KL=97602.54, NLL=85.90)
 batch_loss: 74255.51  |  KL+NLL: 97743.66  (KL=97644.61, NLL=99.05)


Epoch 3/10:  11%|█         | 40/360 [00:02<00:15, 20.80it/s]

 batch_loss: 73927.36  |  KL+NLL: 97823.97  (KL=97622.16, NLL=201.81)
 batch_loss: 73726.30  |  KL+NLL: 97660.44  (KL=97574.00, NLL=86.44)
 batch_loss: 73546.72  |  KL+NLL: 97698.49  (KL=97616.92, NLL=81.57)


Epoch 3/10:  12%|█▏        | 43/360 [00:02<00:15, 20.67it/s]

 batch_loss: 73586.28  |  KL+NLL: 97680.21  (KL=97609.00, NLL=71.21)


Epoch 3/10:  13%|█▎        | 46/360 [00:02<00:15, 20.49it/s]

 batch_loss: 73461.86  |  KL+NLL: 97704.44  (KL=97613.53, NLL=90.91)
 batch_loss: 73601.57  |  KL+NLL: 97718.36  (KL=97624.40, NLL=93.97)
 batch_loss: 73250.29  |  KL+NLL: 97666.39  (KL=97567.36, NLL=99.03)
 batch_loss: 73228.17  |  KL+NLL: 97649.70  (KL=97558.80, NLL=90.90)
 batch_loss: 72864.83  |  KL+NLL: 97642.99  (KL=97531.52, NLL=111.48)


Epoch 3/10:  14%|█▎        | 49/360 [00:02<00:15, 20.43it/s]

 batch_loss: 72811.94  |  KL+NLL: 97632.94  (KL=97526.79, NLL=106.16)
 batch_loss: 73149.61  |  KL+NLL: 97579.54  (KL=97494.31, NLL=85.23)
 batch_loss: 73091.04  |  KL+NLL: 97559.91  (KL=97477.68, NLL=82.23)


Epoch 3/10:  14%|█▍        | 52/360 [00:02<00:15, 20.46it/s]

 batch_loss: 72723.28  |  KL+NLL: 97537.22  (KL=97469.14, NLL=68.07)


Epoch 3/10:  15%|█▌        | 55/360 [00:02<00:15, 19.99it/s]

 batch_loss: 73149.41  |  KL+NLL: 97595.57  (KL=97496.32, NLL=99.25)
 batch_loss: 72396.47  |  KL+NLL: 97599.87  (KL=97486.91, NLL=112.96)
 batch_loss: 72660.19  |  KL+NLL: 97586.46  (KL=97474.67, NLL=111.79)
 batch_loss: 72493.51  |  KL+NLL: 97668.08  (KL=97505.01, NLL=163.07)


Epoch 3/10:  16%|█▌        | 58/360 [00:02<00:15, 19.77it/s]

 batch_loss: 72289.00  |  KL+NLL: 97578.55  (KL=97485.78, NLL=92.77)
 batch_loss: 72145.90  |  KL+NLL: 97558.59  (KL=97445.59, NLL=113.00)
 batch_loss: 72222.85  |  KL+NLL: 97564.28  (KL=97398.34, NLL=165.95)


Epoch 3/10:  17%|█▋        | 61/360 [00:03<00:14, 19.98it/s]

 batch_loss: 72037.13  |  KL+NLL: 97530.54  (KL=97422.09, NLL=108.46)
 batch_loss: 71900.77  |  KL+NLL: 97476.10  (KL=97383.13, NLL=92.97)
 batch_loss: 71669.41  |  KL+NLL: 97524.06  (KL=97403.02, NLL=121.03)
 batch_loss: 71814.74  |  KL+NLL: 97503.73  (KL=97392.45, NLL=111.28)


Epoch 3/10:  18%|█▊        | 64/360 [00:03<00:14, 19.75it/s]

 batch_loss: 71676.51  |  KL+NLL: 97479.16  (KL=97380.60, NLL=98.56)
 batch_loss: 71615.07  |  KL+NLL: 97404.72  (KL=97328.55, NLL=76.17)


Epoch 3/10:  19%|█▊        | 67/360 [00:03<00:14, 20.26it/s]

 batch_loss: 71518.58  |  KL+NLL: 97462.29  (KL=97366.96, NLL=95.33)
 batch_loss: 71444.03  |  KL+NLL: 97390.70  (KL=97299.29, NLL=91.41)
 batch_loss: 71189.26  |  KL+NLL: 97453.22  (KL=97355.62, NLL=97.59)


Epoch 3/10:  19%|█▉        | 70/360 [00:03<00:14, 20.69it/s]

 batch_loss: 71158.51  |  KL+NLL: 97468.80  (KL=97368.88, NLL=99.92)
 batch_loss: 71135.44  |  KL+NLL: 97441.41  (KL=97340.45, NLL=100.96)


Epoch 3/10:  20%|██        | 73/360 [00:03<00:14, 20.45it/s]

 batch_loss: 71104.95  |  KL+NLL: 97380.06  (KL=97299.90, NLL=80.16)
 batch_loss: 70833.07  |  KL+NLL: 97432.50  (KL=97307.31, NLL=125.19)
 batch_loss: 70990.45  |  KL+NLL: 97480.71  (KL=97311.87, NLL=168.84)
 batch_loss: 70981.64  |  KL+NLL: 97416.06  (KL=97328.07, NLL=87.99)


Epoch 3/10:  21%|██        | 76/360 [00:03<00:13, 20.30it/s]

 batch_loss: 70581.71  |  KL+NLL: 97472.80  (KL=97299.95, NLL=172.85)
 batch_loss: 70465.54  |  KL+NLL: 97359.26  (KL=97296.34, NLL=62.92)
 batch_loss: 70519.90  |  KL+NLL: 97354.13  (KL=97275.70, NLL=78.44)
 batch_loss: 70575.07  |  KL+NLL: 97382.02  (KL=97254.84, NLL=127.18)


Epoch 3/10:  22%|██▏       | 79/360 [00:03<00:13, 20.48it/s]

 batch_loss: 70185.65  |  KL+NLL: 97343.12  (KL=97269.86, NLL=73.26)


Epoch 3/10:  23%|██▎       | 82/360 [00:04<00:13, 20.49it/s]

 batch_loss: 70318.22  |  KL+NLL: 97381.15  (KL=97270.97, NLL=110.18)
 batch_loss: 70391.68  |  KL+NLL: 97345.06  (KL=97266.03, NLL=79.02)
 batch_loss: 70081.31  |  KL+NLL: 97297.52  (KL=97215.41, NLL=82.11)
 batch_loss: 69895.46  |  KL+NLL: 97286.21  (KL=97186.61, NLL=99.60)
 batch_loss: 69712.25  |  KL+NLL: 97335.03  (KL=97213.74, NLL=121.28)


Epoch 3/10:  24%|██▍       | 88/360 [00:04<00:13, 20.43it/s]

 batch_loss: 69932.80  |  KL+NLL: 97245.40  (KL=97169.05, NLL=76.35)
 batch_loss: 69522.78  |  KL+NLL: 97277.24  (KL=97177.81, NLL=99.43)
 batch_loss: 69650.19  |  KL+NLL: 97266.28  (KL=97137.71, NLL=128.57)
 batch_loss: 69614.90  |  KL+NLL: 97340.52  (KL=97199.59, NLL=140.93)


Epoch 3/10:  25%|██▌       | 91/360 [00:04<00:13, 20.50it/s]

 batch_loss: 69501.18  |  KL+NLL: 97342.55  (KL=97184.79, NLL=157.76)
 batch_loss: 69367.98  |  KL+NLL: 97261.12  (KL=97166.69, NLL=94.43)
 batch_loss: 69115.36  |  KL+NLL: 97254.53  (KL=97151.02, NLL=103.51)
 batch_loss: 69093.47  |  KL+NLL: 97200.17  (KL=97134.06, NLL=66.11)
 batch_loss: 69156.45  |  KL+NLL: 97243.80  (KL=97159.43, NLL=84.37)


Epoch 3/10:  27%|██▋       | 97/360 [00:04<00:12, 20.76it/s]

 batch_loss: 69213.31  |  KL+NLL: 97268.70  (KL=97121.91, NLL=146.79)
 batch_loss: 69002.06  |  KL+NLL: 97172.21  (KL=97075.94, NLL=96.28)
 batch_loss: 68906.54  |  KL+NLL: 97211.88  (KL=97108.93, NLL=102.95)
 batch_loss: 68935.22  |  KL+NLL: 97187.76  (KL=97075.96, NLL=111.80)


Epoch 3/10:  28%|██▊       | 100/360 [00:04<00:12, 20.56it/s]

 batch_loss: 68801.18  |  KL+NLL: 97195.53  (KL=97084.63, NLL=110.89)
 batch_loss: 68778.49  |  KL+NLL: 97236.62  (KL=97120.75, NLL=115.87)
 batch_loss: 68429.43  |  KL+NLL: 97150.18  (KL=97072.44, NLL=77.74)
 batch_loss: 68248.58  |  KL+NLL: 97138.79  (KL=97012.93, NLL=125.86)
 batch_loss: 68428.65  |  KL+NLL: 97253.38  (KL=97124.07, NLL=129.31)


Epoch 3/10:  29%|██▉       | 106/360 [00:05<00:12, 20.19it/s]

 batch_loss: 68202.73  |  KL+NLL: 97100.08  (KL=97046.98, NLL=53.11)
 batch_loss: 68030.72  |  KL+NLL: 97056.83  (KL=96994.20, NLL=62.63)
 batch_loss: 67945.94  |  KL+NLL: 97064.55  (KL=96989.84, NLL=74.70)
 batch_loss: 67638.79  |  KL+NLL: 97094.51  (KL=97027.22, NLL=67.29)
 batch_loss: 67625.79  |  KL+NLL: 97081.53  (KL=96996.95, NLL=84.57)


Epoch 3/10:  30%|███       | 109/360 [00:05<00:12, 20.41it/s]

 batch_loss: 68107.18  |  KL+NLL: 97074.66  (KL=96969.98, NLL=104.68)
 batch_loss: 67466.53  |  KL+NLL: 97048.53  (KL=96993.41, NLL=55.12)
 batch_loss: 67500.70  |  KL+NLL: 97084.66  (KL=97017.58, NLL=67.08)
 batch_loss: 67556.23  |  KL+NLL: 97034.20  (KL=96968.05, NLL=66.16)


Epoch 3/10:  32%|███▏      | 115/360 [00:05<00:12, 19.73it/s]

 batch_loss: 67240.87  |  KL+NLL: 97029.21  (KL=96971.80, NLL=57.41)
 batch_loss: 66985.07  |  KL+NLL: 97102.50  (KL=96997.09, NLL=105.40)
 batch_loss: 67095.99  |  KL+NLL: 97113.13  (KL=96955.42, NLL=157.71)
 batch_loss: 67391.95  |  KL+NLL: 97047.20  (KL=96947.66, NLL=99.54)


Epoch 3/10:  33%|███▎      | 118/360 [00:05<00:11, 20.34it/s]

 batch_loss: 67345.71  |  KL+NLL: 97039.94  (KL=96920.60, NLL=119.34)
 batch_loss: 66920.79  |  KL+NLL: 96984.71  (KL=96904.78, NLL=79.93)
 batch_loss: 66760.67  |  KL+NLL: 97027.27  (KL=96938.46, NLL=88.81)
 batch_loss: 66803.00  |  KL+NLL: 97016.97  (KL=96916.55, NLL=100.41)
 batch_loss: 66933.34  |  KL+NLL: 97028.29  (KL=96918.17, NLL=110.11)


Epoch 3/10:  34%|███▍      | 124/360 [00:06<00:11, 20.38it/s]

 batch_loss: 66857.64  |  KL+NLL: 96979.22  (KL=96882.66, NLL=96.57)
 batch_loss: 66702.80  |  KL+NLL: 97033.76  (KL=96905.23, NLL=128.54)
 batch_loss: 66627.31  |  KL+NLL: 96972.95  (KL=96860.09, NLL=112.86)
 batch_loss: 66457.09  |  KL+NLL: 96906.54  (KL=96834.75, NLL=71.79)
 batch_loss: 66263.99  |  KL+NLL: 96910.41  (KL=96834.70, NLL=75.72)


Epoch 3/10:  36%|███▌      | 130/360 [00:06<00:11, 20.85it/s]

 batch_loss: 66188.23  |  KL+NLL: 96900.18  (KL=96833.61, NLL=66.57)
 batch_loss: 66023.71  |  KL+NLL: 96934.04  (KL=96835.70, NLL=98.34)
 batch_loss: 65629.00  |  KL+NLL: 96983.41  (KL=96867.09, NLL=116.32)
 batch_loss: 65760.82  |  KL+NLL: 96907.34  (KL=96816.81, NLL=90.53)
 batch_loss: 65877.44  |  KL+NLL: 96912.49  (KL=96824.75, NLL=87.74)


Epoch 3/10:  37%|███▋      | 133/360 [00:06<00:11, 20.58it/s]

 batch_loss: 65612.45  |  KL+NLL: 96933.36  (KL=96846.28, NLL=87.07)
 batch_loss: 65800.19  |  KL+NLL: 96887.98  (KL=96780.98, NLL=106.99)
 batch_loss: 65721.90  |  KL+NLL: 96879.19  (KL=96722.69, NLL=156.50)
 batch_loss: 65509.05  |  KL+NLL: 96833.94  (KL=96767.17, NLL=66.76)
 batch_loss: 65265.78  |  KL+NLL: 96906.98  (KL=96757.23, NLL=149.75)


Epoch 3/10:  39%|███▊      | 139/360 [00:06<00:10, 20.72it/s]

 batch_loss: 65385.22  |  KL+NLL: 96821.72  (KL=96742.55, NLL=79.17)
 batch_loss: 65180.38  |  KL+NLL: 96850.61  (KL=96773.55, NLL=77.05)
 batch_loss: 65294.10  |  KL+NLL: 96843.30  (KL=96737.15, NLL=106.15)
 batch_loss: 64863.59  |  KL+NLL: 96858.97  (KL=96762.77, NLL=96.19)
 batch_loss: 64841.67  |  KL+NLL: 96883.11  (KL=96723.40, NLL=159.71)


Epoch 3/10:  40%|████      | 145/360 [00:07<00:10, 20.21it/s]

 batch_loss: 64583.34  |  KL+NLL: 96810.75  (KL=96707.81, NLL=102.94)
 batch_loss: 64757.64  |  KL+NLL: 96773.62  (KL=96666.16, NLL=107.46)
 batch_loss: 64455.93  |  KL+NLL: 96800.29  (KL=96722.31, NLL=77.98)
 batch_loss: 64869.92  |  KL+NLL: 96783.33  (KL=96690.24, NLL=93.09)
 batch_loss: 64490.90  |  KL+NLL: 96790.25  (KL=96696.81, NLL=93.44)


Epoch 3/10:  41%|████      | 148/360 [00:07<00:10, 19.75it/s]

 batch_loss: 64484.89  |  KL+NLL: 96780.66  (KL=96682.82, NLL=97.84)
 batch_loss: 64358.92  |  KL+NLL: 96807.83  (KL=96685.96, NLL=121.87)
 batch_loss: 64119.09  |  KL+NLL: 96749.81  (KL=96682.59, NLL=67.22)
 batch_loss: 64504.55  |  KL+NLL: 96782.25  (KL=96688.72, NLL=93.53)
 batch_loss: 63971.42  |  KL+NLL: 96776.33  (KL=96656.32, NLL=120.01)


Epoch 3/10:  43%|████▎     | 154/360 [00:07<00:10, 20.15it/s]

 batch_loss: 64212.18  |  KL+NLL: 96754.65  (KL=96667.30, NLL=87.34)
 batch_loss: 63858.83  |  KL+NLL: 96753.45  (KL=96661.03, NLL=92.42)
 batch_loss: 63624.94  |  KL+NLL: 96762.79  (KL=96650.73, NLL=112.06)
 batch_loss: 63930.94  |  KL+NLL: 96682.83  (KL=96621.51, NLL=61.32)
 batch_loss: 63800.30  |  KL+NLL: 96859.09  (KL=96637.11, NLL=221.98)


Epoch 3/10:  44%|████▍     | 160/360 [00:07<00:09, 20.91it/s]

 batch_loss: 63778.31  |  KL+NLL: 96748.85  (KL=96607.12, NLL=141.74)
 batch_loss: 63277.47  |  KL+NLL: 96708.37  (KL=96593.55, NLL=114.82)
 batch_loss: 63164.48  |  KL+NLL: 96683.76  (KL=96574.84, NLL=108.92)
 batch_loss: 63192.06  |  KL+NLL: 96764.95  (KL=96619.60, NLL=145.35)
 batch_loss: 63483.60  |  KL+NLL: 96699.79  (KL=96605.51, NLL=94.28)


Epoch 3/10:  45%|████▌     | 163/360 [00:08<00:09, 20.80it/s]

 batch_loss: 63097.45  |  KL+NLL: 96725.29  (KL=96578.28, NLL=147.01)
 batch_loss: 63431.48  |  KL+NLL: 96784.78  (KL=96609.91, NLL=174.87)
 batch_loss: 63146.10  |  KL+NLL: 96707.18  (KL=96591.40, NLL=115.78)
 batch_loss: 62630.54  |  KL+NLL: 96672.78  (KL=96573.29, NLL=99.50)
 batch_loss: 62746.83  |  KL+NLL: 96635.55  (KL=96531.87, NLL=103.69)


Epoch 3/10:  47%|████▋     | 169/360 [00:08<00:09, 20.65it/s]

 batch_loss: 62518.21  |  KL+NLL: 96656.10  (KL=96541.93, NLL=114.17)
 batch_loss: 62678.52  |  KL+NLL: 96654.09  (KL=96573.82, NLL=80.27)
 batch_loss: 62611.15  |  KL+NLL: 96691.13  (KL=96556.52, NLL=134.61)
 batch_loss: 62537.78  |  KL+NLL: 96742.84  (KL=96478.56, NLL=264.27)
 batch_loss: 62482.48  |  KL+NLL: 96615.41  (KL=96515.98, NLL=99.43)


Epoch 3/10:  49%|████▊     | 175/360 [00:08<00:08, 20.89it/s]

 batch_loss: 62419.29  |  KL+NLL: 96626.04  (KL=96512.00, NLL=114.04)
 batch_loss: 62111.40  |  KL+NLL: 96627.83  (KL=96502.57, NLL=125.26)
 batch_loss: 62123.10  |  KL+NLL: 96591.77  (KL=96475.97, NLL=115.80)
 batch_loss: 62015.85  |  KL+NLL: 96592.60  (KL=96508.02, NLL=84.58)
 batch_loss: 61864.50  |  KL+NLL: 96685.11  (KL=96558.11, NLL=127.00)


Epoch 3/10:  49%|████▉     | 178/360 [00:08<00:08, 20.95it/s]

 batch_loss: 61834.28  |  KL+NLL: 96565.95  (KL=96444.67, NLL=121.28)
 batch_loss: 61689.26  |  KL+NLL: 96587.45  (KL=96485.96, NLL=101.49)
 batch_loss: 61392.94  |  KL+NLL: 96579.84  (KL=96470.76, NLL=109.08)
 batch_loss: 61662.69  |  KL+NLL: 96593.61  (KL=96475.84, NLL=117.76)
 batch_loss: 61598.16  |  KL+NLL: 96666.08  (KL=96498.91, NLL=167.17)


Epoch 3/10:  51%|█████     | 184/360 [00:09<00:08, 21.33it/s]

 batch_loss: 61450.35  |  KL+NLL: 96606.61  (KL=96488.34, NLL=118.27)
 batch_loss: 61216.77  |  KL+NLL: 96505.78  (KL=96394.69, NLL=111.09)
 batch_loss: 61511.45  |  KL+NLL: 96531.40  (KL=96417.44, NLL=113.96)
 batch_loss: 61601.11  |  KL+NLL: 96531.41  (KL=96418.61, NLL=112.80)
 batch_loss: 60996.99  |  KL+NLL: 96653.14  (KL=96483.36, NLL=169.78)


Epoch 3/10:  53%|█████▎    | 190/360 [00:09<00:08, 20.47it/s]

 batch_loss: 60808.31  |  KL+NLL: 96486.05  (KL=96351.34, NLL=134.70)
 batch_loss: 60907.77  |  KL+NLL: 96469.38  (KL=96389.73, NLL=79.65)
 batch_loss: 60865.05  |  KL+NLL: 96517.48  (KL=96387.22, NLL=130.27)
 batch_loss: 60890.48  |  KL+NLL: 96442.99  (KL=96398.98, NLL=44.01)
 batch_loss: 60486.18  |  KL+NLL: 96443.01  (KL=96379.99, NLL=63.01)


Epoch 3/10:  54%|█████▎    | 193/360 [00:09<00:08, 19.98it/s]

 batch_loss: 60638.48  |  KL+NLL: 96515.97  (KL=96431.64, NLL=84.33)
 batch_loss: 60552.93  |  KL+NLL: 96463.71  (KL=96373.70, NLL=90.01)
 batch_loss: 60484.94  |  KL+NLL: 96432.66  (KL=96355.60, NLL=77.06)
 batch_loss: 60515.05  |  KL+NLL: 96544.40  (KL=96382.81, NLL=161.59)
 batch_loss: 60290.77  |  KL+NLL: 96520.22  (KL=96401.56, NLL=118.66)


Epoch 3/10:  55%|█████▌    | 199/360 [00:09<00:07, 20.47it/s]

 batch_loss: 60048.12  |  KL+NLL: 96468.30  (KL=96376.84, NLL=91.46)
 batch_loss: 60139.95  |  KL+NLL: 96427.14  (KL=96318.79, NLL=108.35)
 batch_loss: 59878.99  |  KL+NLL: 96430.90  (KL=96353.49, NLL=77.41)
 batch_loss: 59756.87  |  KL+NLL: 96418.89  (KL=96304.08, NLL=114.81)
 batch_loss: 59600.29  |  KL+NLL: 96446.20  (KL=96356.30, NLL=89.90)


Epoch 3/10:  57%|█████▋    | 205/360 [00:10<00:07, 20.89it/s]

 batch_loss: 59643.56  |  KL+NLL: 96433.09  (KL=96328.94, NLL=104.15)
 batch_loss: 59898.59  |  KL+NLL: 96434.58  (KL=96306.61, NLL=127.98)
 batch_loss: 59885.23  |  KL+NLL: 96443.57  (KL=96335.41, NLL=108.15)
 batch_loss: 59548.90  |  KL+NLL: 96384.46  (KL=96304.74, NLL=79.71)
 batch_loss: 59397.84  |  KL+NLL: 96503.05  (KL=96336.69, NLL=166.37)


Epoch 3/10:  58%|█████▊    | 208/360 [00:10<00:07, 20.83it/s]

 batch_loss: 59408.23  |  KL+NLL: 96466.52  (KL=96298.77, NLL=167.75)
 batch_loss: 59183.11  |  KL+NLL: 96425.43  (KL=96249.15, NLL=176.29)
 batch_loss: 59219.90  |  KL+NLL: 96424.69  (KL=96296.73, NLL=127.96)
 batch_loss: 59033.81  |  KL+NLL: 96347.15  (KL=96274.55, NLL=72.60)
 batch_loss: 58738.17  |  KL+NLL: 96507.92  (KL=96314.19, NLL=193.73)


Epoch 3/10:  59%|█████▉    | 214/360 [00:10<00:07, 20.79it/s]

 batch_loss: 58772.74  |  KL+NLL: 96412.19  (KL=96284.64, NLL=127.55)
 batch_loss: 59012.67  |  KL+NLL: 96378.95  (KL=96243.34, NLL=135.60)
 batch_loss: 58977.92  |  KL+NLL: 96359.42  (KL=96299.61, NLL=59.81)
 batch_loss: 59017.38  |  KL+NLL: 96429.25  (KL=96247.34, NLL=181.91)
 batch_loss: 58532.40  |  KL+NLL: 96275.85  (KL=96185.38, NLL=90.46)


Epoch 3/10:  61%|██████    | 220/360 [00:10<00:06, 21.29it/s]

 batch_loss: 58635.02  |  KL+NLL: 96437.95  (KL=96270.21, NLL=167.74)
 batch_loss: 58480.71  |  KL+NLL: 96356.64  (KL=96265.43, NLL=91.21)
 batch_loss: 58234.31  |  KL+NLL: 96342.65  (KL=96237.60, NLL=105.04)
 batch_loss: 58459.35  |  KL+NLL: 96320.38  (KL=96190.13, NLL=130.25)
 batch_loss: 58146.21  |  KL+NLL: 96270.15  (KL=96213.05, NLL=57.10)


Epoch 3/10:  62%|██████▏   | 223/360 [00:10<00:06, 20.77it/s]

 batch_loss: 58267.02  |  KL+NLL: 96347.38  (KL=96200.16, NLL=147.21)
 batch_loss: 57999.88  |  KL+NLL: 96298.44  (KL=96204.54, NLL=93.90)
 batch_loss: 58092.30  |  KL+NLL: 96307.95  (KL=96168.68, NLL=139.27)
 batch_loss: 57777.76  |  KL+NLL: 96291.10  (KL=96168.11, NLL=122.99)


Epoch 3/10:  64%|██████▎   | 229/360 [00:11<00:06, 20.55it/s]

 batch_loss: 57864.13  |  KL+NLL: 96378.85  (KL=96192.91, NLL=185.94)
 batch_loss: 57949.28  |  KL+NLL: 96292.82  (KL=96179.00, NLL=113.82)
 batch_loss: 57622.00  |  KL+NLL: 96285.95  (KL=96178.65, NLL=107.31)
 batch_loss: 57469.51  |  KL+NLL: 96287.31  (KL=96195.32, NLL=91.99)
 batch_loss: 57373.09  |  KL+NLL: 96261.57  (KL=96164.40, NLL=97.17)


Epoch 3/10:  64%|██████▍   | 232/360 [00:11<00:06, 20.52it/s]

 batch_loss: 57409.88  |  KL+NLL: 96341.71  (KL=96163.94, NLL=177.78)
 batch_loss: 57385.16  |  KL+NLL: 96276.76  (KL=96160.30, NLL=116.45)
 batch_loss: 57440.12  |  KL+NLL: 96238.54  (KL=96152.24, NLL=86.30)
 batch_loss: 57353.26  |  KL+NLL: 96228.93  (KL=96134.11, NLL=94.82)
 batch_loss: 56756.54  |  KL+NLL: 96232.37  (KL=96164.01, NLL=68.36)


Epoch 3/10:  66%|██████▌   | 238/360 [00:11<00:06, 20.12it/s]

 batch_loss: 56798.17  |  KL+NLL: 96241.67  (KL=96125.62, NLL=116.05)
 batch_loss: 56955.10  |  KL+NLL: 96184.75  (KL=96109.45, NLL=75.30)
 batch_loss: 56927.09  |  KL+NLL: 96244.21  (KL=96102.29, NLL=141.92)
 batch_loss: 56561.75  |  KL+NLL: 96265.04  (KL=96145.58, NLL=119.46)


Epoch 3/10:  67%|██████▋   | 241/360 [00:11<00:05, 20.31it/s]

 batch_loss: 56663.73  |  KL+NLL: 96203.13  (KL=96081.52, NLL=121.61)
 batch_loss: 56579.87  |  KL+NLL: 96275.96  (KL=96106.70, NLL=169.26)
 batch_loss: 56339.89  |  KL+NLL: 96191.55  (KL=96054.81, NLL=136.74)
 batch_loss: 56291.55  |  KL+NLL: 96150.50  (KL=96082.52, NLL=67.99)
 batch_loss: 56399.71  |  KL+NLL: 96204.67  (KL=96102.57, NLL=102.10)


Epoch 3/10:  69%|██████▊   | 247/360 [00:12<00:05, 20.68it/s]

 batch_loss: 56295.40  |  KL+NLL: 96130.30  (KL=96046.99, NLL=83.30)
 batch_loss: 56292.85  |  KL+NLL: 96175.68  (KL=96031.92, NLL=143.76)
 batch_loss: 55981.59  |  KL+NLL: 96212.31  (KL=96107.88, NLL=104.43)
 batch_loss: 56201.45  |  KL+NLL: 96204.90  (KL=96071.23, NLL=133.67)
 batch_loss: 56098.63  |  KL+NLL: 96144.05  (KL=96054.02, NLL=90.03)


Epoch 3/10:  69%|██████▉   | 250/360 [00:12<00:05, 20.85it/s]

 batch_loss: 55838.66  |  KL+NLL: 96172.35  (KL=96081.74, NLL=90.61)
 batch_loss: 55772.72  |  KL+NLL: 96109.73  (KL=95996.77, NLL=112.96)
 batch_loss: 55450.20  |  KL+NLL: 96130.91  (KL=96054.80, NLL=76.11)
 batch_loss: 55459.58  |  KL+NLL: 96101.10  (KL=95987.08, NLL=114.02)


Epoch 3/10:  71%|███████   | 256/360 [00:12<00:05, 20.38it/s]

 batch_loss: 55354.91  |  KL+NLL: 96212.93  (KL=96049.91, NLL=163.03)
 batch_loss: 55397.91  |  KL+NLL: 96141.95  (KL=96012.26, NLL=129.69)
 batch_loss: 55588.17  |  KL+NLL: 96128.35  (KL=96042.68, NLL=85.67)
 batch_loss: 55317.24  |  KL+NLL: 96177.40  (KL=96054.52, NLL=122.88)
 batch_loss: 55510.31  |  KL+NLL: 96166.75  (KL=96033.46, NLL=133.29)


Epoch 3/10:  73%|███████▎  | 262/360 [00:12<00:04, 20.36it/s]

 batch_loss: 55157.08  |  KL+NLL: 96089.63  (KL=95999.71, NLL=89.92)
 batch_loss: 55098.62  |  KL+NLL: 96080.57  (KL=95998.88, NLL=81.69)
 batch_loss: 55249.28  |  KL+NLL: 96255.90  (KL=96043.31, NLL=212.59)
 batch_loss: 54887.05  |  KL+NLL: 96128.93  (KL=96053.68, NLL=75.25)
 batch_loss: 54937.99  |  KL+NLL: 96156.48  (KL=96048.73, NLL=107.75)


Epoch 3/10:  74%|███████▎  | 265/360 [00:13<00:04, 20.54it/s]

 batch_loss: 54758.39  |  KL+NLL: 96128.12  (KL=95981.58, NLL=146.54)
 batch_loss: 54698.62  |  KL+NLL: 96099.96  (KL=95950.73, NLL=149.24)
 batch_loss: 55029.05  |  KL+NLL: 96071.42  (KL=95952.45, NLL=118.96)
 batch_loss: 54689.25  |  KL+NLL: 96079.23  (KL=96002.79, NLL=76.44)
 batch_loss: 54654.90  |  KL+NLL: 96160.07  (KL=96013.56, NLL=146.51)


Epoch 3/10:  75%|███████▌  | 271/360 [00:13<00:04, 21.18it/s]

 batch_loss: 54226.81  |  KL+NLL: 96058.95  (KL=95970.75, NLL=88.20)
 batch_loss: 54312.16  |  KL+NLL: 96090.76  (KL=95945.23, NLL=145.53)
 batch_loss: 53878.85  |  KL+NLL: 96058.03  (KL=95941.43, NLL=116.60)
 batch_loss: 54217.90  |  KL+NLL: 96056.72  (KL=95945.36, NLL=111.36)
 batch_loss: 54021.23  |  KL+NLL: 96104.87  (KL=95965.53, NLL=139.33)


Epoch 3/10:  77%|███████▋  | 277/360 [00:13<00:03, 21.01it/s]

 batch_loss: 53994.29  |  KL+NLL: 96031.67  (KL=95916.32, NLL=115.35)
 batch_loss: 54003.68  |  KL+NLL: 96023.90  (KL=95940.49, NLL=83.41)
 batch_loss: 53849.70  |  KL+NLL: 95958.72  (KL=95867.04, NLL=91.68)
 batch_loss: 53887.28  |  KL+NLL: 96081.80  (KL=95958.20, NLL=123.59)
 batch_loss: 53756.84  |  KL+NLL: 95943.90  (KL=95867.00, NLL=76.90)


Epoch 3/10:  78%|███████▊  | 280/360 [00:13<00:03, 20.97it/s]

 batch_loss: 53480.16  |  KL+NLL: 96116.73  (KL=95940.84, NLL=175.89)
 batch_loss: 53643.81  |  KL+NLL: 96066.74  (KL=95949.81, NLL=116.93)
 batch_loss: 53439.01  |  KL+NLL: 96046.29  (KL=95954.96, NLL=91.33)
 batch_loss: 53546.66  |  KL+NLL: 96042.63  (KL=95919.25, NLL=123.38)
 batch_loss: 53064.87  |  KL+NLL: 96103.16  (KL=95979.56, NLL=123.59)


Epoch 3/10:  79%|███████▉  | 286/360 [00:14<00:03, 20.69it/s]

 batch_loss: 53350.16  |  KL+NLL: 96033.73  (KL=95857.90, NLL=175.83)
 batch_loss: 53406.25  |  KL+NLL: 95970.63  (KL=95854.67, NLL=115.96)
 batch_loss: 52957.94  |  KL+NLL: 96007.57  (KL=95893.98, NLL=113.59)
 batch_loss: 53382.03  |  KL+NLL: 96034.03  (KL=95870.42, NLL=163.61)


Epoch 3/10:  80%|████████  | 289/360 [00:14<00:03, 20.61it/s]

 batch_loss: 52947.75  |  KL+NLL: 96006.69  (KL=95901.76, NLL=104.93)
 batch_loss: 52816.20  |  KL+NLL: 95999.10  (KL=95887.98, NLL=111.12)
 batch_loss: 52723.06  |  KL+NLL: 96044.62  (KL=95859.15, NLL=185.47)
 batch_loss: 52586.85  |  KL+NLL: 95944.69  (KL=95841.22, NLL=103.47)
 batch_loss: 52860.37  |  KL+NLL: 95962.55  (KL=95895.84, NLL=66.71)


Epoch 3/10:  82%|████████▏ | 295/360 [00:14<00:03, 20.53it/s]

 batch_loss: 52530.42  |  KL+NLL: 96068.26  (KL=95867.88, NLL=200.38)
 batch_loss: 52510.83  |  KL+NLL: 96123.11  (KL=95968.37, NLL=154.75)
 batch_loss: 52482.33  |  KL+NLL: 95965.42  (KL=95875.44, NLL=89.98)
 batch_loss: 52497.34  |  KL+NLL: 95962.45  (KL=95856.64, NLL=105.81)
 batch_loss: 52333.10  |  KL+NLL: 95927.57  (KL=95855.09, NLL=72.48)


Epoch 3/10:  84%|████████▎ | 301/360 [00:14<00:02, 21.24it/s]

 batch_loss: 52173.29  |  KL+NLL: 95955.85  (KL=95809.16, NLL=146.70)
 batch_loss: 51935.50  |  KL+NLL: 95985.19  (KL=95848.29, NLL=136.90)
 batch_loss: 52184.82  |  KL+NLL: 95924.12  (KL=95849.05, NLL=75.07)
 batch_loss: 52139.60  |  KL+NLL: 95934.99  (KL=95806.30, NLL=128.69)
 batch_loss: 51700.90  |  KL+NLL: 95958.48  (KL=95836.35, NLL=122.12)


Epoch 3/10:  84%|████████▍ | 304/360 [00:14<00:02, 21.42it/s]

 batch_loss: 52016.51  |  KL+NLL: 95844.12  (KL=95792.63, NLL=51.49)
 batch_loss: 52023.68  |  KL+NLL: 95893.94  (KL=95825.28, NLL=68.66)
 batch_loss: 51482.70  |  KL+NLL: 95972.59  (KL=95825.73, NLL=146.86)
 batch_loss: 51762.66  |  KL+NLL: 95992.19  (KL=95824.70, NLL=167.50)
 batch_loss: 51457.31  |  KL+NLL: 95957.80  (KL=95828.91, NLL=128.89)


Epoch 3/10:  86%|████████▌ | 310/360 [00:15<00:02, 21.82it/s]

 batch_loss: 51640.89  |  KL+NLL: 95954.39  (KL=95798.23, NLL=156.17)
 batch_loss: 51238.42  |  KL+NLL: 95915.43  (KL=95809.10, NLL=106.33)
 batch_loss: 51567.96  |  KL+NLL: 95985.81  (KL=95884.82, NLL=100.99)
 batch_loss: 51435.51  |  KL+NLL: 95985.87  (KL=95877.20, NLL=108.66)
 batch_loss: 51282.89  |  KL+NLL: 95891.32  (KL=95793.94, NLL=97.39)


Epoch 3/10:  88%|████████▊ | 316/360 [00:15<00:02, 21.70it/s]

 batch_loss: 50881.09  |  KL+NLL: 95912.60  (KL=95833.86, NLL=78.74)
 batch_loss: 51061.21  |  KL+NLL: 95975.84  (KL=95801.98, NLL=173.86)
 batch_loss: 50754.97  |  KL+NLL: 95932.47  (KL=95820.89, NLL=111.58)
 batch_loss: 50763.39  |  KL+NLL: 95914.65  (KL=95781.28, NLL=133.37)
 batch_loss: 51075.86  |  KL+NLL: 96011.91  (KL=95852.23, NLL=159.68)


Epoch 3/10:  89%|████████▊ | 319/360 [00:15<00:01, 21.67it/s]

 batch_loss: 50700.00  |  KL+NLL: 95896.56  (KL=95778.08, NLL=118.48)
 batch_loss: 50584.89  |  KL+NLL: 95870.63  (KL=95786.62, NLL=84.02)
 batch_loss: 50812.15  |  KL+NLL: 95872.58  (KL=95778.77, NLL=93.81)
 batch_loss: 50374.68  |  KL+NLL: 95866.56  (KL=95758.12, NLL=108.45)
 batch_loss: 50317.19  |  KL+NLL: 95817.92  (KL=95737.25, NLL=80.67)


Epoch 3/10:  90%|█████████ | 325/360 [00:15<00:01, 21.11it/s]

 batch_loss: 49953.63  |  KL+NLL: 95978.01  (KL=95729.88, NLL=248.13)
 batch_loss: 50404.93  |  KL+NLL: 95838.01  (KL=95770.00, NLL=68.01)
 batch_loss: 50430.25  |  KL+NLL: 95822.66  (KL=95743.30, NLL=79.36)
 batch_loss: 50346.85  |  KL+NLL: 95907.24  (KL=95725.06, NLL=182.18)
 batch_loss: 49798.79  |  KL+NLL: 95782.41  (KL=95671.81, NLL=110.59)


Epoch 3/10:  92%|█████████▏| 331/360 [00:16<00:01, 21.33it/s]

 batch_loss: 49978.59  |  KL+NLL: 95983.32  (KL=95780.58, NLL=202.74)
 batch_loss: 49963.08  |  KL+NLL: 95828.80  (KL=95729.79, NLL=99.01)
 batch_loss: 49763.55  |  KL+NLL: 95822.98  (KL=95720.28, NLL=102.70)
 batch_loss: 49382.83  |  KL+NLL: 95817.41  (KL=95715.04, NLL=102.37)
 batch_loss: 49630.58  |  KL+NLL: 95831.55  (KL=95741.55, NLL=90.00)


Epoch 3/10:  93%|█████████▎| 334/360 [00:16<00:01, 21.09it/s]

 batch_loss: 49450.74  |  KL+NLL: 95857.43  (KL=95731.98, NLL=125.45)
 batch_loss: 49784.08  |  KL+NLL: 95847.51  (KL=95731.48, NLL=116.04)
 batch_loss: 49529.32  |  KL+NLL: 95853.00  (KL=95740.57, NLL=112.43)
 batch_loss: 49559.94  |  KL+NLL: 95911.57  (KL=95739.73, NLL=171.83)
 batch_loss: 49208.03  |  KL+NLL: 95802.71  (KL=95710.59, NLL=92.12)


Epoch 3/10:  94%|█████████▍| 340/360 [00:16<00:00, 21.34it/s]

 batch_loss: 49275.65  |  KL+NLL: 95825.21  (KL=95715.71, NLL=109.50)
 batch_loss: 48942.03  |  KL+NLL: 95826.99  (KL=95744.41, NLL=82.58)
 batch_loss: 48695.62  |  KL+NLL: 95787.50  (KL=95680.23, NLL=107.27)
 batch_loss: 49081.29  |  KL+NLL: 95827.54  (KL=95742.18, NLL=85.36)
 batch_loss: 49228.12  |  KL+NLL: 95791.81  (KL=95687.45, NLL=104.37)


Epoch 3/10:  95%|█████████▌| 343/360 [00:16<00:00, 20.96it/s]

 batch_loss: 49088.49  |  KL+NLL: 95786.09  (KL=95690.26, NLL=95.83)
 batch_loss: 48888.26  |  KL+NLL: 95792.47  (KL=95675.27, NLL=117.20)
 batch_loss: 48683.86  |  KL+NLL: 95848.21  (KL=95729.47, NLL=118.74)
 batch_loss: 48660.57  |  KL+NLL: 95774.42  (KL=95645.58, NLL=128.84)


Epoch 3/10:  97%|█████████▋| 349/360 [00:16<00:00, 20.85it/s]

 batch_loss: 48612.70  |  KL+NLL: 95884.07  (KL=95783.56, NLL=100.51)
 batch_loss: 48198.85  |  KL+NLL: 95806.00  (KL=95722.67, NLL=83.33)
 batch_loss: 48267.58  |  KL+NLL: 95868.62  (KL=95693.78, NLL=174.84)
 batch_loss: 48549.22  |  KL+NLL: 95838.19  (KL=95730.39, NLL=107.80)
 batch_loss: 48182.00  |  KL+NLL: 95739.01  (KL=95664.80, NLL=74.21)


Epoch 3/10:  99%|█████████▊| 355/360 [00:17<00:00, 21.29it/s]

 batch_loss: 48556.42  |  KL+NLL: 95823.72  (KL=95730.95, NLL=92.76)
 batch_loss: 47935.90  |  KL+NLL: 95840.01  (KL=95748.68, NLL=91.33)
 batch_loss: 48097.00  |  KL+NLL: 95814.40  (KL=95697.66, NLL=116.73)
 batch_loss: 48160.63  |  KL+NLL: 95734.86  (KL=95621.21, NLL=113.65)
 batch_loss: 48001.55  |  KL+NLL: 95786.91  (KL=95704.93, NLL=81.98)


Epoch 3/10: 100%|██████████| 360/360 [00:17<00:00, 20.60it/s]


 batch_loss: 47815.83  |  KL+NLL: 95866.82  (KL=95706.43, NLL=160.39)
 batch_loss: 48018.31  |  KL+NLL: 95804.31  (KL=95663.31, NLL=141.00)
 batch_loss: 47704.57  |  KL+NLL: 95824.18  (KL=95683.73, NLL=140.44)
 batch_loss: 47832.69  |  KL+NLL: 95778.22  (KL=95605.89, NLL=172.33)
 batch_loss: 47718.94  |  KL+NLL: 95759.89  (KL=95676.97, NLL=82.92)
Epoch 3 - ELBO Loss: 61883.2849


Epoch 4/10:   1%|          | 4/360 [00:00<00:22, 15.78it/s]

 batch_loss: 47233.39  |  KL+NLL: 95776.57  (KL=95699.48, NLL=77.08)
 batch_loss: 47496.89  |  KL+NLL: 95833.94  (KL=95681.02, NLL=152.93)
 batch_loss: 47346.18  |  KL+NLL: 95701.12  (KL=95649.18, NLL=51.94)
 batch_loss: 47416.30  |  KL+NLL: 95726.19  (KL=95648.27, NLL=77.91)
 batch_loss: 47344.45  |  KL+NLL: 95826.75  (KL=95662.55, NLL=164.21)


Epoch 4/10:   3%|▎         | 10/360 [00:00<00:18, 19.30it/s]

 batch_loss: 46854.42  |  KL+NLL: 95727.85  (KL=95642.12, NLL=85.73)
 batch_loss: 47167.87  |  KL+NLL: 95800.23  (KL=95629.27, NLL=170.97)
 batch_loss: 46707.97  |  KL+NLL: 95865.37  (KL=95630.21, NLL=235.16)
 batch_loss: 46790.93  |  KL+NLL: 95767.93  (KL=95645.30, NLL=122.64)
 batch_loss: 46867.53  |  KL+NLL: 95772.28  (KL=95636.71, NLL=135.57)


Epoch 4/10:   4%|▎         | 13/360 [00:00<00:17, 20.06it/s]

 batch_loss: 47066.16  |  KL+NLL: 95728.82  (KL=95603.54, NLL=125.28)
 batch_loss: 46820.09  |  KL+NLL: 95759.61  (KL=95625.80, NLL=133.81)
 batch_loss: 46847.48  |  KL+NLL: 95804.46  (KL=95650.20, NLL=154.27)
 batch_loss: 46559.06  |  KL+NLL: 95712.07  (KL=95645.34, NLL=66.73)
 batch_loss: 46337.57  |  KL+NLL: 95724.19  (KL=95631.05, NLL=93.14)


Epoch 4/10:   5%|▌         | 19/360 [00:00<00:16, 20.70it/s]

 batch_loss: 46246.39  |  KL+NLL: 95729.77  (KL=95646.55, NLL=83.22)
 batch_loss: 46463.53  |  KL+NLL: 95817.52  (KL=95655.19, NLL=162.33)
 batch_loss: 46094.69  |  KL+NLL: 95662.16  (KL=95610.33, NLL=51.84)
 batch_loss: 46060.03  |  KL+NLL: 95688.92  (KL=95626.34, NLL=62.58)
 batch_loss: 46314.86  |  KL+NLL: 95742.61  (KL=95623.75, NLL=118.86)


Epoch 4/10:   7%|▋         | 25/360 [00:01<00:16, 20.72it/s]

 batch_loss: 46066.99  |  KL+NLL: 95704.03  (KL=95633.41, NLL=70.61)
 batch_loss: 45936.21  |  KL+NLL: 95740.06  (KL=95633.34, NLL=106.72)
 batch_loss: 45885.80  |  KL+NLL: 95657.26  (KL=95568.62, NLL=88.63)
 batch_loss: 45908.16  |  KL+NLL: 95710.86  (KL=95602.66, NLL=108.19)
 batch_loss: 45768.00  |  KL+NLL: 95767.12  (KL=95653.66, NLL=113.46)


Epoch 4/10:   8%|▊         | 28/360 [00:01<00:16, 20.52it/s]

 batch_loss: 45848.00  |  KL+NLL: 95755.68  (KL=95614.98, NLL=140.70)
 batch_loss: 45886.80  |  KL+NLL: 95768.86  (KL=95630.71, NLL=138.15)
 batch_loss: 45592.84  |  KL+NLL: 95837.09  (KL=95684.38, NLL=152.72)
 batch_loss: 45423.89  |  KL+NLL: 95799.03  (KL=95617.80, NLL=181.23)
 batch_loss: 45363.62  |  KL+NLL: 95732.68  (KL=95580.96, NLL=151.72)


Epoch 4/10:   9%|▉         | 34/360 [00:01<00:15, 20.64it/s]

 batch_loss: 45202.98  |  KL+NLL: 95677.47  (KL=95565.27, NLL=112.20)
 batch_loss: 45513.47  |  KL+NLL: 95715.71  (KL=95616.59, NLL=99.12)
 batch_loss: 45122.22  |  KL+NLL: 95760.59  (KL=95663.05, NLL=97.53)
 batch_loss: 44940.35  |  KL+NLL: 95721.40  (KL=95563.91, NLL=157.50)
 batch_loss: 44962.78  |  KL+NLL: 95722.68  (KL=95604.35, NLL=118.33)


Epoch 4/10:  11%|█         | 40/360 [00:02<00:15, 20.62it/s]

 batch_loss: 45193.55  |  KL+NLL: 95766.93  (KL=95666.95, NLL=99.98)
 batch_loss: 44922.94  |  KL+NLL: 95721.79  (KL=95586.85, NLL=134.94)
 batch_loss: 44777.00  |  KL+NLL: 95715.48  (KL=95614.77, NLL=100.70)
 batch_loss: 44566.88  |  KL+NLL: 95724.98  (KL=95579.45, NLL=145.54)
 batch_loss: 44824.00  |  KL+NLL: 95619.94  (KL=95531.87, NLL=88.07)


Epoch 4/10:  12%|█▏        | 43/360 [00:02<00:15, 20.22it/s]

 batch_loss: 44855.65  |  KL+NLL: 95678.20  (KL=95558.41, NLL=119.79)
 batch_loss: 44706.23  |  KL+NLL: 95689.01  (KL=95557.88, NLL=131.12)
 batch_loss: 44470.69  |  KL+NLL: 95712.56  (KL=95621.74, NLL=90.82)
 batch_loss: 44559.56  |  KL+NLL: 95736.90  (KL=95648.73, NLL=88.17)
 batch_loss: 44431.49  |  KL+NLL: 95689.26  (KL=95626.61, NLL=62.65)


Epoch 4/10:  14%|█▎        | 49/360 [00:02<00:15, 20.72it/s]

 batch_loss: 44072.27  |  KL+NLL: 95716.43  (KL=95585.41, NLL=131.01)
 batch_loss: 44317.40  |  KL+NLL: 95640.80  (KL=95548.48, NLL=92.32)
 batch_loss: 44538.77  |  KL+NLL: 95659.26  (KL=95573.31, NLL=85.95)
 batch_loss: 43913.01  |  KL+NLL: 95729.10  (KL=95622.46, NLL=106.63)
 batch_loss: 44202.26  |  KL+NLL: 95807.26  (KL=95650.48, NLL=156.77)


Epoch 4/10:  15%|█▌        | 55/360 [00:02<00:14, 21.43it/s]

 batch_loss: 44164.93  |  KL+NLL: 95731.22  (KL=95599.29, NLL=131.93)
 batch_loss: 44077.15  |  KL+NLL: 95754.93  (KL=95604.05, NLL=150.89)
 batch_loss: 44036.23  |  KL+NLL: 95840.34  (KL=95667.48, NLL=172.87)
 batch_loss: 43888.83  |  KL+NLL: 95718.78  (KL=95596.27, NLL=122.51)
 batch_loss: 43472.30  |  KL+NLL: 95683.51  (KL=95590.05, NLL=93.47)


Epoch 4/10:  16%|█▌        | 58/360 [00:02<00:14, 21.47it/s]

 batch_loss: 43797.46  |  KL+NLL: 95684.94  (KL=95570.00, NLL=114.94)
 batch_loss: 43381.97  |  KL+NLL: 95745.16  (KL=95584.20, NLL=160.97)
 batch_loss: 43383.41  |  KL+NLL: 95837.99  (KL=95651.09, NLL=186.89)
 batch_loss: 43413.08  |  KL+NLL: 95651.87  (KL=95554.86, NLL=97.01)
 batch_loss: 43640.69  |  KL+NLL: 95707.38  (KL=95576.25, NLL=131.13)


Epoch 4/10:  18%|█▊        | 64/360 [00:03<00:14, 21.13it/s]

 batch_loss: 43420.68  |  KL+NLL: 95833.33  (KL=95616.71, NLL=216.62)
 batch_loss: 43328.53  |  KL+NLL: 95735.68  (KL=95568.95, NLL=166.73)
 batch_loss: 43227.02  |  KL+NLL: 95627.75  (KL=95524.24, NLL=103.50)
 batch_loss: 43137.57  |  KL+NLL: 95681.93  (KL=95565.62, NLL=116.30)
 batch_loss: 42892.19  |  KL+NLL: 95693.10  (KL=95527.04, NLL=166.06)


Epoch 4/10:  19%|█▉        | 70/360 [00:03<00:13, 20.97it/s]

 batch_loss: 43133.61  |  KL+NLL: 95725.30  (KL=95590.41, NLL=134.89)
 batch_loss: 42894.60  |  KL+NLL: 95658.73  (KL=95559.62, NLL=99.12)
 batch_loss: 42501.18  |  KL+NLL: 95593.01  (KL=95510.98, NLL=82.03)
 batch_loss: 42526.94  |  KL+NLL: 95562.69  (KL=95456.90, NLL=105.79)
 batch_loss: 43022.52  |  KL+NLL: 95741.22  (KL=95584.05, NLL=157.17)


Epoch 4/10:  20%|██        | 73/360 [00:03<00:13, 21.17it/s]

 batch_loss: 42697.22  |  KL+NLL: 95703.31  (KL=95585.13, NLL=118.18)
 batch_loss: 42658.54  |  KL+NLL: 95672.03  (KL=95546.48, NLL=125.54)
 batch_loss: 42269.51  |  KL+NLL: 95639.54  (KL=95571.89, NLL=67.65)
 batch_loss: 42422.71  |  KL+NLL: 95736.06  (KL=95555.96, NLL=180.10)
 batch_loss: 42234.14  |  KL+NLL: 95760.31  (KL=95608.29, NLL=152.02)


Epoch 4/10:  22%|██▏       | 79/360 [00:03<00:13, 20.91it/s]

 batch_loss: 42233.29  |  KL+NLL: 95670.87  (KL=95543.41, NLL=127.46)
 batch_loss: 42415.85  |  KL+NLL: 95759.06  (KL=95612.56, NLL=146.50)
 batch_loss: 42119.15  |  KL+NLL: 95657.20  (KL=95538.67, NLL=118.53)
 batch_loss: 42156.59  |  KL+NLL: 95679.71  (KL=95550.13, NLL=129.58)
 batch_loss: 42044.99  |  KL+NLL: 95747.88  (KL=95589.30, NLL=158.59)


Epoch 4/10:  24%|██▎       | 85/360 [00:04<00:13, 20.65it/s]

 batch_loss: 41922.32  |  KL+NLL: 95697.95  (KL=95610.57, NLL=87.38)
 batch_loss: 41815.24  |  KL+NLL: 95684.25  (KL=95558.06, NLL=126.19)
 batch_loss: 41919.52  |  KL+NLL: 95668.68  (KL=95545.92, NLL=122.76)
 batch_loss: 41609.44  |  KL+NLL: 95677.28  (KL=95581.62, NLL=95.66)
 batch_loss: 41659.10  |  KL+NLL: 95657.34  (KL=95545.10, NLL=112.23)


Epoch 4/10:  24%|██▍       | 88/360 [00:04<00:13, 20.55it/s]

 batch_loss: 41929.90  |  KL+NLL: 95646.56  (KL=95513.73, NLL=132.83)
 batch_loss: 41340.87  |  KL+NLL: 95665.72  (KL=95524.61, NLL=141.11)
 batch_loss: 41622.86  |  KL+NLL: 95728.05  (KL=95602.77, NLL=125.28)
 batch_loss: 41524.04  |  KL+NLL: 95699.26  (KL=95562.08, NLL=137.18)
 batch_loss: 41297.29  |  KL+NLL: 95671.64  (KL=95522.64, NLL=149.00)


Epoch 4/10:  26%|██▌       | 94/360 [00:04<00:12, 20.48it/s]

 batch_loss: 41218.75  |  KL+NLL: 95673.71  (KL=95557.33, NLL=116.39)
 batch_loss: 41449.25  |  KL+NLL: 95729.64  (KL=95544.48, NLL=185.17)
 batch_loss: 41376.16  |  KL+NLL: 95656.81  (KL=95551.79, NLL=105.02)
 batch_loss: 40978.50  |  KL+NLL: 95619.76  (KL=95530.92, NLL=88.84)
 batch_loss: 41105.64  |  KL+NLL: 95631.81  (KL=95516.09, NLL=115.72)


Epoch 4/10:  28%|██▊       | 100/360 [00:04<00:12, 20.49it/s]

 batch_loss: 40991.02  |  KL+NLL: 95689.31  (KL=95573.15, NLL=116.16)
 batch_loss: 40753.61  |  KL+NLL: 95673.49  (KL=95557.05, NLL=116.45)
 batch_loss: 40956.23  |  KL+NLL: 95703.64  (KL=95586.09, NLL=117.55)
 batch_loss: 40730.13  |  KL+NLL: 95677.87  (KL=95524.64, NLL=153.23)
 batch_loss: 40668.60  |  KL+NLL: 95610.94  (KL=95530.94, NLL=80.00)


Epoch 4/10:  29%|██▊       | 103/360 [00:05<00:12, 20.33it/s]

 batch_loss: 40745.19  |  KL+NLL: 95709.43  (KL=95607.30, NLL=102.12)
 batch_loss: 40270.99  |  KL+NLL: 95612.32  (KL=95504.16, NLL=108.17)
 batch_loss: 40488.22  |  KL+NLL: 95694.43  (KL=95591.13, NLL=103.30)
 batch_loss: 40362.10  |  KL+NLL: 95683.98  (KL=95582.86, NLL=101.12)
 batch_loss: 40633.73  |  KL+NLL: 95757.44  (KL=95549.85, NLL=207.59)


Epoch 4/10:  30%|███       | 109/360 [00:05<00:12, 20.24it/s]

 batch_loss: 40441.85  |  KL+NLL: 95568.49  (KL=95495.20, NLL=73.29)
 batch_loss: 40530.54  |  KL+NLL: 95683.78  (KL=95568.98, NLL=114.81)
 batch_loss: 40314.06  |  KL+NLL: 95729.35  (KL=95568.38, NLL=160.98)
 batch_loss: 40191.98  |  KL+NLL: 95718.83  (KL=95538.83, NLL=180.00)


Epoch 4/10:  31%|███       | 112/360 [00:05<00:12, 20.35it/s]

 batch_loss: 40370.19  |  KL+NLL: 95743.61  (KL=95635.02, NLL=108.59)
 batch_loss: 39688.72  |  KL+NLL: 95633.34  (KL=95541.37, NLL=91.97)
 batch_loss: 40087.96  |  KL+NLL: 95832.43  (KL=95644.92, NLL=187.51)
 batch_loss: 40066.30  |  KL+NLL: 95694.66  (KL=95529.32, NLL=165.34)
 batch_loss: 40012.95  |  KL+NLL: 95689.65  (KL=95546.38, NLL=143.27)


Epoch 4/10:  33%|███▎      | 118/360 [00:05<00:11, 20.40it/s]

 batch_loss: 39599.14  |  KL+NLL: 95788.16  (KL=95676.95, NLL=111.22)
 batch_loss: 39830.26  |  KL+NLL: 95626.45  (KL=95496.09, NLL=130.35)
 batch_loss: 39663.09  |  KL+NLL: 95705.06  (KL=95578.82, NLL=126.24)
 batch_loss: 39567.52  |  KL+NLL: 95633.76  (KL=95535.75, NLL=98.01)
 batch_loss: 39405.27  |  KL+NLL: 95686.81  (KL=95587.09, NLL=99.73)


Epoch 4/10:  34%|███▍      | 124/360 [00:06<00:11, 20.40it/s]

 batch_loss: 39361.65  |  KL+NLL: 95632.62  (KL=95514.07, NLL=118.55)
 batch_loss: 39566.69  |  KL+NLL: 95665.44  (KL=95526.63, NLL=138.80)
 batch_loss: 39495.19  |  KL+NLL: 95684.50  (KL=95530.25, NLL=154.25)
 batch_loss: 39589.31  |  KL+NLL: 95764.83  (KL=95623.96, NLL=140.87)
 batch_loss: 39333.96  |  KL+NLL: 95707.75  (KL=95600.92, NLL=106.82)


Epoch 4/10:  35%|███▌      | 127/360 [00:06<00:11, 20.67it/s]

 batch_loss: 39194.37  |  KL+NLL: 95742.91  (KL=95584.27, NLL=158.65)
 batch_loss: 39003.28  |  KL+NLL: 95641.13  (KL=95553.41, NLL=87.71)
 batch_loss: 38977.90  |  KL+NLL: 95682.84  (KL=95535.93, NLL=146.91)
 batch_loss: 38973.09  |  KL+NLL: 95774.53  (KL=95585.70, NLL=188.83)
 batch_loss: 38773.48  |  KL+NLL: 95727.71  (KL=95567.34, NLL=160.38)


Epoch 4/10:  37%|███▋      | 133/360 [00:06<00:11, 20.24it/s]

 batch_loss: 39085.11  |  KL+NLL: 95603.34  (KL=95502.98, NLL=100.35)
 batch_loss: 39063.28  |  KL+NLL: 95716.24  (KL=95574.55, NLL=141.69)
 batch_loss: 38738.08  |  KL+NLL: 95695.74  (KL=95610.55, NLL=85.18)
 batch_loss: 38634.51  |  KL+NLL: 95703.02  (KL=95556.88, NLL=146.14)
 batch_loss: 38289.64  |  KL+NLL: 95719.79  (KL=95582.45, NLL=137.33)


Epoch 4/10:  39%|███▊      | 139/360 [00:06<00:10, 20.70it/s]

 batch_loss: 38520.40  |  KL+NLL: 95686.45  (KL=95557.41, NLL=129.03)
 batch_loss: 38348.50  |  KL+NLL: 95689.48  (KL=95571.00, NLL=118.48)
 batch_loss: 38206.31  |  KL+NLL: 95689.58  (KL=95612.25, NLL=77.33)
 batch_loss: 38206.03  |  KL+NLL: 95674.43  (KL=95537.94, NLL=136.50)
 batch_loss: 38142.31  |  KL+NLL: 95711.08  (KL=95594.77, NLL=116.32)


Epoch 4/10:  39%|███▉      | 142/360 [00:06<00:10, 20.77it/s]

 batch_loss: 38412.38  |  KL+NLL: 95674.35  (KL=95542.13, NLL=132.22)
 batch_loss: 38391.07  |  KL+NLL: 95644.14  (KL=95494.57, NLL=149.57)
 batch_loss: 38037.13  |  KL+NLL: 95689.18  (KL=95511.25, NLL=177.93)
 batch_loss: 37988.74  |  KL+NLL: 95659.76  (KL=95555.28, NLL=104.48)
 batch_loss: 38084.17  |  KL+NLL: 95692.01  (KL=95581.09, NLL=110.92)


Epoch 4/10:  41%|████      | 148/360 [00:07<00:10, 20.78it/s]

 batch_loss: 37533.72  |  KL+NLL: 95692.37  (KL=95570.67, NLL=121.70)
 batch_loss: 37833.14  |  KL+NLL: 95767.63  (KL=95598.87, NLL=168.76)
 batch_loss: 37729.36  |  KL+NLL: 95653.96  (KL=95503.88, NLL=150.09)
 batch_loss: 37720.26  |  KL+NLL: 95642.62  (KL=95508.83, NLL=133.79)
 batch_loss: 37538.63  |  KL+NLL: 95649.65  (KL=95520.10, NLL=129.54)


Epoch 4/10:  43%|████▎     | 154/360 [00:07<00:09, 21.25it/s]

 batch_loss: 37537.19  |  KL+NLL: 95716.26  (KL=95538.13, NLL=178.12)
 batch_loss: 37575.54  |  KL+NLL: 95694.69  (KL=95594.53, NLL=100.15)
 batch_loss: 37424.55  |  KL+NLL: 95792.48  (KL=95618.74, NLL=173.74)
 batch_loss: 37477.58  |  KL+NLL: 95549.13  (KL=95447.66, NLL=101.47)
 batch_loss: 37249.65  |  KL+NLL: 95666.97  (KL=95524.02, NLL=142.94)


Epoch 4/10:  44%|████▎     | 157/360 [00:07<00:09, 21.17it/s]

 batch_loss: 37285.18  |  KL+NLL: 95737.85  (KL=95596.27, NLL=141.59)
 batch_loss: 37247.75  |  KL+NLL: 95660.27  (KL=95530.45, NLL=129.82)
 batch_loss: 37052.57  |  KL+NLL: 95662.56  (KL=95509.58, NLL=152.99)
 batch_loss: 36926.54  |  KL+NLL: 95689.55  (KL=95582.94, NLL=106.61)
 batch_loss: 36989.03  |  KL+NLL: 95627.58  (KL=95532.14, NLL=95.44)


Epoch 4/10:  45%|████▌     | 163/360 [00:07<00:09, 21.40it/s]

 batch_loss: 37224.69  |  KL+NLL: 95693.05  (KL=95595.38, NLL=97.67)
 batch_loss: 36848.08  |  KL+NLL: 95698.87  (KL=95524.05, NLL=174.82)
 batch_loss: 36754.10  |  KL+NLL: 95695.63  (KL=95604.40, NLL=91.23)
 batch_loss: 36989.65  |  KL+NLL: 95738.72  (KL=95617.18, NLL=121.54)
 batch_loss: 36731.23  |  KL+NLL: 95707.55  (KL=95573.25, NLL=134.30)


Epoch 4/10:  47%|████▋     | 169/360 [00:08<00:09, 20.89it/s]

 batch_loss: 36526.52  |  KL+NLL: 95720.44  (KL=95596.36, NLL=124.08)
 batch_loss: 36597.33  |  KL+NLL: 95676.52  (KL=95546.09, NLL=130.42)
 batch_loss: 36306.87  |  KL+NLL: 95739.89  (KL=95643.96, NLL=95.93)
 batch_loss: 36416.42  |  KL+NLL: 95746.08  (KL=95599.91, NLL=146.17)
 batch_loss: 36324.50  |  KL+NLL: 95749.19  (KL=95636.25, NLL=112.94)


Epoch 4/10:  48%|████▊     | 172/360 [00:08<00:09, 20.42it/s]

 batch_loss: 36069.47  |  KL+NLL: 95666.68  (KL=95579.39, NLL=87.29)
 batch_loss: 36205.59  |  KL+NLL: 95663.39  (KL=95525.24, NLL=138.14)
 batch_loss: 36417.13  |  KL+NLL: 95685.30  (KL=95609.12, NLL=76.19)
 batch_loss: 36039.06  |  KL+NLL: 95668.23  (KL=95588.06, NLL=80.17)
 batch_loss: 35972.23  |  KL+NLL: 95698.13  (KL=95529.03, NLL=169.10)


Epoch 4/10:  49%|████▉     | 178/360 [00:08<00:08, 21.01it/s]

 batch_loss: 35956.34  |  KL+NLL: 95747.31  (KL=95627.61, NLL=119.70)
 batch_loss: 35839.89  |  KL+NLL: 95707.86  (KL=95598.91, NLL=108.95)
 batch_loss: 36039.80  |  KL+NLL: 95717.49  (KL=95612.81, NLL=104.68)
 batch_loss: 35910.43  |  KL+NLL: 95809.19  (KL=95667.08, NLL=142.12)
 batch_loss: 35846.16  |  KL+NLL: 95686.11  (KL=95592.05, NLL=94.05)


Epoch 4/10:  51%|█████     | 184/360 [00:08<00:08, 21.03it/s]

 batch_loss: 35701.97  |  KL+NLL: 95728.97  (KL=95594.85, NLL=134.12)
 batch_loss: 35676.78  |  KL+NLL: 95817.11  (KL=95635.87, NLL=181.25)
 batch_loss: 35676.60  |  KL+NLL: 95734.47  (KL=95606.04, NLL=128.43)
 batch_loss: 35248.76  |  KL+NLL: 95763.59  (KL=95623.84, NLL=139.75)
 batch_loss: 35373.50  |  KL+NLL: 95694.77  (KL=95591.09, NLL=103.68)


Epoch 4/10:  52%|█████▏    | 187/360 [00:09<00:08, 21.32it/s]

 batch_loss: 35464.77  |  KL+NLL: 95804.01  (KL=95666.23, NLL=137.79)
 batch_loss: 35429.87  |  KL+NLL: 95804.20  (KL=95642.11, NLL=162.09)
 batch_loss: 35414.04  |  KL+NLL: 95900.45  (KL=95717.37, NLL=183.08)
 batch_loss: 35311.91  |  KL+NLL: 95663.47  (KL=95544.04, NLL=119.43)
 batch_loss: 35347.21  |  KL+NLL: 95751.38  (KL=95589.65, NLL=161.73)


Epoch 4/10:  54%|█████▎    | 193/360 [00:09<00:07, 20.97it/s]

 batch_loss: 35258.64  |  KL+NLL: 95826.55  (KL=95639.68, NLL=186.87)
 batch_loss: 35117.39  |  KL+NLL: 95751.95  (KL=95666.00, NLL=85.95)
 batch_loss: 34891.18  |  KL+NLL: 95869.72  (KL=95727.26, NLL=142.46)
 batch_loss: 35022.24  |  KL+NLL: 95866.53  (KL=95673.82, NLL=192.71)
 batch_loss: 35083.81  |  KL+NLL: 95781.83  (KL=95632.09, NLL=149.74)


Epoch 4/10:  55%|█████▌    | 199/360 [00:09<00:07, 20.47it/s]

 batch_loss: 35167.22  |  KL+NLL: 95760.86  (KL=95654.55, NLL=106.31)
 batch_loss: 34532.26  |  KL+NLL: 95795.68  (KL=95616.16, NLL=179.51)
 batch_loss: 34745.24  |  KL+NLL: 95788.55  (KL=95649.23, NLL=139.31)
 batch_loss: 34626.66  |  KL+NLL: 95783.13  (KL=95662.20, NLL=120.93)
 batch_loss: 34693.60  |  KL+NLL: 95792.15  (KL=95655.03, NLL=137.12)


Epoch 4/10:  56%|█████▌    | 202/360 [00:09<00:07, 20.90it/s]

 batch_loss: 34796.53  |  KL+NLL: 95819.87  (KL=95645.53, NLL=174.34)
 batch_loss: 34718.72  |  KL+NLL: 95838.86  (KL=95681.95, NLL=156.90)
 batch_loss: 34179.89  |  KL+NLL: 95759.49  (KL=95672.21, NLL=87.28)
 batch_loss: 34446.49  |  KL+NLL: 95708.22  (KL=95635.13, NLL=73.09)
 batch_loss: 34396.69  |  KL+NLL: 95688.03  (KL=95572.67, NLL=115.36)


Epoch 4/10:  58%|█████▊    | 208/360 [00:10<00:07, 20.70it/s]

 batch_loss: 34051.65  |  KL+NLL: 95747.29  (KL=95626.84, NLL=120.44)
 batch_loss: 34307.83  |  KL+NLL: 95829.41  (KL=95668.73, NLL=160.68)
 batch_loss: 33850.30  |  KL+NLL: 95762.69  (KL=95650.95, NLL=111.74)
 batch_loss: 34298.63  |  KL+NLL: 95903.81  (KL=95747.01, NLL=156.81)
 batch_loss: 34168.32  |  KL+NLL: 95764.83  (KL=95646.41, NLL=118.42)


Epoch 4/10:  59%|█████▉    | 214/360 [00:10<00:06, 20.88it/s]

 batch_loss: 34405.52  |  KL+NLL: 95831.03  (KL=95681.43, NLL=149.60)
 batch_loss: 33828.79  |  KL+NLL: 95876.56  (KL=95704.88, NLL=171.68)
 batch_loss: 33919.02  |  KL+NLL: 95854.31  (KL=95676.34, NLL=177.96)
 batch_loss: 34124.00  |  KL+NLL: 95887.85  (KL=95689.11, NLL=198.74)
 batch_loss: 33884.70  |  KL+NLL: 95785.91  (KL=95643.92, NLL=141.99)


Epoch 4/10:  60%|██████    | 217/360 [00:10<00:06, 20.77it/s]

 batch_loss: 33736.88  |  KL+NLL: 95809.68  (KL=95725.46, NLL=84.22)
 batch_loss: 33391.32  |  KL+NLL: 95880.97  (KL=95694.22, NLL=186.75)
 batch_loss: 33382.38  |  KL+NLL: 95839.30  (KL=95710.87, NLL=128.43)
 batch_loss: 33762.35  |  KL+NLL: 95871.60  (KL=95697.59, NLL=174.01)
 batch_loss: 33436.50  |  KL+NLL: 95870.91  (KL=95688.47, NLL=182.44)


Epoch 4/10:  62%|██████▏   | 223/360 [00:10<00:06, 20.69it/s]

 batch_loss: 33436.06  |  KL+NLL: 95851.36  (KL=95666.93, NLL=184.43)
 batch_loss: 33376.43  |  KL+NLL: 95881.87  (KL=95754.63, NLL=127.24)
 batch_loss: 33242.08  |  KL+NLL: 95934.82  (KL=95699.40, NLL=235.42)
 batch_loss: 33173.46  |  KL+NLL: 95819.56  (KL=95701.36, NLL=118.20)
 batch_loss: 33270.25  |  KL+NLL: 95809.94  (KL=95716.33, NLL=93.62)


Epoch 4/10:  64%|██████▎   | 229/360 [00:11<00:06, 20.97it/s]

 batch_loss: 33016.28  |  KL+NLL: 95895.19  (KL=95741.34, NLL=153.84)
 batch_loss: 33205.99  |  KL+NLL: 95811.51  (KL=95688.55, NLL=122.97)
 batch_loss: 32597.51  |  KL+NLL: 95835.21  (KL=95728.98, NLL=106.23)
 batch_loss: 33150.81  |  KL+NLL: 95852.24  (KL=95716.12, NLL=136.13)
 batch_loss: 32939.08  |  KL+NLL: 95825.97  (KL=95718.54, NLL=107.43)


Epoch 4/10:  64%|██████▍   | 232/360 [00:11<00:06, 20.98it/s]

 batch_loss: 32709.03  |  KL+NLL: 95885.10  (KL=95716.02, NLL=169.08)
 batch_loss: 32821.50  |  KL+NLL: 95816.30  (KL=95713.41, NLL=102.89)
 batch_loss: 32999.84  |  KL+NLL: 95922.06  (KL=95796.96, NLL=125.10)
 batch_loss: 32880.17  |  KL+NLL: 95797.63  (KL=95707.37, NLL=90.26)
 batch_loss: 32743.95  |  KL+NLL: 95945.07  (KL=95742.03, NLL=203.04)


Epoch 4/10:  66%|██████▌   | 238/360 [00:11<00:05, 20.36it/s]

 batch_loss: 32682.06  |  KL+NLL: 95862.37  (KL=95756.36, NLL=106.01)
 batch_loss: 32777.63  |  KL+NLL: 95848.84  (KL=95719.83, NLL=129.01)
 batch_loss: 32634.48  |  KL+NLL: 95932.42  (KL=95790.62, NLL=141.80)
 batch_loss: 32564.38  |  KL+NLL: 95875.18  (KL=95751.03, NLL=124.14)
 batch_loss: 32365.38  |  KL+NLL: 95835.25  (KL=95716.38, NLL=118.87)


Epoch 4/10:  68%|██████▊   | 244/360 [00:11<00:05, 20.58it/s]

 batch_loss: 32236.48  |  KL+NLL: 95816.72  (KL=95736.78, NLL=79.94)
 batch_loss: 32600.63  |  KL+NLL: 95885.71  (KL=95821.80, NLL=63.90)
 batch_loss: 32243.40  |  KL+NLL: 95975.01  (KL=95809.08, NLL=165.93)
 batch_loss: 32148.88  |  KL+NLL: 95926.67  (KL=95770.04, NLL=156.63)
 batch_loss: 31904.05  |  KL+NLL: 95930.38  (KL=95812.24, NLL=118.14)


Epoch 4/10:  69%|██████▊   | 247/360 [00:11<00:05, 20.62it/s]

 batch_loss: 32107.26  |  KL+NLL: 95953.13  (KL=95851.44, NLL=101.69)
 batch_loss: 31914.30  |  KL+NLL: 95931.56  (KL=95791.69, NLL=139.88)
 batch_loss: 32086.02  |  KL+NLL: 95911.02  (KL=95740.82, NLL=170.20)
 batch_loss: 32005.48  |  KL+NLL: 95888.35  (KL=95763.55, NLL=124.81)
 batch_loss: 31933.36  |  KL+NLL: 95933.74  (KL=95792.80, NLL=140.94)


Epoch 4/10:  70%|███████   | 253/360 [00:12<00:05, 20.15it/s]

 batch_loss: 31630.77  |  KL+NLL: 95898.79  (KL=95813.91, NLL=84.87)
 batch_loss: 31889.44  |  KL+NLL: 95944.78  (KL=95828.44, NLL=116.34)
 batch_loss: 31462.09  |  KL+NLL: 95784.72  (KL=95666.87, NLL=117.86)
 batch_loss: 31654.19  |  KL+NLL: 95911.35  (KL=95789.12, NLL=122.23)


Epoch 4/10:  71%|███████   | 256/360 [00:12<00:05, 19.98it/s]

 batch_loss: 31560.29  |  KL+NLL: 96010.33  (KL=95829.62, NLL=180.71)
 batch_loss: 31599.43  |  KL+NLL: 95955.12  (KL=95762.53, NLL=192.59)
 batch_loss: 31509.29  |  KL+NLL: 95964.28  (KL=95852.16, NLL=112.12)
 batch_loss: 31581.21  |  KL+NLL: 96003.84  (KL=95827.73, NLL=176.12)
 batch_loss: 31369.93  |  KL+NLL: 95903.48  (KL=95810.45, NLL=93.02)


Epoch 4/10:  73%|███████▎  | 262/360 [00:12<00:04, 20.32it/s]

 batch_loss: 31278.66  |  KL+NLL: 95921.22  (KL=95834.66, NLL=86.56)
 batch_loss: 31096.74  |  KL+NLL: 96034.23  (KL=95901.21, NLL=133.01)
 batch_loss: 31289.16  |  KL+NLL: 95962.01  (KL=95808.70, NLL=153.30)
 batch_loss: 31581.14  |  KL+NLL: 96012.49  (KL=95874.89, NLL=137.60)
 batch_loss: 30779.21  |  KL+NLL: 95967.46  (KL=95883.88, NLL=83.58)


Epoch 4/10:  74%|███████▍  | 268/360 [00:13<00:04, 20.21it/s]

 batch_loss: 31105.77  |  KL+NLL: 95885.78  (KL=95787.56, NLL=98.22)
 batch_loss: 31087.24  |  KL+NLL: 96021.90  (KL=95880.98, NLL=140.91)
 batch_loss: 30759.67  |  KL+NLL: 96048.51  (KL=95856.96, NLL=191.55)
 batch_loss: 30976.51  |  KL+NLL: 96006.89  (KL=95843.37, NLL=163.52)
 batch_loss: 30760.37  |  KL+NLL: 95823.20  (KL=95747.42, NLL=75.77)


Epoch 4/10:  75%|███████▌  | 271/360 [00:13<00:04, 20.02it/s]

 batch_loss: 30818.51  |  KL+NLL: 96019.90  (KL=95866.37, NLL=153.54)
 batch_loss: 30605.40  |  KL+NLL: 96068.91  (KL=95904.30, NLL=164.60)
 batch_loss: 30565.94  |  KL+NLL: 95939.89  (KL=95811.44, NLL=128.45)
 batch_loss: 30535.70  |  KL+NLL: 96021.54  (KL=95896.66, NLL=124.88)
 batch_loss: 30332.10  |  KL+NLL: 96021.90  (KL=95839.00, NLL=182.90)


Epoch 4/10:  77%|███████▋  | 277/360 [00:13<00:04, 20.15it/s]

 batch_loss: 30481.24  |  KL+NLL: 96122.86  (KL=95938.12, NLL=184.74)
 batch_loss: 30284.56  |  KL+NLL: 96007.89  (KL=95884.94, NLL=122.95)
 batch_loss: 30366.16  |  KL+NLL: 96079.72  (KL=95921.85, NLL=157.87)
 batch_loss: 30300.56  |  KL+NLL: 95981.17  (KL=95901.34, NLL=79.83)
 batch_loss: 30277.27  |  KL+NLL: 96125.48  (KL=95966.96, NLL=158.52)


Epoch 4/10:  79%|███████▊  | 283/360 [00:13<00:03, 20.61it/s]

 batch_loss: 30102.93  |  KL+NLL: 95988.22  (KL=95903.83, NLL=84.40)
 batch_loss: 30205.68  |  KL+NLL: 95898.83  (KL=95817.95, NLL=80.88)
 batch_loss: 30144.55  |  KL+NLL: 95960.93  (KL=95818.47, NLL=142.46)
 batch_loss: 30114.09  |  KL+NLL: 95961.94  (KL=95871.73, NLL=90.21)
 batch_loss: 29871.52  |  KL+NLL: 95945.35  (KL=95837.46, NLL=107.89)


Epoch 4/10:  79%|███████▉  | 286/360 [00:13<00:03, 20.30it/s]

 batch_loss: 30007.04  |  KL+NLL: 96019.41  (KL=95944.30, NLL=75.11)
 batch_loss: 29963.19  |  KL+NLL: 96003.99  (KL=95878.46, NLL=125.52)
 batch_loss: 29794.82  |  KL+NLL: 96121.60  (KL=95995.84, NLL=125.76)
 batch_loss: 29912.74  |  KL+NLL: 96005.39  (KL=95885.55, NLL=119.84)
 batch_loss: 29995.95  |  KL+NLL: 96066.16  (KL=95922.64, NLL=143.52)


Epoch 4/10:  81%|████████  | 292/360 [00:14<00:03, 20.67it/s]

 batch_loss: 29619.62  |  KL+NLL: 96109.07  (KL=95966.61, NLL=142.46)
 batch_loss: 29456.27  |  KL+NLL: 95987.02  (KL=95883.55, NLL=103.46)
 batch_loss: 29623.60  |  KL+NLL: 96028.45  (KL=95944.28, NLL=84.17)
 batch_loss: 29670.60  |  KL+NLL: 96053.51  (KL=95931.43, NLL=122.08)
 batch_loss: 29535.52  |  KL+NLL: 96028.06  (KL=95898.98, NLL=129.07)


Epoch 4/10:  83%|████████▎ | 298/360 [00:14<00:03, 20.41it/s]

 batch_loss: 29509.94  |  KL+NLL: 96094.13  (KL=95932.96, NLL=161.17)
 batch_loss: 29246.66  |  KL+NLL: 96085.36  (KL=95947.36, NLL=138.00)
 batch_loss: 29485.27  |  KL+NLL: 96072.38  (KL=95988.27, NLL=84.11)
 batch_loss: 29468.43  |  KL+NLL: 96154.23  (KL=96001.42, NLL=152.81)
 batch_loss: 29289.77  |  KL+NLL: 96046.86  (KL=95928.91, NLL=117.95)


Epoch 4/10:  84%|████████▎ | 301/360 [00:14<00:02, 20.29it/s]

 batch_loss: 29041.36  |  KL+NLL: 96075.52  (KL=95981.96, NLL=93.56)
 batch_loss: 29203.22  |  KL+NLL: 96145.75  (KL=95980.39, NLL=165.36)
 batch_loss: 28910.00  |  KL+NLL: 96220.09  (KL=95985.73, NLL=234.36)
 batch_loss: 29104.92  |  KL+NLL: 96142.94  (KL=96007.80, NLL=135.14)
 batch_loss: 28914.03  |  KL+NLL: 96141.91  (KL=96052.35, NLL=89.56)


Epoch 4/10:  85%|████████▌ | 307/360 [00:14<00:02, 20.27it/s]

 batch_loss: 28976.76  |  KL+NLL: 96131.48  (KL=96015.71, NLL=115.77)
 batch_loss: 28708.14  |  KL+NLL: 96081.76  (KL=95984.49, NLL=97.27)
 batch_loss: 28706.51  |  KL+NLL: 96154.02  (KL=96008.12, NLL=145.89)
 batch_loss: 28697.85  |  KL+NLL: 96102.05  (KL=96005.17, NLL=96.88)
 batch_loss: 28947.02  |  KL+NLL: 96181.02  (KL=95964.98, NLL=216.04)


Epoch 4/10:  87%|████████▋ | 313/360 [00:15<00:02, 20.65it/s]

 batch_loss: 28744.04  |  KL+NLL: 96302.93  (KL=96084.10, NLL=218.83)
 batch_loss: 28509.10  |  KL+NLL: 96067.65  (KL=95960.38, NLL=107.27)
 batch_loss: 28462.43  |  KL+NLL: 96115.09  (KL=96005.22, NLL=109.87)
 batch_loss: 28636.49  |  KL+NLL: 96195.23  (KL=95980.59, NLL=214.64)
 batch_loss: 28366.22  |  KL+NLL: 96204.30  (KL=96022.82, NLL=181.48)


Epoch 4/10:  88%|████████▊ | 316/360 [00:15<00:02, 20.55it/s]

 batch_loss: 28501.15  |  KL+NLL: 96145.42  (KL=96059.40, NLL=86.02)
 batch_loss: 28504.29  |  KL+NLL: 96139.35  (KL=96056.27, NLL=83.08)
 batch_loss: 28483.02  |  KL+NLL: 96186.46  (KL=96077.23, NLL=109.23)
 batch_loss: 28396.10  |  KL+NLL: 96190.90  (KL=96049.62, NLL=141.27)
 batch_loss: 28141.51  |  KL+NLL: 96178.57  (KL=96082.91, NLL=95.66)


Epoch 4/10:  89%|████████▉ | 322/360 [00:15<00:01, 20.25it/s]

 batch_loss: 27988.81  |  KL+NLL: 96120.26  (KL=96003.72, NLL=116.54)
 batch_loss: 28189.14  |  KL+NLL: 96160.89  (KL=96020.80, NLL=140.09)
 batch_loss: 28467.22  |  KL+NLL: 96176.73  (KL=96001.80, NLL=174.93)
 batch_loss: 28108.10  |  KL+NLL: 96221.57  (KL=96092.15, NLL=129.43)
 batch_loss: 28142.40  |  KL+NLL: 96229.59  (KL=96129.26, NLL=100.33)


Epoch 4/10:  91%|█████████ | 328/360 [00:15<00:01, 20.48it/s]

 batch_loss: 27804.55  |  KL+NLL: 96097.93  (KL=96003.09, NLL=94.84)
 batch_loss: 27751.16  |  KL+NLL: 96337.26  (KL=96166.62, NLL=170.64)
 batch_loss: 27928.17  |  KL+NLL: 96251.59  (KL=96140.79, NLL=110.80)
 batch_loss: 27449.28  |  KL+NLL: 96241.97  (KL=96048.76, NLL=193.21)
 batch_loss: 27643.55  |  KL+NLL: 96280.63  (KL=96164.27, NLL=116.36)


Epoch 4/10:  92%|█████████▏| 331/360 [00:16<00:01, 20.65it/s]

 batch_loss: 27876.28  |  KL+NLL: 96272.57  (KL=96135.84, NLL=136.73)
 batch_loss: 27669.33  |  KL+NLL: 96237.41  (KL=96089.63, NLL=147.78)
 batch_loss: 27713.83  |  KL+NLL: 96190.08  (KL=96096.42, NLL=93.66)
 batch_loss: 27648.72  |  KL+NLL: 96289.34  (KL=96153.35, NLL=135.99)
 batch_loss: 27310.07  |  KL+NLL: 96253.36  (KL=96092.28, NLL=161.08)


Epoch 4/10:  94%|█████████▎| 337/360 [00:16<00:01, 20.17it/s]

 batch_loss: 27224.94  |  KL+NLL: 96359.24  (KL=96119.91, NLL=239.33)
 batch_loss: 27535.67  |  KL+NLL: 96251.35  (KL=96168.01, NLL=83.35)
 batch_loss: 27202.26  |  KL+NLL: 96212.73  (KL=96118.12, NLL=94.61)
 batch_loss: 27479.59  |  KL+NLL: 96373.21  (KL=96198.12, NLL=175.10)


Epoch 4/10:  94%|█████████▍| 340/360 [00:16<00:00, 20.09it/s]

 batch_loss: 27271.34  |  KL+NLL: 96325.14  (KL=96173.60, NLL=151.54)
 batch_loss: 27378.10  |  KL+NLL: 96241.77  (KL=96133.27, NLL=108.50)
 batch_loss: 27232.27  |  KL+NLL: 96314.61  (KL=96190.05, NLL=124.55)
 batch_loss: 27150.32  |  KL+NLL: 96256.14  (KL=96162.84, NLL=93.30)
 batch_loss: 27062.13  |  KL+NLL: 96307.46  (KL=96179.27, NLL=128.19)


Epoch 4/10:  96%|█████████▌| 346/360 [00:16<00:00, 20.03it/s]

 batch_loss: 27235.94  |  KL+NLL: 96346.61  (KL=96205.74, NLL=140.87)
 batch_loss: 26898.64  |  KL+NLL: 96295.22  (KL=96139.57, NLL=155.65)
 batch_loss: 26932.75  |  KL+NLL: 96291.94  (KL=96167.02, NLL=124.92)
 batch_loss: 26726.72  |  KL+NLL: 96355.00  (KL=96240.75, NLL=114.25)
 batch_loss: 26854.98  |  KL+NLL: 96357.59  (KL=96230.37, NLL=127.23)


Epoch 4/10:  98%|█████████▊| 352/360 [00:17<00:00, 20.19it/s]

 batch_loss: 26561.09  |  KL+NLL: 96399.21  (KL=96269.07, NLL=130.14)
 batch_loss: 26750.42  |  KL+NLL: 96399.26  (KL=96267.13, NLL=132.13)
 batch_loss: 26747.48  |  KL+NLL: 96438.40  (KL=96225.20, NLL=213.21)
 batch_loss: 26424.93  |  KL+NLL: 96350.87  (KL=96207.47, NLL=143.40)
 batch_loss: 26287.82  |  KL+NLL: 96373.87  (KL=96227.68, NLL=146.19)


Epoch 4/10:  99%|█████████▊| 355/360 [00:17<00:00, 20.76it/s]

 batch_loss: 26551.91  |  KL+NLL: 96387.65  (KL=96286.68, NLL=100.97)
 batch_loss: 26436.17  |  KL+NLL: 96402.93  (KL=96204.56, NLL=198.37)
 batch_loss: 26397.41  |  KL+NLL: 96353.59  (KL=96221.94, NLL=131.65)
 batch_loss: 26535.38  |  KL+NLL: 96303.01  (KL=96173.51, NLL=129.50)
 batch_loss: 26333.96  |  KL+NLL: 96370.83  (KL=96216.73, NLL=154.09)


Epoch 4/10: 100%|██████████| 360/360 [00:17<00:00, 20.55it/s]


 batch_loss: 26299.50  |  KL+NLL: 96353.94  (KL=96238.90, NLL=115.04)
 batch_loss: 26289.55  |  KL+NLL: 96389.95  (KL=96279.63, NLL=110.32)
 batch_loss: 26323.20  |  KL+NLL: 96411.28  (KL=96299.28, NLL=112.00)
Epoch 4 - ELBO Loss: 36136.4104


Epoch 5/10:   1%|          | 3/360 [00:00<00:26, 13.55it/s]

 batch_loss: 26199.59  |  KL+NLL: 96311.65  (KL=96214.81, NLL=96.83)
 batch_loss: 26057.23  |  KL+NLL: 96412.54  (KL=96325.66, NLL=86.88)
 batch_loss: 26067.03  |  KL+NLL: 96355.43  (KL=96185.52, NLL=169.91)
 batch_loss: 25993.69  |  KL+NLL: 96480.26  (KL=96311.87, NLL=168.40)
 batch_loss: 25993.92  |  KL+NLL: 96473.14  (KL=96321.76, NLL=151.38)


Epoch 5/10:   2%|▏         | 8/360 [00:00<00:20, 17.52it/s]

 batch_loss: 26009.83  |  KL+NLL: 96424.19  (KL=96292.12, NLL=132.07)
 batch_loss: 25946.95  |  KL+NLL: 96584.39  (KL=96370.74, NLL=213.65)
 batch_loss: 25897.03  |  KL+NLL: 96462.38  (KL=96362.88, NLL=99.49)
 batch_loss: 25796.45  |  KL+NLL: 96447.00  (KL=96347.52, NLL=99.48)


Epoch 5/10:   4%|▎         | 13/360 [00:00<00:17, 19.43it/s]

 batch_loss: 25741.43  |  KL+NLL: 96403.05  (KL=96300.68, NLL=102.37)
 batch_loss: 25751.43  |  KL+NLL: 96471.53  (KL=96302.46, NLL=169.06)
 batch_loss: 25675.29  |  KL+NLL: 96582.52  (KL=96434.84, NLL=147.69)
 batch_loss: 25552.31  |  KL+NLL: 96539.93  (KL=96371.20, NLL=168.73)
 batch_loss: 25753.36  |  KL+NLL: 96493.06  (KL=96335.34, NLL=157.72)


Epoch 5/10:   5%|▌         | 19/360 [00:01<00:16, 20.26it/s]

 batch_loss: 25424.37  |  KL+NLL: 96450.55  (KL=96300.87, NLL=149.68)
 batch_loss: 25427.83  |  KL+NLL: 96516.63  (KL=96361.79, NLL=154.84)
 batch_loss: 25271.56  |  KL+NLL: 96511.16  (KL=96315.05, NLL=196.10)
 batch_loss: 25536.47  |  KL+NLL: 96616.26  (KL=96416.93, NLL=199.33)
 batch_loss: 25297.15  |  KL+NLL: 96543.15  (KL=96413.30, NLL=129.84)


Epoch 5/10:   6%|▌         | 22/360 [00:01<00:16, 20.29it/s]

 batch_loss: 24990.59  |  KL+NLL: 96406.69  (KL=96294.67, NLL=112.02)
 batch_loss: 25313.37  |  KL+NLL: 96561.29  (KL=96418.58, NLL=142.71)
 batch_loss: 25408.18  |  KL+NLL: 96488.79  (KL=96336.02, NLL=152.77)
 batch_loss: 25063.88  |  KL+NLL: 96620.65  (KL=96430.14, NLL=190.51)
 batch_loss: 25126.33  |  KL+NLL: 96607.94  (KL=96392.27, NLL=215.67)


Epoch 5/10:   8%|▊         | 28/360 [00:01<00:16, 20.40it/s]

 batch_loss: 24979.47  |  KL+NLL: 96588.18  (KL=96455.98, NLL=132.20)
 batch_loss: 25073.03  |  KL+NLL: 96408.33  (KL=96280.40, NLL=127.93)
 batch_loss: 24929.51  |  KL+NLL: 96531.58  (KL=96411.84, NLL=119.73)
 batch_loss: 24602.43  |  KL+NLL: 96624.60  (KL=96440.43, NLL=184.17)
 batch_loss: 24680.38  |  KL+NLL: 96508.81  (KL=96398.75, NLL=110.06)


Epoch 5/10:   9%|▉         | 34/360 [00:01<00:15, 20.73it/s]

 batch_loss: 24885.30  |  KL+NLL: 96600.28  (KL=96384.91, NLL=215.37)
 batch_loss: 24627.84  |  KL+NLL: 96564.99  (KL=96429.41, NLL=135.57)
 batch_loss: 24873.40  |  KL+NLL: 96782.62  (KL=96587.29, NLL=195.33)
 batch_loss: 24675.70  |  KL+NLL: 96691.99  (KL=96534.16, NLL=157.83)
 batch_loss: 24494.43  |  KL+NLL: 96698.33  (KL=96540.69, NLL=157.64)


Epoch 5/10:  10%|█         | 37/360 [00:01<00:15, 20.62it/s]

 batch_loss: 24579.47  |  KL+NLL: 96737.63  (KL=96472.52, NLL=265.11)
 batch_loss: 24575.74  |  KL+NLL: 96575.54  (KL=96491.25, NLL=84.29)
 batch_loss: 24341.39  |  KL+NLL: 96551.44  (KL=96407.23, NLL=144.21)
 batch_loss: 24275.55  |  KL+NLL: 96654.35  (KL=96515.82, NLL=138.53)
 batch_loss: 24345.63  |  KL+NLL: 96665.64  (KL=96543.45, NLL=122.20)


Epoch 5/10:  12%|█▏        | 43/360 [00:02<00:15, 20.97it/s]

 batch_loss: 24346.91  |  KL+NLL: 96653.10  (KL=96513.61, NLL=139.50)
 batch_loss: 24181.38  |  KL+NLL: 96634.31  (KL=96471.62, NLL=162.68)
 batch_loss: 24121.77  |  KL+NLL: 96754.47  (KL=96606.12, NLL=148.36)
 batch_loss: 24245.24  |  KL+NLL: 96672.33  (KL=96549.84, NLL=122.49)
 batch_loss: 24033.06  |  KL+NLL: 96693.84  (KL=96546.34, NLL=147.49)


Epoch 5/10:  14%|█▎        | 49/360 [00:02<00:15, 20.49it/s]

 batch_loss: 24243.59  |  KL+NLL: 96653.83  (KL=96520.62, NLL=133.21)
 batch_loss: 24220.64  |  KL+NLL: 96728.43  (KL=96551.42, NLL=177.01)
 batch_loss: 24156.99  |  KL+NLL: 96660.55  (KL=96524.94, NLL=135.61)
 batch_loss: 23795.81  |  KL+NLL: 96612.19  (KL=96510.27, NLL=101.92)
 batch_loss: 23831.10  |  KL+NLL: 96656.10  (KL=96542.53, NLL=113.57)


Epoch 5/10:  14%|█▍        | 52/360 [00:02<00:14, 20.89it/s]

 batch_loss: 23919.73  |  KL+NLL: 96629.72  (KL=96538.98, NLL=90.73)
 batch_loss: 24206.26  |  KL+NLL: 96654.09  (KL=96526.98, NLL=127.11)
 batch_loss: 23660.57  |  KL+NLL: 96730.96  (KL=96643.75, NLL=87.21)
 batch_loss: 23819.15  |  KL+NLL: 96707.03  (KL=96557.08, NLL=149.95)
 batch_loss: 23715.72  |  KL+NLL: 96741.25  (KL=96609.84, NLL=131.41)


Epoch 5/10:  16%|█▌        | 58/360 [00:02<00:14, 21.19it/s]

 batch_loss: 23638.57  |  KL+NLL: 96756.84  (KL=96640.37, NLL=116.47)
 batch_loss: 23894.43  |  KL+NLL: 96878.22  (KL=96635.67, NLL=242.55)
 batch_loss: 23656.36  |  KL+NLL: 96701.52  (KL=96567.34, NLL=134.17)
 batch_loss: 23460.91  |  KL+NLL: 96713.39  (KL=96580.32, NLL=133.07)
 batch_loss: 23325.13  |  KL+NLL: 96895.63  (KL=96718.71, NLL=176.92)


Epoch 5/10:  18%|█▊        | 64/360 [00:03<00:13, 21.15it/s]

 batch_loss: 23418.62  |  KL+NLL: 96741.98  (KL=96626.51, NLL=115.48)
 batch_loss: 23553.26  |  KL+NLL: 96678.44  (KL=96545.95, NLL=132.49)
 batch_loss: 23369.34  |  KL+NLL: 96747.94  (KL=96631.53, NLL=116.41)
 batch_loss: 23575.36  |  KL+NLL: 96747.07  (KL=96605.11, NLL=141.96)
 batch_loss: 23254.74  |  KL+NLL: 96861.16  (KL=96671.62, NLL=189.53)


Epoch 5/10:  19%|█▊        | 67/360 [00:03<00:13, 21.17it/s]

 batch_loss: 23221.28  |  KL+NLL: 96809.41  (KL=96672.87, NLL=136.55)
 batch_loss: 23175.23  |  KL+NLL: 96825.44  (KL=96671.37, NLL=154.07)
 batch_loss: 23265.01  |  KL+NLL: 96784.19  (KL=96666.04, NLL=118.15)
 batch_loss: 23140.31  |  KL+NLL: 96875.35  (KL=96736.18, NLL=139.17)
 batch_loss: 23127.61  |  KL+NLL: 96846.69  (KL=96686.80, NLL=159.88)


Epoch 5/10:  20%|██        | 73/360 [00:03<00:13, 20.86it/s]

 batch_loss: 22771.39  |  KL+NLL: 96800.61  (KL=96693.40, NLL=107.21)
 batch_loss: 22902.41  |  KL+NLL: 96822.98  (KL=96637.39, NLL=185.59)
 batch_loss: 22943.61  |  KL+NLL: 96791.38  (KL=96683.64, NLL=107.74)
 batch_loss: 22822.08  |  KL+NLL: 96775.17  (KL=96653.17, NLL=122.00)
 batch_loss: 22838.41  |  KL+NLL: 96851.84  (KL=96698.37, NLL=153.48)


Epoch 5/10:  22%|██▏       | 79/360 [00:03<00:13, 20.67it/s]

 batch_loss: 22631.79  |  KL+NLL: 96871.10  (KL=96738.30, NLL=132.81)
 batch_loss: 22663.28  |  KL+NLL: 96900.64  (KL=96734.66, NLL=165.97)
 batch_loss: 22668.32  |  KL+NLL: 97005.87  (KL=96833.44, NLL=172.43)
 batch_loss: 22621.03  |  KL+NLL: 96829.45  (KL=96729.29, NLL=100.16)
 batch_loss: 22738.39  |  KL+NLL: 96951.79  (KL=96798.38, NLL=153.41)


Epoch 5/10:  23%|██▎       | 82/360 [00:04<00:13, 20.45it/s]

 batch_loss: 22734.18  |  KL+NLL: 96910.39  (KL=96824.84, NLL=85.56)
 batch_loss: 22628.53  |  KL+NLL: 96785.55  (KL=96690.76, NLL=94.79)
 batch_loss: 22355.30  |  KL+NLL: 96988.73  (KL=96836.69, NLL=152.05)
 batch_loss: 22523.21  |  KL+NLL: 96878.58  (KL=96761.78, NLL=116.80)
 batch_loss: 22548.49  |  KL+NLL: 96978.13  (KL=96818.32, NLL=159.81)


Epoch 5/10:  24%|██▍       | 88/360 [00:04<00:13, 20.28it/s]

 batch_loss: 22283.09  |  KL+NLL: 96957.05  (KL=96830.62, NLL=126.43)
 batch_loss: 22477.01  |  KL+NLL: 96964.87  (KL=96840.06, NLL=124.81)
 batch_loss: 22333.43  |  KL+NLL: 96935.72  (KL=96810.51, NLL=125.22)
 batch_loss: 22202.92  |  KL+NLL: 96919.10  (KL=96781.56, NLL=137.54)
 batch_loss: 22113.57  |  KL+NLL: 96920.46  (KL=96803.23, NLL=117.23)


Epoch 5/10:  26%|██▌       | 94/360 [00:04<00:12, 20.78it/s]

 batch_loss: 22158.93  |  KL+NLL: 96909.25  (KL=96773.61, NLL=135.64)
 batch_loss: 22303.26  |  KL+NLL: 97060.99  (KL=96898.14, NLL=162.85)
 batch_loss: 22186.61  |  KL+NLL: 96995.56  (KL=96831.20, NLL=164.36)
 batch_loss: 21912.02  |  KL+NLL: 97102.45  (KL=96906.11, NLL=196.34)
 batch_loss: 21999.24  |  KL+NLL: 97017.39  (KL=96892.26, NLL=125.13)


Epoch 5/10:  27%|██▋       | 97/360 [00:04<00:12, 21.18it/s]

 batch_loss: 22024.25  |  KL+NLL: 96980.52  (KL=96813.60, NLL=166.92)
 batch_loss: 22038.39  |  KL+NLL: 97026.20  (KL=96866.91, NLL=159.29)
 batch_loss: 22021.03  |  KL+NLL: 97094.47  (KL=96938.76, NLL=155.71)
 batch_loss: 21907.17  |  KL+NLL: 97026.84  (KL=96806.84, NLL=219.99)
 batch_loss: 21632.64  |  KL+NLL: 96946.98  (KL=96821.55, NLL=125.44)


Epoch 5/10:  29%|██▊       | 103/360 [00:05<00:12, 20.96it/s]

 batch_loss: 21650.95  |  KL+NLL: 96992.79  (KL=96895.94, NLL=96.85)
 batch_loss: 21555.40  |  KL+NLL: 97013.56  (KL=96897.03, NLL=116.53)
 batch_loss: 21620.92  |  KL+NLL: 97074.42  (KL=96934.09, NLL=140.34)
 batch_loss: 21557.93  |  KL+NLL: 97135.94  (KL=96992.58, NLL=143.36)
 batch_loss: 21434.70  |  KL+NLL: 97110.77  (KL=96939.71, NLL=171.06)


Epoch 5/10:  30%|███       | 109/360 [00:05<00:12, 20.89it/s]

 batch_loss: 21492.35  |  KL+NLL: 97029.43  (KL=96928.17, NLL=101.25)
 batch_loss: 21521.54  |  KL+NLL: 96939.89  (KL=96830.95, NLL=108.95)
 batch_loss: 21371.13  |  KL+NLL: 97123.21  (KL=96960.95, NLL=162.26)
 batch_loss: 21452.42  |  KL+NLL: 97039.22  (KL=96920.93, NLL=118.29)
 batch_loss: 21416.84  |  KL+NLL: 97176.83  (KL=97050.55, NLL=126.28)


Epoch 5/10:  31%|███       | 112/360 [00:05<00:12, 20.41it/s]

 batch_loss: 21414.37  |  KL+NLL: 97258.98  (KL=97171.57, NLL=87.41)
 batch_loss: 21816.07  |  KL+NLL: 97182.57  (KL=97089.22, NLL=93.35)
 batch_loss: 21183.67  |  KL+NLL: 97085.48  (KL=96975.96, NLL=109.52)
 batch_loss: 21336.45  |  KL+NLL: 97110.63  (KL=97024.80, NLL=85.83)


Epoch 5/10:  32%|███▏      | 115/360 [00:05<00:12, 20.01it/s]

 batch_loss: 21179.92  |  KL+NLL: 97153.51  (KL=96970.28, NLL=183.23)
 batch_loss: 21191.48  |  KL+NLL: 97206.14  (KL=97040.59, NLL=165.55)
 batch_loss: 21196.63  |  KL+NLL: 97169.23  (KL=97047.12, NLL=122.11)
 batch_loss: 21253.57  |  KL+NLL: 97146.22  (KL=97035.83, NLL=110.40)


Epoch 5/10:  34%|███▎      | 121/360 [00:05<00:12, 19.87it/s]

 batch_loss: 20846.20  |  KL+NLL: 97103.72  (KL=96961.95, NLL=141.78)
 batch_loss: 21012.52  |  KL+NLL: 97226.16  (KL=97144.84, NLL=81.31)
 batch_loss: 21109.54  |  KL+NLL: 97209.28  (KL=97093.87, NLL=115.41)
 batch_loss: 21080.42  |  KL+NLL: 97149.29  (KL=97010.51, NLL=138.78)
 batch_loss: 20874.29  |  KL+NLL: 97180.67  (KL=97069.91, NLL=110.77)


Epoch 5/10:  35%|███▌      | 127/360 [00:06<00:11, 20.35it/s]

 batch_loss: 20667.71  |  KL+NLL: 97230.91  (KL=97134.92, NLL=95.99)
 batch_loss: 20885.15  |  KL+NLL: 97232.00  (KL=97099.23, NLL=132.76)
 batch_loss: 20751.57  |  KL+NLL: 97350.89  (KL=97152.37, NLL=198.52)
 batch_loss: 20728.77  |  KL+NLL: 97246.83  (KL=97107.40, NLL=139.43)
 batch_loss: 20796.90  |  KL+NLL: 97317.98  (KL=97202.06, NLL=115.92)


Epoch 5/10:  36%|███▌      | 130/360 [00:06<00:11, 20.84it/s]

 batch_loss: 20569.88  |  KL+NLL: 97367.89  (KL=97207.51, NLL=160.38)
 batch_loss: 20518.37  |  KL+NLL: 97285.48  (KL=97136.93, NLL=148.55)
 batch_loss: 20611.43  |  KL+NLL: 97262.11  (KL=97091.93, NLL=170.18)
 batch_loss: 20519.30  |  KL+NLL: 97170.94  (KL=97067.27, NLL=103.67)
 batch_loss: 20504.67  |  KL+NLL: 97307.34  (KL=97139.55, NLL=167.78)


Epoch 5/10:  38%|███▊      | 136/360 [00:06<00:10, 21.06it/s]

 batch_loss: 20335.04  |  KL+NLL: 97289.17  (KL=97133.78, NLL=155.39)
 batch_loss: 20361.15  |  KL+NLL: 97254.17  (KL=97129.34, NLL=124.84)
 batch_loss: 20355.47  |  KL+NLL: 97297.66  (KL=97136.71, NLL=160.95)
 batch_loss: 20400.30  |  KL+NLL: 97261.60  (KL=97138.45, NLL=123.15)
 batch_loss: 20166.25  |  KL+NLL: 97187.89  (KL=97080.70, NLL=107.20)


Epoch 5/10:  39%|███▉      | 142/360 [00:06<00:10, 21.61it/s]

 batch_loss: 20266.40  |  KL+NLL: 97360.84  (KL=97226.71, NLL=134.13)
 batch_loss: 20318.65  |  KL+NLL: 97513.16  (KL=97295.91, NLL=217.25)
 batch_loss: 19958.18  |  KL+NLL: 97359.95  (KL=97212.48, NLL=147.47)
 batch_loss: 19968.12  |  KL+NLL: 97384.45  (KL=97264.33, NLL=120.12)
 batch_loss: 20119.12  |  KL+NLL: 97357.75  (KL=97096.98, NLL=260.77)


Epoch 5/10:  40%|████      | 145/360 [00:07<00:09, 21.78it/s]

 batch_loss: 19899.58  |  KL+NLL: 97225.07  (KL=97114.17, NLL=110.90)
 batch_loss: 19756.40  |  KL+NLL: 97421.02  (KL=97300.66, NLL=120.35)
 batch_loss: 19836.14  |  KL+NLL: 97509.12  (KL=97375.93, NLL=133.19)
 batch_loss: 19809.71  |  KL+NLL: 97357.64  (KL=97214.45, NLL=143.20)
 batch_loss: 19733.53  |  KL+NLL: 97392.88  (KL=97226.35, NLL=166.53)


Epoch 5/10:  42%|████▏     | 151/360 [00:07<00:09, 21.83it/s]

 batch_loss: 20016.69  |  KL+NLL: 97451.94  (KL=97278.64, NLL=173.29)
 batch_loss: 19566.41  |  KL+NLL: 97344.83  (KL=97259.84, NLL=84.99)
 batch_loss: 19721.54  |  KL+NLL: 97388.22  (KL=97238.92, NLL=149.29)
 batch_loss: 19688.46  |  KL+NLL: 97389.61  (KL=97261.12, NLL=128.49)
 batch_loss: 19872.34  |  KL+NLL: 97406.34  (KL=97254.38, NLL=151.97)


Epoch 5/10:  44%|████▎     | 157/360 [00:07<00:09, 21.17it/s]

 batch_loss: 19512.74  |  KL+NLL: 97481.73  (KL=97316.34, NLL=165.40)
 batch_loss: 19827.86  |  KL+NLL: 97421.28  (KL=97279.98, NLL=141.29)
 batch_loss: 19354.75  |  KL+NLL: 97408.64  (KL=97248.17, NLL=160.46)
 batch_loss: 19771.62  |  KL+NLL: 97483.74  (KL=97307.68, NLL=176.06)
 batch_loss: 19539.04  |  KL+NLL: 97444.21  (KL=97342.47, NLL=101.74)


Epoch 5/10:  44%|████▍     | 160/360 [00:07<00:09, 20.70it/s]

 batch_loss: 19445.30  |  KL+NLL: 97343.02  (KL=97199.61, NLL=143.41)
 batch_loss: 19492.96  |  KL+NLL: 97562.04  (KL=97405.09, NLL=156.95)
 batch_loss: 19242.28  |  KL+NLL: 97565.44  (KL=97362.09, NLL=203.35)
 batch_loss: 19275.33  |  KL+NLL: 97488.94  (KL=97349.44, NLL=139.50)
 batch_loss: 19377.79  |  KL+NLL: 97552.70  (KL=97395.62, NLL=157.07)


Epoch 5/10:  46%|████▌     | 166/360 [00:08<00:09, 21.08it/s]

 batch_loss: 19089.78  |  KL+NLL: 97522.13  (KL=97406.36, NLL=115.77)
 batch_loss: 19002.51  |  KL+NLL: 97516.16  (KL=97385.34, NLL=130.82)
 batch_loss: 19226.48  |  KL+NLL: 97589.89  (KL=97450.31, NLL=139.58)
 batch_loss: 19216.64  |  KL+NLL: 97527.63  (KL=97380.30, NLL=147.34)
 batch_loss: 18879.01  |  KL+NLL: 97504.49  (KL=97408.30, NLL=96.18)


Epoch 5/10:  48%|████▊     | 172/360 [00:08<00:09, 20.85it/s]

 batch_loss: 19251.76  |  KL+NLL: 97561.24  (KL=97440.46, NLL=120.78)
 batch_loss: 19068.09  |  KL+NLL: 97484.04  (KL=97403.09, NLL=80.95)
 batch_loss: 19224.92  |  KL+NLL: 97521.62  (KL=97406.95, NLL=114.68)
 batch_loss: 18966.31  |  KL+NLL: 97556.55  (KL=97427.80, NLL=128.75)
 batch_loss: 18870.64  |  KL+NLL: 97563.66  (KL=97400.46, NLL=163.20)


Epoch 5/10:  49%|████▊     | 175/360 [00:08<00:09, 20.31it/s]

 batch_loss: 18717.03  |  KL+NLL: 97667.19  (KL=97470.99, NLL=196.20)
 batch_loss: 18777.91  |  KL+NLL: 97598.96  (KL=97510.27, NLL=88.70)
 batch_loss: 18937.10  |  KL+NLL: 97624.66  (KL=97510.90, NLL=113.76)
 batch_loss: 18806.41  |  KL+NLL: 97617.36  (KL=97485.74, NLL=131.62)
 batch_loss: 18625.66  |  KL+NLL: 97647.47  (KL=97511.24, NLL=136.23)


Epoch 5/10:  50%|█████     | 181/360 [00:08<00:08, 20.47it/s]

 batch_loss: 18522.16  |  KL+NLL: 97700.32  (KL=97504.34, NLL=195.98)
 batch_loss: 18449.70  |  KL+NLL: 97616.73  (KL=97489.23, NLL=127.51)
 batch_loss: 18733.10  |  KL+NLL: 97601.29  (KL=97459.72, NLL=141.57)
 batch_loss: 18548.60  |  KL+NLL: 97713.89  (KL=97585.62, NLL=128.27)


Epoch 5/10:  51%|█████     | 184/360 [00:08<00:08, 20.27it/s]

 batch_loss: 18449.76  |  KL+NLL: 97696.17  (KL=97554.92, NLL=141.25)
 batch_loss: 18562.10  |  KL+NLL: 97723.90  (KL=97505.84, NLL=218.06)
 batch_loss: 18554.29  |  KL+NLL: 97622.68  (KL=97493.30, NLL=129.38)
 batch_loss: 18468.65  |  KL+NLL: 97694.49  (KL=97594.85, NLL=99.64)
 batch_loss: 18267.89  |  KL+NLL: 97730.72  (KL=97582.52, NLL=148.19)


Epoch 5/10:  52%|█████▎    | 189/360 [00:09<00:08, 19.81it/s]

 batch_loss: 18439.06  |  KL+NLL: 97768.75  (KL=97616.61, NLL=152.14)
 batch_loss: 18433.38  |  KL+NLL: 97631.86  (KL=97498.97, NLL=132.89)
 batch_loss: 18131.36  |  KL+NLL: 97759.49  (KL=97608.98, NLL=150.50)
 batch_loss: 18104.54  |  KL+NLL: 97756.18  (KL=97647.50, NLL=108.68)
 batch_loss: 18207.91  |  KL+NLL: 97784.19  (KL=97566.87, NLL=217.32)


Epoch 5/10:  54%|█████▍    | 195/360 [00:09<00:08, 20.48it/s]

 batch_loss: 18113.49  |  KL+NLL: 97715.12  (KL=97589.39, NLL=125.73)
 batch_loss: 18146.75  |  KL+NLL: 97797.18  (KL=97659.96, NLL=137.22)
 batch_loss: 18115.31  |  KL+NLL: 97887.06  (KL=97743.07, NLL=143.99)
 batch_loss: 18261.97  |  KL+NLL: 97749.88  (KL=97625.63, NLL=124.25)
 batch_loss: 17859.33  |  KL+NLL: 97837.90  (KL=97720.10, NLL=117.80)


Epoch 5/10:  56%|█████▌    | 201/360 [00:09<00:07, 20.76it/s]

 batch_loss: 18157.85  |  KL+NLL: 97813.12  (KL=97703.29, NLL=109.83)
 batch_loss: 18053.01  |  KL+NLL: 97806.32  (KL=97731.35, NLL=74.97)
 batch_loss: 17788.14  |  KL+NLL: 97797.69  (KL=97640.80, NLL=156.89)
 batch_loss: 18008.14  |  KL+NLL: 97813.60  (KL=97673.55, NLL=140.06)
 batch_loss: 17894.09  |  KL+NLL: 97758.66  (KL=97601.33, NLL=157.33)


Epoch 5/10:  57%|█████▋    | 204/360 [00:09<00:07, 21.06it/s]

 batch_loss: 17781.52  |  KL+NLL: 97742.10  (KL=97637.18, NLL=104.92)
 batch_loss: 17907.49  |  KL+NLL: 97906.65  (KL=97759.52, NLL=147.14)
 batch_loss: 17897.87  |  KL+NLL: 97856.71  (KL=97761.48, NLL=95.23)
 batch_loss: 17654.50  |  KL+NLL: 97851.09  (KL=97688.78, NLL=162.30)
 batch_loss: 17735.52  |  KL+NLL: 97852.76  (KL=97749.23, NLL=103.54)


Epoch 5/10:  58%|█████▊    | 210/360 [00:10<00:07, 20.41it/s]

 batch_loss: 17675.16  |  KL+NLL: 97801.78  (KL=97709.20, NLL=92.59)
 batch_loss: 17605.11  |  KL+NLL: 97889.10  (KL=97775.80, NLL=113.30)
 batch_loss: 17537.16  |  KL+NLL: 97944.88  (KL=97799.14, NLL=145.74)
 batch_loss: 17813.25  |  KL+NLL: 97976.31  (KL=97810.26, NLL=166.05)
 batch_loss: 17511.81  |  KL+NLL: 97982.01  (KL=97822.80, NLL=159.21)


Epoch 5/10:  60%|██████    | 216/360 [00:10<00:07, 20.41it/s]

 batch_loss: 17631.10  |  KL+NLL: 98026.19  (KL=97851.54, NLL=174.65)
 batch_loss: 17384.91  |  KL+NLL: 97904.94  (KL=97794.41, NLL=110.53)
 batch_loss: 17450.37  |  KL+NLL: 98089.37  (KL=97866.81, NLL=222.56)
 batch_loss: 17584.35  |  KL+NLL: 97953.53  (KL=97813.61, NLL=139.92)
 batch_loss: 17407.17  |  KL+NLL: 98040.42  (KL=97875.26, NLL=165.16)


Epoch 5/10:  61%|██████    | 219/360 [00:10<00:06, 20.34it/s]

 batch_loss: 17276.74  |  KL+NLL: 97939.96  (KL=97856.30, NLL=83.66)
 batch_loss: 17364.56  |  KL+NLL: 97998.38  (KL=97872.77, NLL=125.62)
 batch_loss: 17355.25  |  KL+NLL: 98036.08  (KL=97921.64, NLL=114.44)
 batch_loss: 17325.25  |  KL+NLL: 97898.07  (KL=97757.95, NLL=140.11)
 batch_loss: 17208.67  |  KL+NLL: 97942.01  (KL=97836.62, NLL=105.39)


Epoch 5/10:  62%|██████▎   | 225/360 [00:10<00:06, 20.11it/s]

 batch_loss: 16970.72  |  KL+NLL: 98047.34  (KL=97911.16, NLL=136.18)
 batch_loss: 16994.26  |  KL+NLL: 98080.39  (KL=97850.96, NLL=229.43)
 batch_loss: 16914.96  |  KL+NLL: 98078.49  (KL=97950.11, NLL=128.38)
 batch_loss: 17087.91  |  KL+NLL: 98084.83  (KL=97987.23, NLL=97.60)
 batch_loss: 16904.85  |  KL+NLL: 98002.98  (KL=97881.21, NLL=121.77)


Epoch 5/10:  64%|██████▍   | 231/360 [00:11<00:06, 20.33it/s]

 batch_loss: 16954.13  |  KL+NLL: 98044.09  (KL=97936.05, NLL=108.04)
 batch_loss: 16970.08  |  KL+NLL: 98081.60  (KL=97936.39, NLL=145.21)
 batch_loss: 16817.43  |  KL+NLL: 98012.77  (KL=97869.30, NLL=143.47)
 batch_loss: 16616.25  |  KL+NLL: 98053.23  (KL=97878.42, NLL=174.81)
 batch_loss: 16809.08  |  KL+NLL: 98154.13  (KL=97967.15, NLL=186.99)


Epoch 5/10:  65%|██████▌   | 234/360 [00:11<00:06, 20.81it/s]

 batch_loss: 17006.16  |  KL+NLL: 98223.44  (KL=98060.61, NLL=162.83)
 batch_loss: 16709.27  |  KL+NLL: 98222.52  (KL=97969.35, NLL=253.17)
 batch_loss: 16450.32  |  KL+NLL: 98202.87  (KL=98060.09, NLL=142.77)
 batch_loss: 16680.82  |  KL+NLL: 98228.29  (KL=98028.53, NLL=199.76)
 batch_loss: 16922.37  |  KL+NLL: 98301.66  (KL=98157.04, NLL=144.62)


Epoch 5/10:  67%|██████▋   | 240/360 [00:11<00:05, 21.51it/s]

 batch_loss: 16647.31  |  KL+NLL: 98060.70  (KL=97911.90, NLL=148.80)
 batch_loss: 16598.14  |  KL+NLL: 98235.55  (KL=98108.56, NLL=126.99)
 batch_loss: 16531.99  |  KL+NLL: 98132.54  (KL=98006.55, NLL=125.99)
 batch_loss: 16676.48  |  KL+NLL: 98252.33  (KL=98074.61, NLL=177.72)
 batch_loss: 16372.43  |  KL+NLL: 98190.71  (KL=98063.62, NLL=127.09)


Epoch 5/10:  68%|██████▊   | 246/360 [00:11<00:05, 20.99it/s]

 batch_loss: 16398.79  |  KL+NLL: 98151.52  (KL=98046.15, NLL=105.37)
 batch_loss: 16656.92  |  KL+NLL: 98181.02  (KL=98051.81, NLL=129.21)
 batch_loss: 16204.43  |  KL+NLL: 98335.67  (KL=98197.93, NLL=137.74)
 batch_loss: 16310.03  |  KL+NLL: 98197.20  (KL=98094.25, NLL=102.95)
 batch_loss: 16416.79  |  KL+NLL: 98193.16  (KL=98066.72, NLL=126.44)


Epoch 5/10:  69%|██████▉   | 249/360 [00:12<00:05, 21.26it/s]

 batch_loss: 16082.36  |  KL+NLL: 98371.56  (KL=98181.74, NLL=189.81)
 batch_loss: 16250.61  |  KL+NLL: 98307.98  (KL=98156.41, NLL=151.56)
 batch_loss: 16117.45  |  KL+NLL: 98201.18  (KL=98083.22, NLL=117.97)
 batch_loss: 16175.86  |  KL+NLL: 98368.56  (KL=98213.86, NLL=154.70)
 batch_loss: 15892.61  |  KL+NLL: 98104.27  (KL=98023.73, NLL=80.54)


Epoch 5/10:  71%|███████   | 255/360 [00:12<00:04, 21.14it/s]

 batch_loss: 16059.10  |  KL+NLL: 98341.54  (KL=98257.83, NLL=83.71)
 batch_loss: 16279.57  |  KL+NLL: 98385.82  (KL=98228.66, NLL=157.17)
 batch_loss: 16155.55  |  KL+NLL: 98289.97  (KL=98168.18, NLL=121.79)
 batch_loss: 16014.11  |  KL+NLL: 98411.15  (KL=98306.80, NLL=104.34)
 batch_loss: 15764.08  |  KL+NLL: 98266.49  (KL=98180.75, NLL=85.74)


Epoch 5/10:  72%|███████▎  | 261/360 [00:12<00:04, 20.78it/s]

 batch_loss: 16077.82  |  KL+NLL: 98434.98  (KL=98219.31, NLL=215.66)
 batch_loss: 16194.34  |  KL+NLL: 98403.64  (KL=98202.89, NLL=200.75)
 batch_loss: 15984.26  |  KL+NLL: 98326.99  (KL=98224.22, NLL=102.77)
 batch_loss: 15846.97  |  KL+NLL: 98325.65  (KL=98173.25, NLL=152.40)
 batch_loss: 15986.05  |  KL+NLL: 98378.92  (KL=98233.18, NLL=145.74)


Epoch 5/10:  73%|███████▎  | 264/360 [00:12<00:04, 20.83it/s]

 batch_loss: 15961.93  |  KL+NLL: 98302.66  (KL=98196.88, NLL=105.79)
 batch_loss: 15887.11  |  KL+NLL: 98318.71  (KL=98154.96, NLL=163.75)
 batch_loss: 15790.65  |  KL+NLL: 98403.73  (KL=98265.40, NLL=138.33)
 batch_loss: 15866.43  |  KL+NLL: 98364.18  (KL=98182.41, NLL=181.77)
 batch_loss: 15682.06  |  KL+NLL: 98407.14  (KL=98235.77, NLL=171.37)


Epoch 5/10:  75%|███████▌  | 270/360 [00:13<00:04, 19.34it/s]

 batch_loss: 15475.76  |  KL+NLL: 98457.52  (KL=98338.76, NLL=118.76)
 batch_loss: 15639.54  |  KL+NLL: 98363.60  (KL=98251.70, NLL=111.90)
 batch_loss: 15769.35  |  KL+NLL: 98445.67  (KL=98306.78, NLL=138.88)
 batch_loss: 15444.74  |  KL+NLL: 98442.98  (KL=98331.18, NLL=111.80)


Epoch 5/10:  76%|███████▌  | 274/360 [00:13<00:05, 16.10it/s]

 batch_loss: 15559.37  |  KL+NLL: 98573.10  (KL=98449.08, NLL=124.02)
 batch_loss: 15691.97  |  KL+NLL: 98565.13  (KL=98384.05, NLL=181.08)
 batch_loss: 15513.12  |  KL+NLL: 98533.38  (KL=98346.52, NLL=186.87)
 batch_loss: 15395.31  |  KL+NLL: 98570.60  (KL=98385.70, NLL=184.90)


Epoch 5/10:  77%|███████▋  | 277/360 [00:13<00:04, 17.49it/s]

 batch_loss: 15442.54  |  KL+NLL: 98494.99  (KL=98394.39, NLL=100.59)
 batch_loss: 15264.28  |  KL+NLL: 98516.87  (KL=98379.33, NLL=137.54)
 batch_loss: 15290.90  |  KL+NLL: 98604.10  (KL=98482.36, NLL=121.74)
 batch_loss: 15221.49  |  KL+NLL: 98621.75  (KL=98401.95, NLL=219.81)
 batch_loss: 15032.91  |  KL+NLL: 98497.18  (KL=98351.72, NLL=145.47)


Epoch 5/10:  78%|███████▊  | 282/360 [00:13<00:04, 18.99it/s]

 batch_loss: 15487.26  |  KL+NLL: 98585.18  (KL=98456.98, NLL=128.20)
 batch_loss: 15187.07  |  KL+NLL: 98490.31  (KL=98395.45, NLL=94.86)
 batch_loss: 15112.67  |  KL+NLL: 98513.92  (KL=98417.17, NLL=96.75)
 batch_loss: 15184.95  |  KL+NLL: 98561.75  (KL=98393.96, NLL=167.79)
 batch_loss: 15183.95  |  KL+NLL: 98667.08  (KL=98516.51, NLL=150.57)


Epoch 5/10:  80%|████████  | 288/360 [00:14<00:03, 20.05it/s]

 batch_loss: 14831.15  |  KL+NLL: 98526.28  (KL=98399.17, NLL=127.11)
 batch_loss: 14994.97  |  KL+NLL: 98791.98  (KL=98515.66, NLL=276.33)
 batch_loss: 14954.31  |  KL+NLL: 98737.50  (KL=98555.04, NLL=182.46)
 batch_loss: 15002.95  |  KL+NLL: 98587.49  (KL=98447.62, NLL=139.86)
 batch_loss: 14712.08  |  KL+NLL: 98629.07  (KL=98487.04, NLL=142.03)


Epoch 5/10:  82%|████████▏ | 294/360 [00:14<00:03, 20.09it/s]

 batch_loss: 14918.51  |  KL+NLL: 98655.58  (KL=98512.74, NLL=142.84)
 batch_loss: 15076.62  |  KL+NLL: 98652.84  (KL=98493.77, NLL=159.07)
 batch_loss: 14982.67  |  KL+NLL: 98573.65  (KL=98429.23, NLL=144.42)
 batch_loss: 14920.03  |  KL+NLL: 98739.91  (KL=98579.62, NLL=160.30)
 batch_loss: 14848.74  |  KL+NLL: 98661.34  (KL=98481.02, NLL=180.32)


Epoch 5/10:  82%|████████▎ | 297/360 [00:14<00:03, 19.71it/s]

 batch_loss: 14652.48  |  KL+NLL: 98666.20  (KL=98515.83, NLL=150.38)
 batch_loss: 14695.24  |  KL+NLL: 98710.09  (KL=98586.76, NLL=123.33)
 batch_loss: 14700.41  |  KL+NLL: 98682.37  (KL=98532.40, NLL=149.98)
 batch_loss: 14679.24  |  KL+NLL: 98783.85  (KL=98634.94, NLL=148.91)
 batch_loss: 14793.52  |  KL+NLL: 98757.28  (KL=98606.62, NLL=150.65)


Epoch 5/10:  84%|████████▍ | 303/360 [00:14<00:02, 20.18it/s]

 batch_loss: 14666.29  |  KL+NLL: 98800.87  (KL=98660.64, NLL=140.23)
 batch_loss: 14742.53  |  KL+NLL: 98741.74  (KL=98608.85, NLL=132.89)
 batch_loss: 14593.50  |  KL+NLL: 98908.72  (KL=98653.70, NLL=255.02)
 batch_loss: 14537.67  |  KL+NLL: 98827.62  (KL=98672.41, NLL=155.21)
 batch_loss: 14616.44  |  KL+NLL: 98851.55  (KL=98661.52, NLL=190.02)


Epoch 5/10:  86%|████████▌ | 309/360 [00:15<00:02, 20.88it/s]

 batch_loss: 14240.31  |  KL+NLL: 99032.78  (KL=98760.91, NLL=271.87)
 batch_loss: 14338.16  |  KL+NLL: 98778.41  (KL=98645.97, NLL=132.44)
 batch_loss: 14454.83  |  KL+NLL: 98911.50  (KL=98787.55, NLL=123.95)
 batch_loss: 14329.78  |  KL+NLL: 98951.05  (KL=98785.46, NLL=165.58)
 batch_loss: 14294.25  |  KL+NLL: 98869.23  (KL=98746.62, NLL=122.62)


Epoch 5/10:  87%|████████▋ | 312/360 [00:15<00:02, 20.70it/s]

 batch_loss: 14332.75  |  KL+NLL: 98777.97  (KL=98656.45, NLL=121.52)
 batch_loss: 14539.10  |  KL+NLL: 98841.70  (KL=98705.35, NLL=136.35)
 batch_loss: 14056.86  |  KL+NLL: 98894.35  (KL=98769.67, NLL=124.68)
 batch_loss: 14227.26  |  KL+NLL: 98835.99  (KL=98671.27, NLL=164.72)
 batch_loss: 14118.12  |  KL+NLL: 98914.11  (KL=98786.97, NLL=127.14)


Epoch 5/10:  88%|████████▊ | 318/360 [00:15<00:02, 20.66it/s]

 batch_loss: 14139.30  |  KL+NLL: 98909.77  (KL=98700.09, NLL=209.68)
 batch_loss: 13987.74  |  KL+NLL: 98856.50  (KL=98740.14, NLL=116.36)
 batch_loss: 14029.88  |  KL+NLL: 98952.66  (KL=98797.54, NLL=155.12)
 batch_loss: 13926.02  |  KL+NLL: 99058.14  (KL=98910.27, NLL=147.88)


Epoch 5/10:  89%|████████▉ | 321/360 [00:15<00:01, 20.48it/s]

 batch_loss: 13883.43  |  KL+NLL: 98959.05  (KL=98804.98, NLL=154.07)
 batch_loss: 14171.23  |  KL+NLL: 98910.81  (KL=98746.96, NLL=163.85)
 batch_loss: 14105.35  |  KL+NLL: 99011.51  (KL=98910.93, NLL=100.58)
 batch_loss: 14050.59  |  KL+NLL: 98933.94  (KL=98810.55, NLL=123.39)
 batch_loss: 14153.68  |  KL+NLL: 99001.11  (KL=98858.17, NLL=142.94)


Epoch 5/10:  91%|█████████ | 327/360 [00:16<00:01, 20.56it/s]

 batch_loss: 13906.63  |  KL+NLL: 98947.54  (KL=98819.78, NLL=127.76)
 batch_loss: 13876.82  |  KL+NLL: 99041.27  (KL=98914.80, NLL=126.46)
 batch_loss: 13792.08  |  KL+NLL: 99034.97  (KL=98912.05, NLL=122.91)
 batch_loss: 13931.06  |  KL+NLL: 99005.05  (KL=98862.38, NLL=142.67)
 batch_loss: 13839.85  |  KL+NLL: 99133.84  (KL=98894.11, NLL=239.73)


Epoch 5/10:  92%|█████████▎| 333/360 [00:16<00:01, 20.69it/s]

 batch_loss: 13618.54  |  KL+NLL: 99003.47  (KL=98904.30, NLL=99.18)
 batch_loss: 13617.95  |  KL+NLL: 99111.54  (KL=98916.78, NLL=194.76)
 batch_loss: 13483.16  |  KL+NLL: 98948.18  (KL=98840.44, NLL=107.74)
 batch_loss: 13744.13  |  KL+NLL: 99173.92  (KL=99018.96, NLL=154.96)
 batch_loss: 13666.52  |  KL+NLL: 99131.90  (KL=99030.33, NLL=101.57)


Epoch 5/10:  93%|█████████▎| 336/360 [00:16<00:01, 20.54it/s]

 batch_loss: 13730.83  |  KL+NLL: 99066.18  (KL=98940.67, NLL=125.51)
 batch_loss: 13699.77  |  KL+NLL: 99099.09  (KL=98958.38, NLL=140.71)
 batch_loss: 13699.03  |  KL+NLL: 99308.97  (KL=99088.87, NLL=220.10)
 batch_loss: 13454.62  |  KL+NLL: 99162.67  (KL=99028.21, NLL=134.46)
 batch_loss: 13824.51  |  KL+NLL: 99074.29  (KL=98910.48, NLL=163.82)


Epoch 5/10:  95%|█████████▌| 342/360 [00:16<00:00, 20.68it/s]

 batch_loss: 13520.61  |  KL+NLL: 99123.82  (KL=98976.71, NLL=147.11)
 batch_loss: 13446.86  |  KL+NLL: 99101.71  (KL=98984.14, NLL=117.57)
 batch_loss: 13578.56  |  KL+NLL: 99163.22  (KL=99009.66, NLL=153.56)
 batch_loss: 13192.91  |  KL+NLL: 99044.62  (KL=98952.91, NLL=91.70)


Epoch 5/10:  96%|█████████▌| 345/360 [00:16<00:00, 20.38it/s]

 batch_loss: 13353.26  |  KL+NLL: 99239.61  (KL=99059.48, NLL=180.14)
 batch_loss: 13161.72  |  KL+NLL: 99250.87  (KL=99096.02, NLL=154.84)
 batch_loss: 13200.59  |  KL+NLL: 99153.15  (KL=99059.83, NLL=93.32)
 batch_loss: 13417.20  |  KL+NLL: 99125.61  (KL=99008.38, NLL=117.23)
 batch_loss: 13080.84  |  KL+NLL: 99308.35  (KL=99173.71, NLL=134.63)


Epoch 5/10:  98%|█████████▊| 351/360 [00:17<00:00, 20.98it/s]

 batch_loss: 13039.00  |  KL+NLL: 99162.87  (KL=99031.77, NLL=131.10)
 batch_loss: 13252.59  |  KL+NLL: 99317.06  (KL=99119.80, NLL=197.26)
 batch_loss: 13179.76  |  KL+NLL: 99214.85  (KL=99055.70, NLL=159.15)
 batch_loss: 12948.98  |  KL+NLL: 99404.39  (KL=99234.83, NLL=169.56)
 batch_loss: 13364.02  |  KL+NLL: 99411.09  (KL=99202.17, NLL=208.92)


Epoch 5/10:  99%|█████████▉| 357/360 [00:17<00:00, 21.30it/s]

 batch_loss: 13102.32  |  KL+NLL: 99400.99  (KL=99213.95, NLL=187.04)
 batch_loss: 13036.51  |  KL+NLL: 99260.48  (KL=99125.38, NLL=135.10)
 batch_loss: 12993.24  |  KL+NLL: 99231.62  (KL=99094.88, NLL=136.75)
 batch_loss: 13068.20  |  KL+NLL: 99338.20  (KL=99208.66, NLL=129.54)
 batch_loss: 12865.98  |  KL+NLL: 99247.58  (KL=99110.51, NLL=137.07)


Epoch 5/10: 100%|██████████| 360/360 [00:17<00:00, 20.45it/s]


 batch_loss: 12847.38  |  KL+NLL: 99379.40  (KL=99224.22, NLL=155.18)
 batch_loss: 12785.89  |  KL+NLL: 99445.10  (KL=99294.74, NLL=150.36)
 batch_loss: 12778.36  |  KL+NLL: 99318.32  (KL=99222.77, NLL=95.55)
Epoch 5 - ELBO Loss: 18900.5907


Epoch 6/10:   1%|          | 3/360 [00:00<00:25, 14.17it/s]

 batch_loss: 12610.99  |  KL+NLL: 99507.32  (KL=99341.23, NLL=166.10)
 batch_loss: 12593.61  |  KL+NLL: 99399.34  (KL=99191.86, NLL=207.48)
 batch_loss: 12738.78  |  KL+NLL: 99454.40  (KL=99312.98, NLL=141.41)
 batch_loss: 12814.65  |  KL+NLL: 99417.10  (KL=99275.80, NLL=141.30)
 batch_loss: 12871.94  |  KL+NLL: 99345.86  (KL=99195.67, NLL=150.19)


Epoch 6/10:   2%|▎         | 9/360 [00:00<00:18, 19.28it/s]

 batch_loss: 12699.69  |  KL+NLL: 99451.55  (KL=99330.70, NLL=120.85)
 batch_loss: 12876.10  |  KL+NLL: 99443.17  (KL=99341.16, NLL=102.02)
 batch_loss: 12478.01  |  KL+NLL: 99490.72  (KL=99342.70, NLL=148.02)
 batch_loss: 12477.21  |  KL+NLL: 99432.39  (KL=99309.27, NLL=123.13)
 batch_loss: 12371.62  |  KL+NLL: 99606.84  (KL=99341.75, NLL=265.09)


Epoch 6/10:   4%|▍         | 15/360 [00:00<00:16, 20.47it/s]

 batch_loss: 12531.98  |  KL+NLL: 99448.12  (KL=99328.82, NLL=119.30)
 batch_loss: 12519.26  |  KL+NLL: 99686.94  (KL=99510.76, NLL=176.18)
 batch_loss: 12419.80  |  KL+NLL: 99583.73  (KL=99449.09, NLL=134.64)
 batch_loss: 12462.99  |  KL+NLL: 99560.27  (KL=99333.77, NLL=226.51)
 batch_loss: 12462.89  |  KL+NLL: 99492.57  (KL=99325.11, NLL=167.46)


Epoch 6/10:   5%|▌         | 18/360 [00:00<00:16, 20.93it/s]

 batch_loss: 12404.52  |  KL+NLL: 99686.59  (KL=99563.87, NLL=122.72)
 batch_loss: 12446.84  |  KL+NLL: 99661.91  (KL=99460.52, NLL=201.38)
 batch_loss: 12709.56  |  KL+NLL: 99552.88  (KL=99411.01, NLL=141.87)
 batch_loss: 12459.86  |  KL+NLL: 99513.90  (KL=99423.09, NLL=90.81)
 batch_loss: 12387.28  |  KL+NLL: 99455.12  (KL=99335.84, NLL=119.27)


Epoch 6/10:   7%|▋         | 24/360 [00:01<00:15, 21.26it/s]

 batch_loss: 12298.57  |  KL+NLL: 99635.38  (KL=99501.99, NLL=133.39)
 batch_loss: 12456.11  |  KL+NLL: 99574.07  (KL=99462.11, NLL=111.96)
 batch_loss: 12167.02  |  KL+NLL: 99615.47  (KL=99499.20, NLL=116.26)
 batch_loss: 12099.73  |  KL+NLL: 99708.78  (KL=99579.20, NLL=129.58)
 batch_loss: 12387.93  |  KL+NLL: 99666.16  (KL=99487.44, NLL=178.72)


Epoch 6/10:   8%|▊         | 27/360 [00:01<00:15, 21.33it/s]

 batch_loss: 12090.22  |  KL+NLL: 99661.17  (KL=99547.36, NLL=113.81)
 batch_loss: 12189.35  |  KL+NLL: 99732.63  (KL=99584.10, NLL=148.53)
 batch_loss: 12128.35  |  KL+NLL: 99651.77  (KL=99482.36, NLL=169.41)
 batch_loss: 12104.49  |  KL+NLL: 99716.52  (KL=99612.79, NLL=103.73)


Epoch 6/10:   9%|▉         | 33/360 [00:01<00:15, 20.94it/s]

 batch_loss: 11916.83  |  KL+NLL: 99691.01  (KL=99541.20, NLL=149.80)
 batch_loss: 12128.40  |  KL+NLL: 99585.66  (KL=99466.88, NLL=118.78)
 batch_loss: 12106.78  |  KL+NLL: 99683.04  (KL=99515.63, NLL=167.41)
 batch_loss: 12022.91  |  KL+NLL: 99774.79  (KL=99675.17, NLL=99.62)
 batch_loss: 11979.37  |  KL+NLL: 99918.63  (KL=99742.80, NLL=175.83)


Epoch 6/10:  11%|█         | 39/360 [00:01<00:15, 20.86it/s]

 batch_loss: 11836.55  |  KL+NLL: 99676.22  (KL=99527.50, NLL=148.72)
 batch_loss: 12055.28  |  KL+NLL: 99846.73  (KL=99765.40, NLL=81.33)
 batch_loss: 11914.72  |  KL+NLL: 99813.09  (KL=99684.00, NLL=129.09)
 batch_loss: 11712.67  |  KL+NLL: 99856.57  (KL=99638.41, NLL=218.17)
 batch_loss: 11936.73  |  KL+NLL: 99849.91  (KL=99682.23, NLL=167.68)


Epoch 6/10:  12%|█▏        | 42/360 [00:02<00:14, 21.24it/s]

 batch_loss: 11830.82  |  KL+NLL: 99743.17  (KL=99622.04, NLL=121.14)
 batch_loss: 11880.84  |  KL+NLL: 100031.08  (KL=99836.20, NLL=194.87)
 batch_loss: 11805.24  |  KL+NLL: 99943.33  (KL=99760.67, NLL=182.66)
 batch_loss: 11809.32  |  KL+NLL: 99882.46  (KL=99763.01, NLL=119.45)
 batch_loss: 11705.26  |  KL+NLL: 99877.47  (KL=99670.01, NLL=207.46)


Epoch 6/10:  13%|█▎        | 48/360 [00:02<00:14, 21.35it/s]

 batch_loss: 11620.44  |  KL+NLL: 99835.52  (KL=99681.45, NLL=154.06)
 batch_loss: 11632.48  |  KL+NLL: 99915.91  (KL=99710.51, NLL=205.40)
 batch_loss: 11636.67  |  KL+NLL: 99940.83  (KL=99801.89, NLL=138.94)
 batch_loss: 11778.88  |  KL+NLL: 99804.16  (KL=99707.58, NLL=96.58)
 batch_loss: 11339.19  |  KL+NLL: 99912.87  (KL=99780.16, NLL=132.71)


Epoch 6/10:  15%|█▌        | 54/360 [00:02<00:14, 21.08it/s]

 batch_loss: 11668.27  |  KL+NLL: 99946.84  (KL=99742.28, NLL=204.56)
 batch_loss: 11581.11  |  KL+NLL: 99990.13  (KL=99830.48, NLL=159.65)
 batch_loss: 11803.76  |  KL+NLL: 100082.99  (KL=99960.59, NLL=122.40)
 batch_loss: 11710.99  |  KL+NLL: 100019.22  (KL=99863.48, NLL=155.74)
 batch_loss: 11471.15  |  KL+NLL: 100034.30  (KL=99905.64, NLL=128.66)


Epoch 6/10:  16%|█▌        | 57/360 [00:02<00:14, 20.63it/s]

 batch_loss: 11547.65  |  KL+NLL: 100077.06  (KL=99893.21, NLL=183.85)
 batch_loss: 11425.38  |  KL+NLL: 100135.96  (KL=99928.95, NLL=207.01)
 batch_loss: 11419.55  |  KL+NLL: 99947.27  (KL=99829.81, NLL=117.46)
 batch_loss: 11229.41  |  KL+NLL: 100025.40  (KL=99877.82, NLL=147.58)
 batch_loss: 11385.35  |  KL+NLL: 99996.33  (KL=99806.20, NLL=190.13)


Epoch 6/10:  18%|█▊        | 63/360 [00:03<00:14, 20.44it/s]

 batch_loss: 11324.55  |  KL+NLL: 100072.28  (KL=99920.24, NLL=152.03)
 batch_loss: 11212.14  |  KL+NLL: 100134.22  (KL=99987.72, NLL=146.50)
 batch_loss: 11326.82  |  KL+NLL: 99987.76  (KL=99855.22, NLL=132.54)
 batch_loss: 11189.02  |  KL+NLL: 100032.32  (KL=99859.88, NLL=172.44)
 batch_loss: 11157.89  |  KL+NLL: 100003.19  (KL=99845.58, NLL=157.61)


Epoch 6/10:  19%|█▉        | 69/360 [00:03<00:14, 20.52it/s]

 batch_loss: 11150.46  |  KL+NLL: 100089.16  (KL=99984.59, NLL=104.57)
 batch_loss: 11049.97  |  KL+NLL: 100072.14  (KL=99940.55, NLL=131.58)
 batch_loss: 11199.41  |  KL+NLL: 100250.21  (KL=100069.50, NLL=180.71)
 batch_loss: 11058.47  |  KL+NLL: 100099.05  (KL=99987.74, NLL=111.31)
 batch_loss: 11061.48  |  KL+NLL: 99994.32  (KL=99859.02, NLL=135.31)


Epoch 6/10:  20%|██        | 72/360 [00:03<00:13, 20.89it/s]

 batch_loss: 11008.25  |  KL+NLL: 100077.38  (KL=99980.15, NLL=97.23)
 batch_loss: 10959.05  |  KL+NLL: 100216.35  (KL=100084.73, NLL=131.62)
 batch_loss: 10887.42  |  KL+NLL: 100191.91  (KL=100064.21, NLL=127.69)
 batch_loss: 10888.79  |  KL+NLL: 100073.85  (KL=99923.30, NLL=150.54)
 batch_loss: 11005.25  |  KL+NLL: 100143.50  (KL=100037.41, NLL=106.09)


Epoch 6/10:  22%|██▏       | 78/360 [00:03<00:13, 20.33it/s]

 batch_loss: 10927.75  |  KL+NLL: 100169.97  (KL=100023.16, NLL=146.81)
 batch_loss: 11051.26  |  KL+NLL: 100237.69  (KL=100074.57, NLL=163.12)
 batch_loss: 11091.89  |  KL+NLL: 100213.12  (KL=100070.90, NLL=142.22)
 batch_loss: 10650.56  |  KL+NLL: 100116.32  (KL=100006.21, NLL=110.11)
 batch_loss: 10845.92  |  KL+NLL: 100204.63  (KL=100030.13, NLL=174.50)


Epoch 6/10:  23%|██▎       | 84/360 [00:04<00:13, 20.04it/s]

 batch_loss: 10925.11  |  KL+NLL: 100239.31  (KL=100126.49, NLL=112.82)
 batch_loss: 10871.65  |  KL+NLL: 100102.77  (KL=99993.64, NLL=109.13)
 batch_loss: 10713.00  |  KL+NLL: 100257.17  (KL=100079.16, NLL=178.01)
 batch_loss: 10913.04  |  KL+NLL: 100278.61  (KL=100084.22, NLL=194.39)
 batch_loss: 10574.92  |  KL+NLL: 100185.18  (KL=100073.79, NLL=111.39)


Epoch 6/10:  24%|██▍       | 87/360 [00:04<00:13, 20.23it/s]

 batch_loss: 10637.49  |  KL+NLL: 100313.36  (KL=100192.09, NLL=121.27)
 batch_loss: 10690.19  |  KL+NLL: 100209.86  (KL=100099.15, NLL=110.71)
 batch_loss: 10463.98  |  KL+NLL: 100270.55  (KL=100169.95, NLL=100.60)
 batch_loss: 10606.06  |  KL+NLL: 100271.37  (KL=100164.62, NLL=106.75)
 batch_loss: 10752.18  |  KL+NLL: 100284.47  (KL=100129.11, NLL=155.36)


Epoch 6/10:  26%|██▌       | 93/360 [00:04<00:12, 20.59it/s]

 batch_loss: 10698.19  |  KL+NLL: 100338.06  (KL=100190.97, NLL=147.09)
 batch_loss: 10559.44  |  KL+NLL: 100285.86  (KL=100158.09, NLL=127.77)
 batch_loss: 10677.00  |  KL+NLL: 100319.00  (KL=100200.50, NLL=118.50)
 batch_loss: 10906.92  |  KL+NLL: 100471.80  (KL=100277.03, NLL=194.76)
 batch_loss: 10501.36  |  KL+NLL: 100381.70  (KL=100254.23, NLL=127.46)


Epoch 6/10:  28%|██▊       | 99/360 [00:04<00:12, 21.07it/s]

 batch_loss: 10324.83  |  KL+NLL: 100430.01  (KL=100312.42, NLL=117.59)
 batch_loss: 10501.24  |  KL+NLL: 100184.57  (KL=100109.77, NLL=74.80)
 batch_loss: 10273.88  |  KL+NLL: 100463.34  (KL=100339.40, NLL=123.94)
 batch_loss: 10577.68  |  KL+NLL: 100430.93  (KL=100271.14, NLL=159.79)
 batch_loss: 10388.51  |  KL+NLL: 100373.53  (KL=100238.02, NLL=135.52)


Epoch 6/10:  28%|██▊       | 102/360 [00:04<00:12, 21.37it/s]

 batch_loss: 10322.92  |  KL+NLL: 100400.13  (KL=100297.97, NLL=102.16)
 batch_loss: 10475.76  |  KL+NLL: 100506.73  (KL=100320.97, NLL=185.76)
 batch_loss: 10162.80  |  KL+NLL: 100491.41  (KL=100396.41, NLL=95.00)
 batch_loss: 10378.71  |  KL+NLL: 100535.08  (KL=100385.61, NLL=149.47)
 batch_loss: 10385.52  |  KL+NLL: 100537.91  (KL=100383.12, NLL=154.79)


Epoch 6/10:  30%|███       | 108/360 [00:05<00:12, 20.79it/s]

 batch_loss: 10320.44  |  KL+NLL: 100540.39  (KL=100411.03, NLL=129.36)
 batch_loss: 10301.64  |  KL+NLL: 100537.06  (KL=100372.27, NLL=164.79)
 batch_loss: 10189.88  |  KL+NLL: 100733.22  (KL=100550.63, NLL=182.59)
 batch_loss: 10036.61  |  KL+NLL: 100732.34  (KL=100567.88, NLL=164.46)


Epoch 6/10:  31%|███       | 111/360 [00:05<00:12, 20.50it/s]

 batch_loss: 10174.24  |  KL+NLL: 100540.62  (KL=100391.35, NLL=149.27)
 batch_loss: 10176.42  |  KL+NLL: 100631.78  (KL=100512.04, NLL=119.75)
 batch_loss: 10120.24  |  KL+NLL: 100492.07  (KL=100362.31, NLL=129.76)
 batch_loss: 10058.45  |  KL+NLL: 100633.80  (KL=100485.81, NLL=147.99)
 batch_loss: 10034.67  |  KL+NLL: 100586.44  (KL=100394.90, NLL=191.54)


Epoch 6/10:  32%|███▎      | 117/360 [00:05<00:11, 21.01it/s]

 batch_loss: 10013.96  |  KL+NLL: 100675.35  (KL=100462.55, NLL=212.79)
 batch_loss: 10002.40  |  KL+NLL: 100672.01  (KL=100515.38, NLL=156.63)
 batch_loss: 9917.68  |  KL+NLL: 100639.43  (KL=100490.12, NLL=149.31)
 batch_loss: 9875.40  |  KL+NLL: 100610.89  (KL=100498.53, NLL=112.36)
 batch_loss: 9987.74  |  KL+NLL: 100673.77  (KL=100584.45, NLL=89.32)


Epoch 6/10:  34%|███▍      | 123/360 [00:05<00:11, 20.76it/s]

 batch_loss: 9936.53  |  KL+NLL: 100617.16  (KL=100483.91, NLL=133.26)
 batch_loss: 9935.29  |  KL+NLL: 100776.96  (KL=100655.27, NLL=121.68)
 batch_loss: 9904.65  |  KL+NLL: 100738.40  (KL=100609.27, NLL=129.14)
 batch_loss: 9917.90  |  KL+NLL: 100589.81  (KL=100344.77, NLL=245.04)
 batch_loss: 10038.74  |  KL+NLL: 100829.73  (KL=100678.16, NLL=151.57)


Epoch 6/10:  35%|███▌      | 126/360 [00:06<00:11, 20.47it/s]

 batch_loss: 9565.53  |  KL+NLL: 100645.73  (KL=100499.29, NLL=146.44)
 batch_loss: 9728.50  |  KL+NLL: 100680.51  (KL=100526.55, NLL=153.95)
 batch_loss: 9551.45  |  KL+NLL: 100780.87  (KL=100700.73, NLL=80.14)
 batch_loss: 9633.17  |  KL+NLL: 100675.67  (KL=100558.90, NLL=116.77)
 batch_loss: 9725.80  |  KL+NLL: 100816.70  (KL=100669.55, NLL=147.15)


Epoch 6/10:  37%|███▋      | 132/360 [00:06<00:11, 19.89it/s]

 batch_loss: 9697.64  |  KL+NLL: 100811.43  (KL=100676.98, NLL=134.45)
 batch_loss: 9630.23  |  KL+NLL: 100850.48  (KL=100719.95, NLL=130.53)
 batch_loss: 9511.11  |  KL+NLL: 100749.49  (KL=100613.70, NLL=135.80)
 batch_loss: 9734.84  |  KL+NLL: 100771.65  (KL=100651.81, NLL=119.83)


Epoch 6/10:  38%|███▊      | 136/360 [00:06<00:11, 19.60it/s]

 batch_loss: 9732.58  |  KL+NLL: 100907.21  (KL=100673.90, NLL=233.31)
 batch_loss: 9572.92  |  KL+NLL: 100825.95  (KL=100672.51, NLL=153.44)
 batch_loss: 9590.28  |  KL+NLL: 100839.64  (KL=100725.90, NLL=113.74)
 batch_loss: 9529.05  |  KL+NLL: 100822.28  (KL=100699.88, NLL=122.40)


Epoch 6/10:  39%|███▉      | 141/360 [00:06<00:11, 19.78it/s]

 batch_loss: 9665.33  |  KL+NLL: 100844.68  (KL=100667.50, NLL=177.18)
 batch_loss: 9387.52  |  KL+NLL: 100915.60  (KL=100801.64, NLL=113.95)
 batch_loss: 9583.99  |  KL+NLL: 101020.28  (KL=100805.33, NLL=214.95)
 batch_loss: 9374.92  |  KL+NLL: 100870.95  (KL=100723.27, NLL=147.68)
 batch_loss: 9456.29  |  KL+NLL: 100964.16  (KL=100839.42, NLL=124.74)


Epoch 6/10:  41%|████      | 146/360 [00:07<00:10, 19.77it/s]

 batch_loss: 9522.38  |  KL+NLL: 101037.32  (KL=100873.66, NLL=163.67)
 batch_loss: 9417.07  |  KL+NLL: 100898.88  (KL=100784.55, NLL=114.32)
 batch_loss: 9505.09  |  KL+NLL: 101034.36  (KL=100927.53, NLL=106.83)
 batch_loss: 9481.41  |  KL+NLL: 100982.17  (KL=100781.33, NLL=200.84)
 batch_loss: 9339.69  |  KL+NLL: 101058.91  (KL=100923.66, NLL=135.25)


Epoch 6/10:  41%|████▏     | 149/360 [00:07<00:10, 19.93it/s]

 batch_loss: 9573.83  |  KL+NLL: 101043.90  (KL=100914.64, NLL=129.26)
 batch_loss: 9230.37  |  KL+NLL: 101168.97  (KL=101037.71, NLL=131.25)
 batch_loss: 9390.64  |  KL+NLL: 101126.85  (KL=100975.68, NLL=151.17)
 batch_loss: 9434.95  |  KL+NLL: 101112.79  (KL=100988.09, NLL=124.69)
 batch_loss: 9318.74  |  KL+NLL: 101094.69  (KL=100967.77, NLL=126.91)


Epoch 6/10:  43%|████▎     | 155/360 [00:07<00:10, 20.21it/s]

 batch_loss: 9134.43  |  KL+NLL: 101093.34  (KL=100954.56, NLL=138.78)
 batch_loss: 9315.90  |  KL+NLL: 101142.19  (KL=101007.79, NLL=134.40)
 batch_loss: 9371.52  |  KL+NLL: 100958.42  (KL=100850.45, NLL=107.98)
 batch_loss: 9155.67  |  KL+NLL: 101184.11  (KL=101087.55, NLL=96.55)
 batch_loss: 9123.09  |  KL+NLL: 101127.66  (KL=101016.13, NLL=111.52)


Epoch 6/10:  45%|████▍     | 161/360 [00:07<00:09, 20.65it/s]

 batch_loss: 9085.43  |  KL+NLL: 101084.82  (KL=100915.11, NLL=169.71)
 batch_loss: 9028.80  |  KL+NLL: 101106.77  (KL=100988.85, NLL=117.92)
 batch_loss: 9098.55  |  KL+NLL: 100945.36  (KL=100819.21, NLL=126.15)
 batch_loss: 8916.81  |  KL+NLL: 101106.76  (KL=100970.11, NLL=136.65)
 batch_loss: 9062.92  |  KL+NLL: 101129.13  (KL=100947.12, NLL=182.00)


Epoch 6/10:  46%|████▌     | 164/360 [00:08<00:09, 20.87it/s]

 batch_loss: 8860.05  |  KL+NLL: 101125.07  (KL=100998.48, NLL=126.59)
 batch_loss: 8839.51  |  KL+NLL: 101196.02  (KL=101031.34, NLL=164.68)
 batch_loss: 8909.80  |  KL+NLL: 101147.69  (KL=101034.57, NLL=113.11)
 batch_loss: 8990.03  |  KL+NLL: 101282.29  (KL=101191.12, NLL=91.17)
 batch_loss: 9062.71  |  KL+NLL: 101265.06  (KL=101091.07, NLL=173.99)


Epoch 6/10:  47%|████▋     | 170/360 [00:08<00:09, 20.90it/s]

 batch_loss: 9020.81  |  KL+NLL: 101329.99  (KL=101204.46, NLL=125.53)
 batch_loss: 8922.61  |  KL+NLL: 101209.99  (KL=101069.26, NLL=140.73)
 batch_loss: 8967.76  |  KL+NLL: 101247.87  (KL=101112.30, NLL=135.57)
 batch_loss: 9045.20  |  KL+NLL: 101260.57  (KL=101119.04, NLL=141.53)
 batch_loss: 8785.31  |  KL+NLL: 101237.37  (KL=101088.70, NLL=148.67)


Epoch 6/10:  49%|████▉     | 176/360 [00:08<00:08, 20.93it/s]

 batch_loss: 8900.42  |  KL+NLL: 101281.42  (KL=101144.45, NLL=136.96)
 batch_loss: 8912.77  |  KL+NLL: 101315.74  (KL=101165.38, NLL=150.36)
 batch_loss: 8799.78  |  KL+NLL: 101181.81  (KL=101047.36, NLL=134.45)
 batch_loss: 8831.87  |  KL+NLL: 101253.63  (KL=101060.97, NLL=192.67)
 batch_loss: 8817.99  |  KL+NLL: 101275.90  (KL=101136.95, NLL=138.95)


Epoch 6/10:  50%|████▉     | 179/360 [00:08<00:08, 20.91it/s]

 batch_loss: 8623.12  |  KL+NLL: 101507.18  (KL=101284.65, NLL=222.53)
 batch_loss: 8740.00  |  KL+NLL: 101189.06  (KL=101031.66, NLL=157.40)
 batch_loss: 8823.38  |  KL+NLL: 101394.30  (KL=101237.88, NLL=156.42)
 batch_loss: 8728.21  |  KL+NLL: 101397.48  (KL=101262.09, NLL=135.39)
 batch_loss: 8599.36  |  KL+NLL: 101287.11  (KL=101130.10, NLL=157.01)


Epoch 6/10:  51%|█████▏    | 185/360 [00:09<00:08, 20.84it/s]

 batch_loss: 8560.31  |  KL+NLL: 101280.75  (KL=101150.96, NLL=129.79)
 batch_loss: 8705.80  |  KL+NLL: 101358.16  (KL=101234.28, NLL=123.88)
 batch_loss: 8556.32  |  KL+NLL: 101397.67  (KL=101207.52, NLL=190.15)
 batch_loss: 8648.27  |  KL+NLL: 101395.20  (KL=101300.24, NLL=94.96)
 batch_loss: 8644.47  |  KL+NLL: 101400.47  (KL=101272.74, NLL=127.73)


Epoch 6/10:  53%|█████▎    | 191/360 [00:09<00:07, 21.29it/s]

 batch_loss: 8479.60  |  KL+NLL: 101443.45  (KL=101303.72, NLL=139.73)
 batch_loss: 8432.18  |  KL+NLL: 101352.71  (KL=101294.05, NLL=58.66)
 batch_loss: 8516.48  |  KL+NLL: 101353.93  (KL=101238.07, NLL=115.86)
 batch_loss: 8530.82  |  KL+NLL: 101478.48  (KL=101326.27, NLL=152.22)
 batch_loss: 8541.80  |  KL+NLL: 101645.63  (KL=101528.09, NLL=117.54)


Epoch 6/10:  54%|█████▍    | 194/360 [00:09<00:07, 20.77it/s]

 batch_loss: 8321.28  |  KL+NLL: 101532.44  (KL=101381.53, NLL=150.91)
 batch_loss: 8661.40  |  KL+NLL: 101385.82  (KL=101282.86, NLL=102.96)
 batch_loss: 8556.37  |  KL+NLL: 101542.20  (KL=101393.12, NLL=149.07)
 batch_loss: 8489.52  |  KL+NLL: 101498.25  (KL=101381.69, NLL=116.57)
 batch_loss: 8396.98  |  KL+NLL: 101573.75  (KL=101462.02, NLL=111.73)


Epoch 6/10:  56%|█████▌    | 200/360 [00:09<00:07, 21.07it/s]

 batch_loss: 8543.59  |  KL+NLL: 101516.22  (KL=101392.15, NLL=124.07)
 batch_loss: 8281.94  |  KL+NLL: 101589.79  (KL=101475.29, NLL=114.50)
 batch_loss: 8287.50  |  KL+NLL: 101516.30  (KL=101392.85, NLL=123.45)
 batch_loss: 8398.05  |  KL+NLL: 101631.26  (KL=101492.01, NLL=139.25)
 batch_loss: 8464.40  |  KL+NLL: 101539.13  (KL=101382.69, NLL=156.44)


Epoch 6/10:  57%|█████▋    | 206/360 [00:10<00:07, 20.94it/s]

 batch_loss: 8350.36  |  KL+NLL: 101668.36  (KL=101537.20, NLL=131.16)
 batch_loss: 8270.87  |  KL+NLL: 101712.32  (KL=101605.92, NLL=106.40)
 batch_loss: 8208.45  |  KL+NLL: 101634.72  (KL=101519.82, NLL=114.90)
 batch_loss: 8175.00  |  KL+NLL: 101707.57  (KL=101614.38, NLL=93.19)
 batch_loss: 8235.69  |  KL+NLL: 101660.04  (KL=101492.56, NLL=167.48)


Epoch 6/10:  58%|█████▊    | 209/360 [00:10<00:07, 20.20it/s]

 batch_loss: 8295.92  |  KL+NLL: 101712.99  (KL=101542.16, NLL=170.82)
 batch_loss: 8084.94  |  KL+NLL: 101757.36  (KL=101613.17, NLL=144.19)
 batch_loss: 8139.92  |  KL+NLL: 101810.88  (KL=101581.66, NLL=229.22)
 batch_loss: 7893.86  |  KL+NLL: 101711.33  (KL=101539.50, NLL=171.83)


Epoch 6/10:  60%|█████▉    | 215/360 [00:10<00:07, 20.02it/s]

 batch_loss: 7912.91  |  KL+NLL: 101873.80  (KL=101689.30, NLL=184.49)
 batch_loss: 8116.69  |  KL+NLL: 101883.00  (KL=101747.14, NLL=135.86)
 batch_loss: 8085.02  |  KL+NLL: 101699.31  (KL=101563.28, NLL=136.03)
 batch_loss: 8097.82  |  KL+NLL: 101899.56  (KL=101760.35, NLL=139.21)
 batch_loss: 8004.69  |  KL+NLL: 101727.04  (KL=101637.30, NLL=89.73)


Epoch 6/10:  61%|██████    | 218/360 [00:10<00:07, 20.24it/s]

 batch_loss: 7862.78  |  KL+NLL: 101826.70  (KL=101702.80, NLL=123.90)
 batch_loss: 7839.98  |  KL+NLL: 101797.07  (KL=101675.93, NLL=121.14)
 batch_loss: 7886.39  |  KL+NLL: 101783.81  (KL=101649.04, NLL=134.77)
 batch_loss: 7911.23  |  KL+NLL: 101941.79  (KL=101828.28, NLL=113.51)
 batch_loss: 7956.93  |  KL+NLL: 101861.77  (KL=101649.48, NLL=212.28)


Epoch 6/10:  62%|██████▏   | 224/360 [00:10<00:06, 19.98it/s]

 batch_loss: 7983.06  |  KL+NLL: 101827.84  (KL=101654.68, NLL=173.17)
 batch_loss: 7814.41  |  KL+NLL: 101858.73  (KL=101735.41, NLL=123.32)
 batch_loss: 7881.70  |  KL+NLL: 101944.00  (KL=101817.54, NLL=126.46)
 batch_loss: 7813.66  |  KL+NLL: 101715.56  (KL=101629.70, NLL=85.86)


Epoch 6/10:  63%|██████▎   | 227/360 [00:11<00:06, 20.23it/s]

 batch_loss: 7941.32  |  KL+NLL: 101872.85  (KL=101738.76, NLL=134.09)
 batch_loss: 7957.59  |  KL+NLL: 101904.65  (KL=101740.14, NLL=164.51)
 batch_loss: 7822.40  |  KL+NLL: 101935.02  (KL=101782.70, NLL=152.32)
 batch_loss: 7863.10  |  KL+NLL: 101840.47  (KL=101701.20, NLL=139.28)
 batch_loss: 7726.16  |  KL+NLL: 101855.99  (KL=101745.97, NLL=110.02)


Epoch 6/10:  65%|██████▍   | 233/360 [00:11<00:06, 20.71it/s]

 batch_loss: 7670.65  |  KL+NLL: 101913.01  (KL=101735.87, NLL=177.14)
 batch_loss: 7769.48  |  KL+NLL: 102011.32  (KL=101867.41, NLL=143.91)
 batch_loss: 7652.14  |  KL+NLL: 102002.02  (KL=101853.95, NLL=148.08)
 batch_loss: 7599.17  |  KL+NLL: 102092.29  (KL=101903.12, NLL=189.16)
 batch_loss: 7645.54  |  KL+NLL: 101843.74  (KL=101709.66, NLL=134.08)


Epoch 6/10:  66%|██████▋   | 239/360 [00:11<00:05, 20.80it/s]

 batch_loss: 7662.56  |  KL+NLL: 102056.27  (KL=101876.89, NLL=179.38)
 batch_loss: 7649.25  |  KL+NLL: 101890.69  (KL=101709.46, NLL=181.23)
 batch_loss: 7492.08  |  KL+NLL: 101946.13  (KL=101799.58, NLL=146.55)
 batch_loss: 7572.97  |  KL+NLL: 102162.07  (KL=102056.28, NLL=105.79)
 batch_loss: 7789.56  |  KL+NLL: 102025.70  (KL=101899.46, NLL=126.24)


Epoch 6/10:  67%|██████▋   | 242/360 [00:11<00:05, 20.80it/s]

 batch_loss: 7371.48  |  KL+NLL: 102060.65  (KL=101872.43, NLL=188.22)
 batch_loss: 7559.09  |  KL+NLL: 102006.35  (KL=101879.80, NLL=126.54)
 batch_loss: 7712.42  |  KL+NLL: 102028.27  (KL=101915.58, NLL=112.70)
 batch_loss: 7420.04  |  KL+NLL: 102059.14  (KL=101931.39, NLL=127.75)
 batch_loss: 7374.54  |  KL+NLL: 102130.88  (KL=102003.11, NLL=127.77)


Epoch 6/10:  69%|██████▉   | 248/360 [00:12<00:05, 21.20it/s]

 batch_loss: 7453.99  |  KL+NLL: 102109.69  (KL=101922.72, NLL=186.97)
 batch_loss: 7601.21  |  KL+NLL: 102122.11  (KL=101940.66, NLL=181.45)
 batch_loss: 7418.19  |  KL+NLL: 102298.69  (KL=102163.60, NLL=135.08)
 batch_loss: 7469.39  |  KL+NLL: 102184.95  (KL=102056.42, NLL=128.52)
 batch_loss: 7531.90  |  KL+NLL: 102298.02  (KL=102172.98, NLL=125.04)


Epoch 6/10:  71%|███████   | 254/360 [00:12<00:05, 20.81it/s]

 batch_loss: 7435.42  |  KL+NLL: 102156.13  (KL=102039.69, NLL=116.45)
 batch_loss: 7352.63  |  KL+NLL: 102027.09  (KL=101911.49, NLL=115.60)
 batch_loss: 7370.35  |  KL+NLL: 102249.59  (KL=102130.46, NLL=119.13)
 batch_loss: 7378.79  |  KL+NLL: 102111.72  (KL=101924.70, NLL=187.02)
 batch_loss: 7449.27  |  KL+NLL: 102338.19  (KL=102206.10, NLL=132.09)


Epoch 6/10:  71%|███████▏  | 257/360 [00:12<00:04, 20.78it/s]

 batch_loss: 7370.01  |  KL+NLL: 102295.32  (KL=102169.91, NLL=125.40)
 batch_loss: 7295.07  |  KL+NLL: 102255.98  (KL=102095.16, NLL=160.82)
 batch_loss: 7380.97  |  KL+NLL: 102341.43  (KL=102153.60, NLL=187.83)
 batch_loss: 7360.47  |  KL+NLL: 102241.40  (KL=102088.73, NLL=152.66)
 batch_loss: 7203.16  |  KL+NLL: 102264.90  (KL=102136.78, NLL=128.12)


Epoch 6/10:  73%|███████▎  | 263/360 [00:12<00:04, 20.84it/s]

 batch_loss: 7237.44  |  KL+NLL: 102365.12  (KL=102201.02, NLL=164.10)
 batch_loss: 7166.16  |  KL+NLL: 102206.74  (KL=102098.18, NLL=108.56)
 batch_loss: 7283.63  |  KL+NLL: 102347.14  (KL=102150.40, NLL=196.74)
 batch_loss: 7190.89  |  KL+NLL: 102368.18  (KL=102247.10, NLL=121.08)
 batch_loss: 6880.34  |  KL+NLL: 102341.77  (KL=102230.73, NLL=111.04)


Epoch 6/10:  75%|███████▍  | 269/360 [00:13<00:04, 20.65it/s]

 batch_loss: 7156.28  |  KL+NLL: 102399.53  (KL=102253.38, NLL=146.15)
 batch_loss: 7156.67  |  KL+NLL: 102433.12  (KL=102306.21, NLL=126.91)
 batch_loss: 7117.42  |  KL+NLL: 102343.15  (KL=102242.62, NLL=100.53)
 batch_loss: 7194.08  |  KL+NLL: 102390.10  (KL=102283.49, NLL=106.61)
 batch_loss: 7216.89  |  KL+NLL: 102365.79  (KL=102208.56, NLL=157.23)


Epoch 6/10:  76%|███████▌  | 272/360 [00:13<00:04, 20.98it/s]

 batch_loss: 7073.50  |  KL+NLL: 102491.56  (KL=102303.20, NLL=188.36)
 batch_loss: 7092.34  |  KL+NLL: 102454.47  (KL=102290.55, NLL=163.92)
 batch_loss: 7054.93  |  KL+NLL: 102503.62  (KL=102359.44, NLL=144.18)
 batch_loss: 6904.39  |  KL+NLL: 102529.55  (KL=102406.72, NLL=122.83)
 batch_loss: 7174.00  |  KL+NLL: 102549.30  (KL=102393.15, NLL=156.15)


Epoch 6/10:  77%|███████▋  | 278/360 [00:13<00:03, 21.02it/s]

 batch_loss: 6966.40  |  KL+NLL: 102364.91  (KL=102256.56, NLL=108.35)
 batch_loss: 6777.32  |  KL+NLL: 102469.57  (KL=102361.70, NLL=107.88)
 batch_loss: 6841.21  |  KL+NLL: 102424.29  (KL=102274.99, NLL=149.30)
 batch_loss: 6944.55  |  KL+NLL: 102516.19  (KL=102377.82, NLL=138.37)
 batch_loss: 6919.67  |  KL+NLL: 102671.29  (KL=102464.40, NLL=206.89)


Epoch 6/10:  79%|███████▉  | 284/360 [00:13<00:03, 20.55it/s]

 batch_loss: 6835.15  |  KL+NLL: 102636.38  (KL=102451.75, NLL=184.63)
 batch_loss: 6861.05  |  KL+NLL: 102538.07  (KL=102417.11, NLL=120.96)
 batch_loss: 6901.07  |  KL+NLL: 102482.98  (KL=102379.82, NLL=103.16)
 batch_loss: 6729.70  |  KL+NLL: 102493.66  (KL=102363.28, NLL=130.38)
 batch_loss: 6940.27  |  KL+NLL: 102593.70  (KL=102483.56, NLL=110.14)


Epoch 6/10:  80%|███████▉  | 287/360 [00:13<00:03, 21.02it/s]

 batch_loss: 6848.27  |  KL+NLL: 102602.72  (KL=102478.31, NLL=124.40)
 batch_loss: 6787.96  |  KL+NLL: 102639.67  (KL=102494.81, NLL=144.86)
 batch_loss: 7005.61  |  KL+NLL: 102501.31  (KL=102350.05, NLL=151.27)
 batch_loss: 6928.20  |  KL+NLL: 102721.37  (KL=102559.67, NLL=161.70)
 batch_loss: 6751.31  |  KL+NLL: 102740.49  (KL=102631.38, NLL=109.11)


Epoch 6/10:  81%|████████▏ | 293/360 [00:14<00:03, 20.90it/s]

 batch_loss: 6779.70  |  KL+NLL: 102747.33  (KL=102627.10, NLL=120.23)
 batch_loss: 6777.75  |  KL+NLL: 102534.27  (KL=102408.10, NLL=126.16)
 batch_loss: 6799.69  |  KL+NLL: 102758.36  (KL=102618.69, NLL=139.68)
 batch_loss: 6571.71  |  KL+NLL: 102535.80  (KL=102432.02, NLL=103.78)
 batch_loss: 6652.73  |  KL+NLL: 102719.60  (KL=102574.23, NLL=145.37)


Epoch 6/10:  83%|████████▎ | 299/360 [00:14<00:02, 21.34it/s]

 batch_loss: 6698.34  |  KL+NLL: 102864.87  (KL=102742.20, NLL=122.67)
 batch_loss: 6694.73  |  KL+NLL: 102617.09  (KL=102474.75, NLL=142.34)
 batch_loss: 6792.14  |  KL+NLL: 102699.69  (KL=102539.27, NLL=160.42)
 batch_loss: 6621.67  |  KL+NLL: 102706.86  (KL=102575.91, NLL=130.95)
 batch_loss: 6653.21  |  KL+NLL: 102681.70  (KL=102599.06, NLL=82.64)


Epoch 6/10:  84%|████████▍ | 302/360 [00:14<00:02, 21.63it/s]

 batch_loss: 6631.76  |  KL+NLL: 102554.66  (KL=102446.62, NLL=108.03)
 batch_loss: 6634.60  |  KL+NLL: 102730.41  (KL=102605.49, NLL=124.92)
 batch_loss: 6545.00  |  KL+NLL: 102723.23  (KL=102612.49, NLL=110.74)
 batch_loss: 6698.68  |  KL+NLL: 102783.75  (KL=102669.45, NLL=114.30)
 batch_loss: 6504.70  |  KL+NLL: 102756.68  (KL=102596.62, NLL=160.05)


Epoch 6/10:  86%|████████▌ | 308/360 [00:14<00:02, 20.71it/s]

 batch_loss: 6565.54  |  KL+NLL: 102850.96  (KL=102675.14, NLL=175.82)
 batch_loss: 6612.23  |  KL+NLL: 102993.07  (KL=102805.47, NLL=187.60)
 batch_loss: 6413.80  |  KL+NLL: 102904.42  (KL=102727.16, NLL=177.26)
 batch_loss: 6638.05  |  KL+NLL: 102817.86  (KL=102696.33, NLL=121.53)
 batch_loss: 6470.63  |  KL+NLL: 102981.72  (KL=102817.06, NLL=164.66)


Epoch 6/10:  87%|████████▋ | 314/360 [00:15<00:02, 20.70it/s]

 batch_loss: 6489.36  |  KL+NLL: 102843.38  (KL=102648.15, NLL=195.23)
 batch_loss: 6574.98  |  KL+NLL: 103053.30  (KL=102960.63, NLL=92.67)
 batch_loss: 6616.92  |  KL+NLL: 102825.24  (KL=102713.17, NLL=112.07)
 batch_loss: 6452.55  |  KL+NLL: 102883.31  (KL=102704.52, NLL=178.79)
 batch_loss: 6452.51  |  KL+NLL: 102907.66  (KL=102775.70, NLL=131.97)


Epoch 6/10:  88%|████████▊ | 317/360 [00:15<00:02, 20.90it/s]

 batch_loss: 6380.90  |  KL+NLL: 102991.34  (KL=102836.07, NLL=155.27)
 batch_loss: 6321.39  |  KL+NLL: 102990.22  (KL=102837.12, NLL=153.09)
 batch_loss: 6398.87  |  KL+NLL: 102869.95  (KL=102725.42, NLL=144.53)
 batch_loss: 6390.63  |  KL+NLL: 102952.57  (KL=102815.88, NLL=136.69)
 batch_loss: 6217.66  |  KL+NLL: 102924.43  (KL=102831.67, NLL=92.76)


Epoch 6/10:  90%|████████▉ | 323/360 [00:15<00:01, 20.45it/s]

 batch_loss: 6280.30  |  KL+NLL: 102960.98  (KL=102794.28, NLL=166.70)
 batch_loss: 6257.55  |  KL+NLL: 103009.30  (KL=102913.40, NLL=95.90)
 batch_loss: 6155.75  |  KL+NLL: 102965.85  (KL=102841.46, NLL=124.39)
 batch_loss: 6292.64  |  KL+NLL: 103121.93  (KL=102983.46, NLL=138.47)
 batch_loss: 6306.19  |  KL+NLL: 103024.89  (KL=102910.62, NLL=114.27)


Epoch 6/10:  91%|█████████▏| 329/360 [00:15<00:01, 20.51it/s]

 batch_loss: 6228.53  |  KL+NLL: 102971.41  (KL=102825.73, NLL=145.69)
 batch_loss: 6278.87  |  KL+NLL: 103052.26  (KL=102903.48, NLL=148.78)
 batch_loss: 6220.86  |  KL+NLL: 103199.28  (KL=103087.02, NLL=112.26)
 batch_loss: 6134.48  |  KL+NLL: 103051.29  (KL=102894.70, NLL=156.59)
 batch_loss: 6199.52  |  KL+NLL: 103026.52  (KL=102868.92, NLL=157.60)


Epoch 6/10:  92%|█████████▏| 332/360 [00:16<00:01, 20.77it/s]

 batch_loss: 6320.64  |  KL+NLL: 103153.66  (KL=102926.36, NLL=227.30)
 batch_loss: 6235.56  |  KL+NLL: 103181.15  (KL=103047.63, NLL=133.52)
 batch_loss: 6149.50  |  KL+NLL: 103126.61  (KL=102988.04, NLL=138.57)
 batch_loss: 6111.58  |  KL+NLL: 103428.20  (KL=103300.59, NLL=127.61)
 batch_loss: 6232.55  |  KL+NLL: 103056.44  (KL=102893.99, NLL=162.45)


Epoch 6/10:  94%|█████████▍| 338/360 [00:16<00:01, 20.04it/s]

 batch_loss: 6123.13  |  KL+NLL: 103123.72  (KL=103007.88, NLL=115.84)
 batch_loss: 6082.56  |  KL+NLL: 102980.76  (KL=102886.93, NLL=93.83)
 batch_loss: 5980.86  |  KL+NLL: 103142.56  (KL=103012.45, NLL=130.10)
 batch_loss: 5953.26  |  KL+NLL: 103333.35  (KL=103156.01, NLL=177.34)


Epoch 6/10:  95%|█████████▍| 341/360 [00:16<00:00, 20.14it/s]

 batch_loss: 6128.37  |  KL+NLL: 103243.28  (KL=103115.10, NLL=128.18)
 batch_loss: 6044.42  |  KL+NLL: 103387.41  (KL=103264.36, NLL=123.05)
 batch_loss: 5953.89  |  KL+NLL: 103265.48  (KL=103137.66, NLL=127.81)
 batch_loss: 5995.40  |  KL+NLL: 103127.39  (KL=103001.04, NLL=126.35)
 batch_loss: 5948.83  |  KL+NLL: 103342.40  (KL=103170.49, NLL=171.91)


Epoch 6/10:  96%|█████████▋| 347/360 [00:16<00:00, 20.83it/s]

 batch_loss: 6150.89  |  KL+NLL: 103383.42  (KL=103242.59, NLL=140.83)
 batch_loss: 5979.41  |  KL+NLL: 103365.10  (KL=103257.93, NLL=107.17)
 batch_loss: 5867.76  |  KL+NLL: 103287.77  (KL=103184.52, NLL=103.25)
 batch_loss: 5850.16  |  KL+NLL: 103199.07  (KL=103055.92, NLL=143.14)
 batch_loss: 5845.39  |  KL+NLL: 103362.65  (KL=103224.34, NLL=138.31)


Epoch 6/10:  98%|█████████▊| 353/360 [00:17<00:00, 20.84it/s]

 batch_loss: 5776.84  |  KL+NLL: 103364.17  (KL=103260.50, NLL=103.67)
 batch_loss: 5925.00  |  KL+NLL: 103302.63  (KL=103188.80, NLL=113.83)
 batch_loss: 5821.60  |  KL+NLL: 103225.16  (KL=103090.55, NLL=134.61)
 batch_loss: 5928.98  |  KL+NLL: 103280.68  (KL=103134.24, NLL=146.44)
 batch_loss: 5873.84  |  KL+NLL: 103361.22  (KL=103240.70, NLL=120.52)


Epoch 6/10:  99%|█████████▉| 356/360 [00:17<00:00, 21.36it/s]

 batch_loss: 5758.55  |  KL+NLL: 103364.08  (KL=103274.80, NLL=89.27)
 batch_loss: 5662.77  |  KL+NLL: 103441.48  (KL=103324.55, NLL=116.93)
 batch_loss: 5807.94  |  KL+NLL: 103522.42  (KL=103342.27, NLL=180.16)
 batch_loss: 5802.75  |  KL+NLL: 103420.42  (KL=103310.17, NLL=110.25)
 batch_loss: 5835.15  |  KL+NLL: 103481.51  (KL=103314.23, NLL=167.28)


Epoch 6/10: 100%|██████████| 360/360 [00:17<00:00, 20.64it/s]


 batch_loss: 5735.71  |  KL+NLL: 103431.72  (KL=103267.98, NLL=163.75)
 batch_loss: 5731.25  |  KL+NLL: 103452.25  (KL=103365.23, NLL=87.02)
Epoch 6 - ELBO Loss: 8889.8625


Epoch 7/10:   0%|          | 1/360 [00:00<00:50,  7.15it/s]

 batch_loss: 5814.94  |  KL+NLL: 103367.12  (KL=103270.40, NLL=96.72)


Epoch 7/10:   1%|          | 4/360 [00:00<00:23, 15.25it/s]

 batch_loss: 5794.66  |  KL+NLL: 103571.31  (KL=103418.62, NLL=152.69)
 batch_loss: 5687.04  |  KL+NLL: 103540.78  (KL=103442.35, NLL=98.43)
 batch_loss: 5791.85  |  KL+NLL: 103590.47  (KL=103445.60, NLL=144.87)
 batch_loss: 5750.66  |  KL+NLL: 103425.25  (KL=103290.80, NLL=134.45)
 batch_loss: 5587.91  |  KL+NLL: 103536.53  (KL=103379.85, NLL=156.68)


Epoch 7/10:   3%|▎         | 10/360 [00:00<00:18, 18.78it/s]

 batch_loss: 5624.63  |  KL+NLL: 103398.06  (KL=103274.00, NLL=124.06)
 batch_loss: 5614.18  |  KL+NLL: 103474.56  (KL=103353.31, NLL=121.25)
 batch_loss: 5707.62  |  KL+NLL: 103609.20  (KL=103465.40, NLL=143.80)
 batch_loss: 5638.04  |  KL+NLL: 103517.16  (KL=103402.42, NLL=114.73)
 batch_loss: 5604.76  |  KL+NLL: 103539.39  (KL=103448.29, NLL=91.10)


Epoch 7/10:   4%|▎         | 13/360 [00:00<00:17, 19.70it/s]

 batch_loss: 5575.30  |  KL+NLL: 103626.30  (KL=103478.30, NLL=147.99)
 batch_loss: 5576.62  |  KL+NLL: 103726.73  (KL=103542.81, NLL=183.92)
 batch_loss: 5504.32  |  KL+NLL: 103660.81  (KL=103482.65, NLL=178.16)
 batch_loss: 5549.30  |  KL+NLL: 103723.71  (KL=103599.24, NLL=124.47)


Epoch 7/10:   4%|▍         | 16/360 [00:00<00:17, 20.11it/s]

 batch_loss: 5666.23  |  KL+NLL: 103738.77  (KL=103629.84, NLL=108.93)


Epoch 7/10:   5%|▌         | 19/360 [00:01<00:16, 20.28it/s]

 batch_loss: 5530.72  |  KL+NLL: 103582.69  (KL=103453.51, NLL=129.18)
 batch_loss: 5434.89  |  KL+NLL: 103717.99  (KL=103583.97, NLL=134.02)
 batch_loss: 5447.76  |  KL+NLL: 103824.95  (KL=103696.42, NLL=128.52)
 batch_loss: 5609.52  |  KL+NLL: 103756.32  (KL=103634.00, NLL=122.32)
 batch_loss: 5549.50  |  KL+NLL: 103662.41  (KL=103521.02, NLL=141.38)


Epoch 7/10:   7%|▋         | 25/360 [00:01<00:16, 20.39it/s]

 batch_loss: 5378.31  |  KL+NLL: 103870.00  (KL=103739.80, NLL=130.19)
 batch_loss: 5424.30  |  KL+NLL: 103786.01  (KL=103641.09, NLL=144.92)
 batch_loss: 5459.41  |  KL+NLL: 103679.98  (KL=103535.02, NLL=144.96)
 batch_loss: 5598.84  |  KL+NLL: 103647.71  (KL=103524.09, NLL=123.62)
 batch_loss: 5550.65  |  KL+NLL: 103737.47  (KL=103596.84, NLL=140.64)


Epoch 7/10:   8%|▊         | 28/360 [00:01<00:16, 20.16it/s]

 batch_loss: 5404.21  |  KL+NLL: 103609.63  (KL=103460.98, NLL=148.65)
 batch_loss: 5458.21  |  KL+NLL: 103885.11  (KL=103755.81, NLL=129.29)
 batch_loss: 5497.26  |  KL+NLL: 103802.19  (KL=103691.41, NLL=110.78)
 batch_loss: 5494.63  |  KL+NLL: 103952.12  (KL=103794.89, NLL=157.23)


Epoch 7/10:   9%|▊         | 31/360 [00:01<00:15, 20.79it/s]

 batch_loss: 5496.70  |  KL+NLL: 103817.02  (KL=103680.05, NLL=136.97)


Epoch 7/10:   9%|▉         | 34/360 [00:01<00:15, 20.73it/s]

 batch_loss: 5355.31  |  KL+NLL: 103699.52  (KL=103557.35, NLL=142.16)
 batch_loss: 5204.37  |  KL+NLL: 103833.03  (KL=103678.80, NLL=154.23)
 batch_loss: 5153.86  |  KL+NLL: 103867.70  (KL=103762.02, NLL=105.68)
 batch_loss: 5356.60  |  KL+NLL: 103719.05  (KL=103623.26, NLL=95.80)
 batch_loss: 5362.80  |  KL+NLL: 103870.29  (KL=103770.96, NLL=99.33)


Epoch 7/10:  11%|█         | 40/360 [00:02<00:15, 20.77it/s]

 batch_loss: 5432.19  |  KL+NLL: 103832.62  (KL=103710.79, NLL=121.83)
 batch_loss: 5325.51  |  KL+NLL: 103941.07  (KL=103805.59, NLL=135.48)
 batch_loss: 5143.72  |  KL+NLL: 103860.23  (KL=103725.14, NLL=135.09)
 batch_loss: 5149.49  |  KL+NLL: 104020.89  (KL=103901.54, NLL=119.35)
 batch_loss: 5239.57  |  KL+NLL: 103930.94  (KL=103802.66, NLL=128.28)


Epoch 7/10:  12%|█▏        | 43/360 [00:02<00:15, 21.01it/s]

 batch_loss: 5170.20  |  KL+NLL: 103829.30  (KL=103692.97, NLL=136.33)
 batch_loss: 5162.67  |  KL+NLL: 103991.75  (KL=103815.88, NLL=175.87)
 batch_loss: 5243.40  |  KL+NLL: 103768.44  (KL=103680.36, NLL=88.08)
 batch_loss: 5211.27  |  KL+NLL: 103817.64  (KL=103712.01, NLL=105.63)


Epoch 7/10:  13%|█▎        | 46/360 [00:02<00:14, 21.16it/s]

 batch_loss: 5220.63  |  KL+NLL: 104025.89  (KL=103823.23, NLL=202.66)


Epoch 7/10:  14%|█▎        | 49/360 [00:02<00:14, 20.80it/s]

 batch_loss: 5179.37  |  KL+NLL: 104026.66  (KL=103896.49, NLL=130.17)
 batch_loss: 5297.69  |  KL+NLL: 103825.52  (KL=103671.95, NLL=153.57)
 batch_loss: 4998.95  |  KL+NLL: 104075.23  (KL=103847.41, NLL=227.83)
 batch_loss: 4960.11  |  KL+NLL: 103934.09  (KL=103834.24, NLL=99.84)
 batch_loss: 5005.30  |  KL+NLL: 104082.93  (KL=103950.76, NLL=132.17)


Epoch 7/10:  15%|█▌        | 55/360 [00:02<00:14, 20.96it/s]

 batch_loss: 5235.61  |  KL+NLL: 104172.96  (KL=104043.09, NLL=129.87)
 batch_loss: 5163.92  |  KL+NLL: 104345.24  (KL=104212.37, NLL=132.87)
 batch_loss: 5102.34  |  KL+NLL: 103958.14  (KL=103834.00, NLL=124.14)
 batch_loss: 5126.75  |  KL+NLL: 103980.10  (KL=103812.13, NLL=167.97)
 batch_loss: 5196.24  |  KL+NLL: 104219.96  (KL=103994.11, NLL=225.85)


Epoch 7/10:  16%|█▌        | 58/360 [00:02<00:14, 20.65it/s]

 batch_loss: 5090.87  |  KL+NLL: 104113.54  (KL=103982.81, NLL=130.72)
 batch_loss: 5077.45  |  KL+NLL: 104021.08  (KL=103804.45, NLL=216.63)
 batch_loss: 5136.96  |  KL+NLL: 104130.12  (KL=104042.19, NLL=87.93)
 batch_loss: 4975.27  |  KL+NLL: 104210.84  (KL=104087.63, NLL=123.21)


Epoch 7/10:  17%|█▋        | 61/360 [00:03<00:14, 20.57it/s]

 batch_loss: 5068.45  |  KL+NLL: 104000.26  (KL=103894.26, NLL=106.00)


Epoch 7/10:  18%|█▊        | 64/360 [00:03<00:14, 20.56it/s]

 batch_loss: 5048.49  |  KL+NLL: 104342.08  (KL=104213.51, NLL=128.57)
 batch_loss: 4960.69  |  KL+NLL: 104240.98  (KL=104139.30, NLL=101.68)
 batch_loss: 4962.48  |  KL+NLL: 104264.90  (KL=104167.31, NLL=97.59)
 batch_loss: 5059.92  |  KL+NLL: 104240.85  (KL=104111.39, NLL=129.45)
 batch_loss: 5049.74  |  KL+NLL: 104250.08  (KL=104054.58, NLL=195.50)


Epoch 7/10:  19%|█▉        | 70/360 [00:03<00:13, 20.83it/s]

 batch_loss: 5031.25  |  KL+NLL: 104116.60  (KL=103923.99, NLL=192.61)
 batch_loss: 4920.33  |  KL+NLL: 104318.81  (KL=104189.33, NLL=129.48)
 batch_loss: 4832.45  |  KL+NLL: 104148.82  (KL=104001.74, NLL=147.08)
 batch_loss: 4957.17  |  KL+NLL: 104427.52  (KL=104233.35, NLL=194.16)
 batch_loss: 4930.14  |  KL+NLL: 104484.46  (KL=104295.74, NLL=188.72)


Epoch 7/10:  20%|██        | 73/360 [00:03<00:13, 20.74it/s]

 batch_loss: 4845.97  |  KL+NLL: 104159.15  (KL=103982.12, NLL=177.03)
 batch_loss: 4903.10  |  KL+NLL: 104212.65  (KL=104104.29, NLL=108.36)
 batch_loss: 4850.74  |  KL+NLL: 104355.14  (KL=104250.27, NLL=104.87)
 batch_loss: 4931.22  |  KL+NLL: 104300.90  (KL=104167.75, NLL=133.15)


Epoch 7/10:  21%|██        | 76/360 [00:03<00:13, 20.99it/s]

 batch_loss: 4935.67  |  KL+NLL: 104128.17  (KL=104026.81, NLL=101.36)


Epoch 7/10:  22%|██▏       | 79/360 [00:03<00:13, 21.34it/s]

 batch_loss: 4846.87  |  KL+NLL: 104134.67  (KL=104009.20, NLL=125.47)
 batch_loss: 4722.18  |  KL+NLL: 104276.69  (KL=104171.70, NLL=104.99)
 batch_loss: 4839.56  |  KL+NLL: 104455.44  (KL=104350.00, NLL=105.44)
 batch_loss: 4789.29  |  KL+NLL: 104431.76  (KL=104304.50, NLL=127.26)
 batch_loss: 4594.24  |  KL+NLL: 104363.67  (KL=104232.95, NLL=130.72)


Epoch 7/10:  24%|██▎       | 85/360 [00:04<00:12, 21.50it/s]

 batch_loss: 4847.52  |  KL+NLL: 104409.93  (KL=104241.88, NLL=168.06)
 batch_loss: 4769.16  |  KL+NLL: 104565.53  (KL=104467.05, NLL=98.48)
 batch_loss: 4732.05  |  KL+NLL: 104584.86  (KL=104444.14, NLL=140.72)
 batch_loss: 4776.13  |  KL+NLL: 104297.71  (KL=104161.09, NLL=136.62)
 batch_loss: 4756.24  |  KL+NLL: 104485.07  (KL=104384.81, NLL=100.26)


Epoch 7/10:  24%|██▍       | 88/360 [00:04<00:13, 20.80it/s]

 batch_loss: 4665.70  |  KL+NLL: 104347.32  (KL=104245.71, NLL=101.61)
 batch_loss: 4787.10  |  KL+NLL: 104412.46  (KL=104289.88, NLL=122.57)
 batch_loss: 4747.46  |  KL+NLL: 104555.80  (KL=104399.72, NLL=156.08)
 batch_loss: 4590.19  |  KL+NLL: 104469.80  (KL=104367.45, NLL=102.35)


Epoch 7/10:  25%|██▌       | 91/360 [00:04<00:12, 20.90it/s]

 batch_loss: 4642.45  |  KL+NLL: 104458.72  (KL=104348.41, NLL=110.31)


Epoch 7/10:  26%|██▌       | 94/360 [00:04<00:12, 21.04it/s]

 batch_loss: 4649.45  |  KL+NLL: 104320.11  (KL=104181.17, NLL=138.94)
 batch_loss: 4702.53  |  KL+NLL: 104589.11  (KL=104459.40, NLL=129.71)
 batch_loss: 4753.50  |  KL+NLL: 104564.93  (KL=104407.95, NLL=156.98)
 batch_loss: 4784.11  |  KL+NLL: 104648.67  (KL=104473.77, NLL=174.89)
 batch_loss: 4563.59  |  KL+NLL: 104806.84  (KL=104679.55, NLL=127.28)


Epoch 7/10:  28%|██▊       | 100/360 [00:04<00:12, 21.23it/s]

 batch_loss: 4652.98  |  KL+NLL: 104577.15  (KL=104457.38, NLL=119.77)
 batch_loss: 4639.33  |  KL+NLL: 104626.30  (KL=104514.49, NLL=111.81)
 batch_loss: 4612.88  |  KL+NLL: 104412.41  (KL=104264.67, NLL=147.74)
 batch_loss: 4573.94  |  KL+NLL: 104663.31  (KL=104521.83, NLL=141.48)
 batch_loss: 4639.74  |  KL+NLL: 104736.28  (KL=104607.48, NLL=128.80)


Epoch 7/10:  29%|██▊       | 103/360 [00:05<00:12, 20.90it/s]

 batch_loss: 4624.97  |  KL+NLL: 104594.35  (KL=104463.84, NLL=130.51)
 batch_loss: 4525.47  |  KL+NLL: 104590.06  (KL=104479.31, NLL=110.74)
 batch_loss: 4515.14  |  KL+NLL: 104542.92  (KL=104424.50, NLL=118.42)
 batch_loss: 4449.65  |  KL+NLL: 104661.27  (KL=104524.74, NLL=136.53)


Epoch 7/10:  29%|██▉       | 106/360 [00:05<00:12, 20.85it/s]

 batch_loss: 4490.24  |  KL+NLL: 104666.40  (KL=104542.49, NLL=123.91)


Epoch 7/10:  30%|███       | 109/360 [00:05<00:12, 20.82it/s]

 batch_loss: 4507.07  |  KL+NLL: 104718.19  (KL=104604.05, NLL=114.13)
 batch_loss: 4624.81  |  KL+NLL: 104808.75  (KL=104672.60, NLL=136.15)
 batch_loss: 4354.01  |  KL+NLL: 104726.36  (KL=104549.55, NLL=176.81)
 batch_loss: 4520.16  |  KL+NLL: 104722.93  (KL=104604.73, NLL=118.20)
 batch_loss: 4522.49  |  KL+NLL: 104711.95  (KL=104558.78, NLL=153.17)


Epoch 7/10:  31%|███       | 112/360 [00:05<00:11, 20.86it/s]

 batch_loss: 4517.70  |  KL+NLL: 104677.18  (KL=104573.84, NLL=103.34)
 batch_loss: 4411.17  |  KL+NLL: 104774.83  (KL=104602.14, NLL=172.69)
 batch_loss: 4480.51  |  KL+NLL: 104670.53  (KL=104538.17, NLL=132.36)


Epoch 7/10:  32%|███▏      | 115/360 [00:05<00:11, 20.46it/s]

 batch_loss: 4485.85  |  KL+NLL: 104810.96  (KL=104672.80, NLL=138.15)
 batch_loss: 4523.15  |  KL+NLL: 104649.55  (KL=104526.99, NLL=122.55)


Epoch 7/10:  33%|███▎      | 118/360 [00:05<00:11, 20.72it/s]

 batch_loss: 4229.21  |  KL+NLL: 104638.83  (KL=104536.68, NLL=102.15)
 batch_loss: 4351.86  |  KL+NLL: 104959.36  (KL=104845.12, NLL=114.23)
 batch_loss: 4489.07  |  KL+NLL: 104793.23  (KL=104674.74, NLL=118.49)


Epoch 7/10:  34%|███▎      | 121/360 [00:05<00:11, 20.52it/s]

 batch_loss: 4426.55  |  KL+NLL: 104813.56  (KL=104652.44, NLL=161.12)
 batch_loss: 4213.65  |  KL+NLL: 104718.47  (KL=104614.19, NLL=104.28)


Epoch 7/10:  34%|███▍      | 124/360 [00:06<00:11, 20.50it/s]

 batch_loss: 4270.54  |  KL+NLL: 104759.60  (KL=104613.78, NLL=145.81)
 batch_loss: 4384.77  |  KL+NLL: 104778.60  (KL=104690.92, NLL=87.67)
 batch_loss: 4409.76  |  KL+NLL: 104787.81  (KL=104586.91, NLL=200.90)
 batch_loss: 4319.31  |  KL+NLL: 104749.95  (KL=104638.26, NLL=111.69)
 batch_loss: 4360.72  |  KL+NLL: 104987.18  (KL=104854.94, NLL=132.24)


Epoch 7/10:  35%|███▌      | 127/360 [00:06<00:11, 20.62it/s]

 batch_loss: 4195.01  |  KL+NLL: 104841.97  (KL=104725.21, NLL=116.76)
 batch_loss: 4347.37  |  KL+NLL: 105051.60  (KL=104874.62, NLL=176.98)
 batch_loss: 4332.85  |  KL+NLL: 105078.46  (KL=104935.46, NLL=142.99)


Epoch 7/10:  36%|███▌      | 130/360 [00:06<00:11, 20.53it/s]

 batch_loss: 4248.93  |  KL+NLL: 104876.42  (KL=104771.89, NLL=104.53)
 batch_loss: 4256.63  |  KL+NLL: 104835.99  (KL=104693.17, NLL=142.82)


Epoch 7/10:  37%|███▋      | 133/360 [00:06<00:11, 20.32it/s]

 batch_loss: 4257.03  |  KL+NLL: 104995.58  (KL=104828.22, NLL=167.36)
 batch_loss: 4375.07  |  KL+NLL: 105064.40  (KL=104881.22, NLL=183.19)
 batch_loss: 4271.51  |  KL+NLL: 105074.40  (KL=104927.07, NLL=147.33)


Epoch 7/10:  38%|███▊      | 136/360 [00:06<00:10, 20.45it/s]

 batch_loss: 4253.49  |  KL+NLL: 105070.59  (KL=104892.27, NLL=178.32)
 batch_loss: 4212.15  |  KL+NLL: 105023.49  (KL=104868.23, NLL=155.26)


Epoch 7/10:  39%|███▊      | 139/360 [00:06<00:10, 20.30it/s]

 batch_loss: 4242.78  |  KL+NLL: 105236.99  (KL=105043.66, NLL=193.33)
 batch_loss: 4046.15  |  KL+NLL: 105115.77  (KL=105007.72, NLL=108.05)
 batch_loss: 4148.61  |  KL+NLL: 105014.80  (KL=104841.13, NLL=173.67)
 batch_loss: 4118.17  |  KL+NLL: 105039.15  (KL=104874.14, NLL=165.01)
 batch_loss: 4304.81  |  KL+NLL: 104736.28  (KL=104611.88, NLL=124.41)


Epoch 7/10:  39%|███▉      | 142/360 [00:06<00:10, 20.52it/s]

 batch_loss: 4083.46  |  KL+NLL: 105073.49  (KL=104927.85, NLL=145.64)
 batch_loss: 4102.43  |  KL+NLL: 104797.82  (KL=104681.61, NLL=116.21)
 batch_loss: 4259.54  |  KL+NLL: 105182.41  (KL=105028.77, NLL=153.65)


Epoch 7/10:  40%|████      | 145/360 [00:07<00:10, 20.85it/s]

 batch_loss: 4096.97  |  KL+NLL: 105180.66  (KL=105041.11, NLL=139.55)
 batch_loss: 4158.95  |  KL+NLL: 105082.91  (KL=104977.19, NLL=105.72)


Epoch 7/10:  41%|████      | 148/360 [00:07<00:10, 21.10it/s]

 batch_loss: 4217.22  |  KL+NLL: 105038.35  (KL=104918.55, NLL=119.80)
 batch_loss: 4155.51  |  KL+NLL: 105194.77  (KL=105055.45, NLL=139.32)
 batch_loss: 3992.31  |  KL+NLL: 105145.80  (KL=104962.73, NLL=183.07)


Epoch 7/10:  42%|████▏     | 151/360 [00:07<00:09, 21.20it/s]

 batch_loss: 4042.39  |  KL+NLL: 105199.12  (KL=105072.05, NLL=127.06)
 batch_loss: 4187.82  |  KL+NLL: 105001.22  (KL=104856.56, NLL=144.65)


Epoch 7/10:  43%|████▎     | 154/360 [00:07<00:09, 21.52it/s]

 batch_loss: 4088.70  |  KL+NLL: 105033.68  (KL=104918.15, NLL=115.53)
 batch_loss: 4015.37  |  KL+NLL: 105186.07  (KL=105048.37, NLL=137.70)
 batch_loss: 4139.27  |  KL+NLL: 105344.68  (KL=105222.42, NLL=122.26)
 batch_loss: 4096.11  |  KL+NLL: 105063.06  (KL=104941.68, NLL=121.38)
 batch_loss: 4038.01  |  KL+NLL: 105366.05  (KL=105221.35, NLL=144.70)


Epoch 7/10:  44%|████▎     | 157/360 [00:07<00:09, 21.25it/s]

 batch_loss: 3956.27  |  KL+NLL: 105232.91  (KL=105108.82, NLL=124.09)
 batch_loss: 3919.33  |  KL+NLL: 105319.06  (KL=105193.74, NLL=125.32)
 batch_loss: 4033.86  |  KL+NLL: 105292.12  (KL=105168.92, NLL=123.20)


Epoch 7/10:  44%|████▍     | 160/360 [00:07<00:09, 21.02it/s]

 batch_loss: 3902.26  |  KL+NLL: 105438.93  (KL=105274.02, NLL=164.91)
 batch_loss: 3897.84  |  KL+NLL: 105303.94  (KL=105154.85, NLL=149.09)


Epoch 7/10:  45%|████▌     | 163/360 [00:07<00:09, 20.98it/s]

 batch_loss: 3988.26  |  KL+NLL: 105242.96  (KL=105131.38, NLL=111.58)
 batch_loss: 3882.92  |  KL+NLL: 105331.36  (KL=105163.36, NLL=168.00)
 batch_loss: 4035.35  |  KL+NLL: 105264.35  (KL=105092.73, NLL=171.62)


Epoch 7/10:  46%|████▌     | 166/360 [00:08<00:09, 20.89it/s]

 batch_loss: 3921.93  |  KL+NLL: 105477.00  (KL=105328.96, NLL=148.04)
 batch_loss: 3984.79  |  KL+NLL: 105369.68  (KL=105230.59, NLL=139.09)
 batch_loss: 3837.36  |  KL+NLL: 105357.71  (KL=105208.38, NLL=149.33)
 batch_loss: 3942.06  |  KL+NLL: 105483.49  (KL=105322.88, NLL=160.61)


Epoch 7/10:  47%|████▋     | 169/360 [00:08<00:09, 20.50it/s]

 batch_loss: 4009.83  |  KL+NLL: 105303.45  (KL=105172.42, NLL=131.03)
 batch_loss: 3972.86  |  KL+NLL: 105272.43  (KL=105146.24, NLL=126.19)
 batch_loss: 3920.96  |  KL+NLL: 105283.03  (KL=105176.54, NLL=106.49)


Epoch 7/10:  48%|████▊     | 172/360 [00:08<00:09, 20.42it/s]

 batch_loss: 3851.26  |  KL+NLL: 105403.65  (KL=105264.08, NLL=139.57)
 batch_loss: 3694.59  |  KL+NLL: 105144.70  (KL=105011.31, NLL=133.38)


Epoch 7/10:  49%|████▊     | 175/360 [00:08<00:09, 20.47it/s]

 batch_loss: 3790.18  |  KL+NLL: 105306.75  (KL=105176.84, NLL=129.91)
 batch_loss: 3762.79  |  KL+NLL: 105373.29  (KL=105259.41, NLL=113.88)
 batch_loss: 3865.52  |  KL+NLL: 105603.18  (KL=105461.22, NLL=141.96)


Epoch 7/10:  49%|████▉     | 178/360 [00:08<00:08, 20.47it/s]

 batch_loss: 3902.96  |  KL+NLL: 105476.65  (KL=105328.42, NLL=148.23)
 batch_loss: 3756.78  |  KL+NLL: 105436.96  (KL=105285.54, NLL=151.42)


Epoch 7/10:  50%|█████     | 181/360 [00:08<00:08, 20.62it/s]

 batch_loss: 3777.61  |  KL+NLL: 105739.63  (KL=105560.56, NLL=179.07)
 batch_loss: 3906.82  |  KL+NLL: 105445.35  (KL=105332.65, NLL=112.70)
 batch_loss: 3795.25  |  KL+NLL: 105387.56  (KL=105269.41, NLL=118.16)
 batch_loss: 3840.79  |  KL+NLL: 105536.12  (KL=105425.15, NLL=110.97)
 batch_loss: 3739.72  |  KL+NLL: 105441.59  (KL=105347.16, NLL=94.43)


Epoch 7/10:  51%|█████     | 184/360 [00:08<00:08, 21.05it/s]

 batch_loss: 3642.39  |  KL+NLL: 105389.13  (KL=105289.44, NLL=99.69)
 batch_loss: 3707.84  |  KL+NLL: 105586.80  (KL=105479.76, NLL=107.04)
 batch_loss: 3658.95  |  KL+NLL: 105595.68  (KL=105479.92, NLL=115.76)


Epoch 7/10:  52%|█████▏    | 187/360 [00:09<00:08, 21.26it/s]

 batch_loss: 3665.69  |  KL+NLL: 105478.28  (KL=105286.45, NLL=191.83)
 batch_loss: 3747.56  |  KL+NLL: 105476.61  (KL=105246.33, NLL=230.28)


Epoch 7/10:  53%|█████▎    | 190/360 [00:09<00:08, 21.09it/s]

 batch_loss: 3797.38  |  KL+NLL: 105506.87  (KL=105401.59, NLL=105.28)
 batch_loss: 3757.49  |  KL+NLL: 105524.90  (KL=105405.51, NLL=119.39)
 batch_loss: 3753.04  |  KL+NLL: 105633.55  (KL=105499.30, NLL=134.24)


Epoch 7/10:  54%|█████▎    | 193/360 [00:09<00:07, 21.13it/s]

 batch_loss: 3707.91  |  KL+NLL: 105853.75  (KL=105742.85, NLL=110.90)
 batch_loss: 3726.94  |  KL+NLL: 105528.27  (KL=105362.52, NLL=165.75)


Epoch 7/10:  54%|█████▍    | 196/360 [00:09<00:07, 20.82it/s]

 batch_loss: 3579.06  |  KL+NLL: 105549.99  (KL=105413.83, NLL=136.17)
 batch_loss: 3535.79  |  KL+NLL: 105807.92  (KL=105667.95, NLL=139.97)
 batch_loss: 3714.49  |  KL+NLL: 105643.28  (KL=105520.97, NLL=122.31)
 batch_loss: 3760.81  |  KL+NLL: 105682.15  (KL=105532.45, NLL=149.70)
 batch_loss: 3588.96  |  KL+NLL: 105617.09  (KL=105516.38, NLL=100.71)


Epoch 7/10:  55%|█████▌    | 199/360 [00:09<00:07, 20.54it/s]

 batch_loss: 3711.44  |  KL+NLL: 105570.87  (KL=105427.73, NLL=143.14)
 batch_loss: 3713.28  |  KL+NLL: 105577.68  (KL=105456.28, NLL=121.40)
 batch_loss: 3642.61  |  KL+NLL: 105849.31  (KL=105740.34, NLL=108.96)


Epoch 7/10:  56%|█████▌    | 202/360 [00:09<00:07, 20.45it/s]

 batch_loss: 3754.43  |  KL+NLL: 105552.51  (KL=105462.05, NLL=90.45)
 batch_loss: 3510.16  |  KL+NLL: 105690.82  (KL=105589.29, NLL=101.53)


Epoch 7/10:  57%|█████▋    | 205/360 [00:09<00:07, 20.67it/s]

 batch_loss: 3481.14  |  KL+NLL: 105787.01  (KL=105652.18, NLL=134.83)
 batch_loss: 3638.37  |  KL+NLL: 105674.74  (KL=105583.09, NLL=91.65)
 batch_loss: 3646.95  |  KL+NLL: 105693.57  (KL=105581.91, NLL=111.67)


Epoch 7/10:  58%|█████▊    | 208/360 [00:10<00:07, 20.54it/s]

 batch_loss: 3681.64  |  KL+NLL: 105629.40  (KL=105463.74, NLL=165.66)
 batch_loss: 3760.10  |  KL+NLL: 105777.55  (KL=105642.39, NLL=135.16)


Epoch 7/10:  59%|█████▊    | 211/360 [00:10<00:07, 20.70it/s]

 batch_loss: 3752.99  |  KL+NLL: 105816.54  (KL=105694.94, NLL=121.61)
 batch_loss: 3571.86  |  KL+NLL: 105623.15  (KL=105479.55, NLL=143.60)
 batch_loss: 3485.26  |  KL+NLL: 105829.15  (KL=105683.38, NLL=145.77)
 batch_loss: 3511.88  |  KL+NLL: 105893.68  (KL=105764.80, NLL=128.89)


Epoch 7/10:  59%|█████▉    | 214/360 [00:10<00:07, 20.42it/s]

 batch_loss: 3523.93  |  KL+NLL: 105919.23  (KL=105747.12, NLL=172.11)
 batch_loss: 3453.09  |  KL+NLL: 105827.60  (KL=105686.66, NLL=140.94)
 batch_loss: 3559.82  |  KL+NLL: 105939.58  (KL=105791.39, NLL=148.19)
 batch_loss: 3435.64  |  KL+NLL: 105986.41  (KL=105891.21, NLL=95.20)


Epoch 7/10:  61%|██████    | 220/360 [00:10<00:06, 20.31it/s]

 batch_loss: 3551.37  |  KL+NLL: 105870.16  (KL=105700.77, NLL=169.39)
 batch_loss: 3436.38  |  KL+NLL: 105746.57  (KL=105581.20, NLL=165.38)
 batch_loss: 3421.76  |  KL+NLL: 106002.69  (KL=105852.00, NLL=150.69)
 batch_loss: 3431.73  |  KL+NLL: 105932.73  (KL=105808.53, NLL=124.20)
 batch_loss: 3417.32  |  KL+NLL: 105693.56  (KL=105585.18, NLL=108.38)


Epoch 7/10:  62%|██████▏   | 223/360 [00:10<00:06, 20.56it/s]

 batch_loss: 3573.45  |  KL+NLL: 106019.04  (KL=105903.39, NLL=115.65)
 batch_loss: 3328.70  |  KL+NLL: 106011.81  (KL=105893.74, NLL=118.07)
 batch_loss: 3523.51  |  KL+NLL: 106046.83  (KL=105941.30, NLL=105.52)
 batch_loss: 3488.17  |  KL+NLL: 106094.40  (KL=105983.05, NLL=111.35)


Epoch 7/10:  63%|██████▎   | 226/360 [00:10<00:06, 20.46it/s]

 batch_loss: 3419.54  |  KL+NLL: 105892.84  (KL=105772.12, NLL=120.72)


Epoch 7/10:  64%|██████▎   | 229/360 [00:11<00:06, 20.35it/s]

 batch_loss: 3388.68  |  KL+NLL: 106085.23  (KL=105954.23, NLL=131.00)
 batch_loss: 3330.05  |  KL+NLL: 106031.77  (KL=105908.34, NLL=123.43)
 batch_loss: 3396.92  |  KL+NLL: 105847.54  (KL=105673.71, NLL=173.83)
 batch_loss: 3397.90  |  KL+NLL: 105886.77  (KL=105766.37, NLL=120.41)
 batch_loss: 3461.46  |  KL+NLL: 106210.65  (KL=106024.80, NLL=185.86)


Epoch 7/10:  65%|██████▌   | 235/360 [00:11<00:06, 20.22it/s]

 batch_loss: 3511.23  |  KL+NLL: 106027.39  (KL=105832.88, NLL=194.51)
 batch_loss: 3391.29  |  KL+NLL: 106040.28  (KL=105914.59, NLL=125.69)
 batch_loss: 3421.65  |  KL+NLL: 105972.18  (KL=105820.77, NLL=151.42)
 batch_loss: 3400.40  |  KL+NLL: 106056.75  (KL=105932.44, NLL=124.31)
 batch_loss: 3356.83  |  KL+NLL: 106100.51  (KL=105966.30, NLL=134.21)


Epoch 7/10:  66%|██████▌   | 238/360 [00:11<00:05, 20.38it/s]

 batch_loss: 3387.36  |  KL+NLL: 106008.85  (KL=105857.59, NLL=151.26)
 batch_loss: 3413.73  |  KL+NLL: 106096.29  (KL=105988.99, NLL=107.30)
 batch_loss: 3374.30  |  KL+NLL: 105969.54  (KL=105856.79, NLL=112.75)
 batch_loss: 3331.50  |  KL+NLL: 106285.48  (KL=106169.78, NLL=115.70)


Epoch 7/10:  67%|██████▋   | 241/360 [00:11<00:05, 20.43it/s]

 batch_loss: 3374.30  |  KL+NLL: 105863.78  (KL=105716.86, NLL=146.92)


Epoch 7/10:  68%|██████▊   | 244/360 [00:11<00:05, 20.20it/s]

 batch_loss: 3352.06  |  KL+NLL: 105981.85  (KL=105832.86, NLL=148.99)
 batch_loss: 3403.17  |  KL+NLL: 105988.01  (KL=105868.56, NLL=119.45)
 batch_loss: 3339.84  |  KL+NLL: 106082.47  (KL=105929.00, NLL=153.47)
 batch_loss: 3364.84  |  KL+NLL: 106128.99  (KL=105996.74, NLL=132.25)
 batch_loss: 3270.81  |  KL+NLL: 106386.78  (KL=106235.23, NLL=151.55)


Epoch 7/10:  69%|██████▊   | 247/360 [00:12<00:05, 20.02it/s]

 batch_loss: 3276.52  |  KL+NLL: 106062.72  (KL=105930.64, NLL=132.08)
 batch_loss: 3160.42  |  KL+NLL: 106266.24  (KL=106103.08, NLL=163.16)


Epoch 7/10:  69%|██████▉   | 250/360 [00:12<00:05, 20.05it/s]

 batch_loss: 3263.84  |  KL+NLL: 106199.46  (KL=106113.43, NLL=86.03)
 batch_loss: 3153.10  |  KL+NLL: 105959.08  (KL=105836.56, NLL=122.51)
 batch_loss: 3128.19  |  KL+NLL: 106154.09  (KL=106007.94, NLL=146.15)


Epoch 7/10:  70%|███████   | 253/360 [00:12<00:05, 20.37it/s]

 batch_loss: 3306.65  |  KL+NLL: 106275.23  (KL=106173.93, NLL=101.30)
 batch_loss: 3293.44  |  KL+NLL: 106207.12  (KL=106098.23, NLL=108.89)
 batch_loss: 3206.98  |  KL+NLL: 106155.71  (KL=105996.02, NLL=159.69)
 batch_loss: 3312.35  |  KL+NLL: 106307.52  (KL=106172.78, NLL=134.74)


Epoch 7/10:  71%|███████   | 256/360 [00:12<00:05, 20.12it/s]

 batch_loss: 3146.12  |  KL+NLL: 106364.62  (KL=106229.41, NLL=135.21)
 batch_loss: 3131.99  |  KL+NLL: 106256.06  (KL=106121.00, NLL=135.06)
 batch_loss: 3267.50  |  KL+NLL: 106508.94  (KL=106274.22, NLL=234.72)


Epoch 7/10:  72%|███████▏  | 259/360 [00:12<00:04, 20.53it/s]

 batch_loss: 3212.03  |  KL+NLL: 106394.22  (KL=106237.91, NLL=156.31)
 batch_loss: 3148.46  |  KL+NLL: 106367.79  (KL=106226.62, NLL=141.17)


Epoch 7/10:  73%|███████▎  | 262/360 [00:12<00:04, 20.41it/s]

 batch_loss: 3161.52  |  KL+NLL: 106301.65  (KL=106202.17, NLL=99.48)
 batch_loss: 3201.36  |  KL+NLL: 106292.26  (KL=106173.80, NLL=118.46)
 batch_loss: 3080.56  |  KL+NLL: 106236.16  (KL=106087.98, NLL=148.18)


Epoch 7/10:  74%|███████▎  | 265/360 [00:12<00:04, 20.46it/s]

 batch_loss: 3155.07  |  KL+NLL: 106450.70  (KL=106349.02, NLL=101.68)
 batch_loss: 3078.08  |  KL+NLL: 106334.77  (KL=106151.20, NLL=183.57)


Epoch 7/10:  74%|███████▍  | 268/360 [00:13<00:04, 20.56it/s]

 batch_loss: 3139.08  |  KL+NLL: 106417.91  (KL=106270.34, NLL=147.58)
 batch_loss: 3127.99  |  KL+NLL: 106247.98  (KL=106098.40, NLL=149.58)
 batch_loss: 3193.59  |  KL+NLL: 106428.56  (KL=106258.58, NLL=169.98)
 batch_loss: 3087.61  |  KL+NLL: 106359.01  (KL=106209.18, NLL=149.83)
 batch_loss: 2977.53  |  KL+NLL: 106298.80  (KL=106113.94, NLL=184.87)


Epoch 7/10:  75%|███████▌  | 271/360 [00:13<00:04, 20.41it/s]

 batch_loss: 3053.25  |  KL+NLL: 106273.78  (KL=106145.26, NLL=128.52)
 batch_loss: 3138.46  |  KL+NLL: 106442.91  (KL=106318.29, NLL=124.62)
 batch_loss: 3184.11  |  KL+NLL: 106226.42  (KL=106123.66, NLL=102.76)


Epoch 7/10:  76%|███████▌  | 274/360 [00:13<00:04, 20.89it/s]

 batch_loss: 3149.10  |  KL+NLL: 106411.32  (KL=106278.88, NLL=132.44)
 batch_loss: 3116.32  |  KL+NLL: 106437.44  (KL=106323.39, NLL=114.04)


Epoch 7/10:  77%|███████▋  | 277/360 [00:13<00:04, 20.68it/s]

 batch_loss: 3059.12  |  KL+NLL: 106342.66  (KL=106224.60, NLL=118.06)
 batch_loss: 3046.25  |  KL+NLL: 106387.00  (KL=106229.85, NLL=157.15)
 batch_loss: 3007.78  |  KL+NLL: 106894.50  (KL=106666.67, NLL=227.83)
 batch_loss: 3095.00  |  KL+NLL: 106473.43  (KL=106288.34, NLL=185.10)


Epoch 7/10:  78%|███████▊  | 280/360 [00:13<00:03, 20.27it/s]

 batch_loss: 3071.15  |  KL+NLL: 106318.15  (KL=106198.41, NLL=119.74)
 batch_loss: 2967.09  |  KL+NLL: 106416.27  (KL=106232.89, NLL=183.38)


Epoch 7/10:  79%|███████▊  | 283/360 [00:13<00:03, 19.94it/s]

 batch_loss: 2978.21  |  KL+NLL: 106583.36  (KL=106474.44, NLL=108.92)
 batch_loss: 3119.48  |  KL+NLL: 106665.55  (KL=106519.44, NLL=146.11)


Epoch 7/10:  79%|███████▉  | 286/360 [00:13<00:03, 20.32it/s]

 batch_loss: 2927.92  |  KL+NLL: 106421.53  (KL=106252.67, NLL=168.86)
 batch_loss: 3042.95  |  KL+NLL: 106548.33  (KL=106404.09, NLL=144.24)
 batch_loss: 3035.94  |  KL+NLL: 106438.44  (KL=106303.12, NLL=135.32)
 batch_loss: 3065.99  |  KL+NLL: 106579.60  (KL=106438.93, NLL=140.67)
 batch_loss: 2891.53  |  KL+NLL: 106654.41  (KL=106497.45, NLL=156.95)


Epoch 7/10:  80%|████████  | 289/360 [00:14<00:03, 20.18it/s]

 batch_loss: 2912.19  |  KL+NLL: 106574.80  (KL=106452.84, NLL=121.97)
 batch_loss: 2925.30  |  KL+NLL: 106682.27  (KL=106467.41, NLL=214.86)
 batch_loss: 3019.43  |  KL+NLL: 106718.62  (KL=106565.34, NLL=153.28)


Epoch 7/10:  81%|████████  | 292/360 [00:14<00:03, 20.53it/s]

 batch_loss: 3009.80  |  KL+NLL: 106734.73  (KL=106520.73, NLL=214.00)
 batch_loss: 2985.79  |  KL+NLL: 106688.39  (KL=106574.91, NLL=113.48)


Epoch 7/10:  82%|████████▏ | 295/360 [00:14<00:03, 20.71it/s]

 batch_loss: 2859.18  |  KL+NLL: 106563.31  (KL=106449.11, NLL=114.20)
 batch_loss: 2956.99  |  KL+NLL: 106622.94  (KL=106482.46, NLL=140.48)
 batch_loss: 3044.83  |  KL+NLL: 106889.09  (KL=106748.77, NLL=140.32)


Epoch 7/10:  83%|████████▎ | 298/360 [00:14<00:02, 20.90it/s]

 batch_loss: 2916.01  |  KL+NLL: 106653.83  (KL=106552.85, NLL=100.98)
 batch_loss: 2896.36  |  KL+NLL: 106640.79  (KL=106512.64, NLL=128.15)


Epoch 7/10:  84%|████████▎ | 301/360 [00:14<00:02, 21.09it/s]

 batch_loss: 2927.57  |  KL+NLL: 106851.68  (KL=106683.90, NLL=167.79)
 batch_loss: 2776.09  |  KL+NLL: 106547.95  (KL=106440.69, NLL=107.27)
 batch_loss: 2788.23  |  KL+NLL: 106588.23  (KL=106462.16, NLL=126.07)
 batch_loss: 2840.81  |  KL+NLL: 106833.98  (KL=106707.71, NLL=126.27)
 batch_loss: 2861.32  |  KL+NLL: 106694.39  (KL=106568.51, NLL=125.88)


Epoch 7/10:  84%|████████▍ | 304/360 [00:14<00:02, 20.74it/s]

 batch_loss: 2963.74  |  KL+NLL: 106798.44  (KL=106664.18, NLL=134.26)
 batch_loss: 2823.57  |  KL+NLL: 106997.84  (KL=106869.48, NLL=128.37)
 batch_loss: 2760.48  |  KL+NLL: 106661.06  (KL=106523.22, NLL=137.84)


Epoch 7/10:  85%|████████▌ | 307/360 [00:14<00:02, 20.95it/s]

 batch_loss: 2746.73  |  KL+NLL: 106865.44  (KL=106732.18, NLL=133.26)
 batch_loss: 2777.20  |  KL+NLL: 106616.53  (KL=106534.95, NLL=81.58)


Epoch 7/10:  86%|████████▌ | 310/360 [00:15<00:02, 20.94it/s]

 batch_loss: 2792.45  |  KL+NLL: 106676.38  (KL=106555.71, NLL=120.67)
 batch_loss: 2877.83  |  KL+NLL: 106670.38  (KL=106534.42, NLL=135.96)
 batch_loss: 2860.44  |  KL+NLL: 106884.52  (KL=106766.10, NLL=118.41)
 batch_loss: 2826.30  |  KL+NLL: 106957.09  (KL=106806.52, NLL=150.58)


Epoch 7/10:  88%|████████▊ | 316/360 [00:15<00:02, 20.41it/s]

 batch_loss: 2749.93  |  KL+NLL: 106871.30  (KL=106774.89, NLL=96.41)
 batch_loss: 2900.05  |  KL+NLL: 106940.74  (KL=106805.73, NLL=135.00)
 batch_loss: 2603.02  |  KL+NLL: 106832.10  (KL=106706.71, NLL=125.39)
 batch_loss: 2738.31  |  KL+NLL: 106874.76  (KL=106777.46, NLL=97.29)
 batch_loss: 2826.83  |  KL+NLL: 106758.30  (KL=106642.73, NLL=115.58)


Epoch 7/10:  89%|████████▊ | 319/360 [00:15<00:02, 20.48it/s]

 batch_loss: 2731.96  |  KL+NLL: 106972.20  (KL=106815.18, NLL=157.02)
 batch_loss: 2718.95  |  KL+NLL: 106963.84  (KL=106842.52, NLL=121.32)
 batch_loss: 2802.61  |  KL+NLL: 106920.36  (KL=106790.38, NLL=129.97)
 batch_loss: 2722.21  |  KL+NLL: 106993.41  (KL=106870.35, NLL=123.05)


Epoch 7/10:  89%|████████▉ | 322/360 [00:15<00:01, 20.42it/s]

 batch_loss: 2831.92  |  KL+NLL: 107058.69  (KL=106930.96, NLL=127.73)


Epoch 7/10:  90%|█████████ | 325/360 [00:15<00:01, 20.37it/s]

 batch_loss: 2746.75  |  KL+NLL: 106889.73  (KL=106742.95, NLL=146.78)
 batch_loss: 2877.95  |  KL+NLL: 107152.83  (KL=107039.03, NLL=113.79)
 batch_loss: 2882.70  |  KL+NLL: 106931.25  (KL=106756.51, NLL=174.74)
 batch_loss: 2724.32  |  KL+NLL: 107044.52  (KL=106878.06, NLL=166.46)
 batch_loss: 2832.94  |  KL+NLL: 107009.58  (KL=106887.62, NLL=121.97)


Epoch 7/10:  92%|█████████▏| 331/360 [00:16<00:01, 20.79it/s]

 batch_loss: 2626.84  |  KL+NLL: 106865.89  (KL=106761.05, NLL=104.84)
 batch_loss: 2679.51  |  KL+NLL: 106874.83  (KL=106738.84, NLL=135.99)
 batch_loss: 2624.25  |  KL+NLL: 107091.68  (KL=106958.21, NLL=133.47)
 batch_loss: 2725.81  |  KL+NLL: 107039.97  (KL=106911.59, NLL=128.39)
 batch_loss: 2811.81  |  KL+NLL: 106901.84  (KL=106773.97, NLL=127.87)


Epoch 7/10:  93%|█████████▎| 334/360 [00:16<00:01, 20.49it/s]

 batch_loss: 2739.67  |  KL+NLL: 106900.25  (KL=106777.44, NLL=122.81)
 batch_loss: 2695.30  |  KL+NLL: 106929.46  (KL=106770.97, NLL=158.49)
 batch_loss: 2650.01  |  KL+NLL: 106699.05  (KL=106603.73, NLL=95.32)
 batch_loss: 2697.67  |  KL+NLL: 107060.10  (KL=106949.76, NLL=110.34)


Epoch 7/10:  94%|█████████▍| 340/360 [00:16<00:00, 20.54it/s]

 batch_loss: 2593.28  |  KL+NLL: 107029.08  (KL=106891.88, NLL=137.20)
 batch_loss: 2679.04  |  KL+NLL: 106956.42  (KL=106837.80, NLL=118.63)
 batch_loss: 2608.49  |  KL+NLL: 107184.35  (KL=107045.30, NLL=139.05)
 batch_loss: 2679.95  |  KL+NLL: 107373.13  (KL=107237.67, NLL=135.46)
 batch_loss: 2695.28  |  KL+NLL: 106975.26  (KL=106852.56, NLL=122.69)


Epoch 7/10:  95%|█████████▌| 343/360 [00:16<00:00, 20.60it/s]

 batch_loss: 2741.98  |  KL+NLL: 106916.77  (KL=106805.63, NLL=111.14)
 batch_loss: 2722.79  |  KL+NLL: 107393.57  (KL=107231.85, NLL=161.72)
 batch_loss: 2568.51  |  KL+NLL: 107200.77  (KL=107087.64, NLL=113.12)
 batch_loss: 2662.99  |  KL+NLL: 107065.99  (KL=106960.97, NLL=105.02)


Epoch 7/10:  96%|█████████▌| 346/360 [00:16<00:00, 20.53it/s]

 batch_loss: 2611.07  |  KL+NLL: 107446.00  (KL=107246.39, NLL=199.61)


Epoch 7/10:  97%|█████████▋| 349/360 [00:16<00:00, 20.60it/s]

 batch_loss: 2647.06  |  KL+NLL: 107228.24  (KL=107090.27, NLL=137.97)
 batch_loss: 2537.95  |  KL+NLL: 107257.11  (KL=107090.72, NLL=166.40)
 batch_loss: 2594.00  |  KL+NLL: 107106.85  (KL=107006.78, NLL=100.07)
 batch_loss: 2567.94  |  KL+NLL: 107191.03  (KL=107050.99, NLL=140.04)
 batch_loss: 2626.35  |  KL+NLL: 107052.41  (KL=106926.65, NLL=125.76)


Epoch 7/10:  99%|█████████▊| 355/360 [00:17<00:00, 21.47it/s]

 batch_loss: 2471.29  |  KL+NLL: 107268.72  (KL=107160.19, NLL=108.53)
 batch_loss: 2605.06  |  KL+NLL: 107015.56  (KL=106897.46, NLL=118.09)
 batch_loss: 2603.67  |  KL+NLL: 107172.62  (KL=107063.87, NLL=108.75)
 batch_loss: 2600.94  |  KL+NLL: 107186.91  (KL=107026.67, NLL=160.24)
 batch_loss: 2620.75  |  KL+NLL: 107385.49  (KL=107219.38, NLL=166.11)


Epoch 7/10: 100%|██████████| 360/360 [00:17<00:00, 20.60it/s]


 batch_loss: 2534.94  |  KL+NLL: 107220.23  (KL=107118.91, NLL=101.31)
 batch_loss: 2593.97  |  KL+NLL: 107315.58  (KL=107187.55, NLL=128.03)
 batch_loss: 2569.65  |  KL+NLL: 107190.31  (KL=107037.70, NLL=152.60)
 batch_loss: 2538.26  |  KL+NLL: 107191.95  (KL=107065.64, NLL=126.31)
Epoch 7 - ELBO Loss: 3926.1544


Epoch 8/10:   1%|          | 3/360 [00:00<00:25, 13.87it/s]

 batch_loss: 2473.49  |  KL+NLL: 107216.95  (KL=107071.67, NLL=145.28)
 batch_loss: 2453.32  |  KL+NLL: 107321.81  (KL=107183.84, NLL=137.97)
 batch_loss: 2586.89  |  KL+NLL: 107357.41  (KL=107232.98, NLL=124.43)
 batch_loss: 2424.94  |  KL+NLL: 107113.99  (KL=106988.52, NLL=125.46)
 batch_loss: 2459.47  |  KL+NLL: 107286.77  (KL=107172.98, NLL=113.78)


Epoch 8/10:   2%|▎         | 9/360 [00:00<00:18, 19.40it/s]

 batch_loss: 2558.86  |  KL+NLL: 107449.02  (KL=107323.29, NLL=125.73)
 batch_loss: 2554.94  |  KL+NLL: 107129.38  (KL=107031.95, NLL=97.43)
 batch_loss: 2364.77  |  KL+NLL: 107588.85  (KL=107467.44, NLL=121.41)
 batch_loss: 2436.07  |  KL+NLL: 107479.51  (KL=107352.24, NLL=127.27)
 batch_loss: 2456.28  |  KL+NLL: 107387.54  (KL=107199.91, NLL=187.63)


Epoch 8/10:   4%|▍         | 15/360 [00:00<00:17, 20.29it/s]

 batch_loss: 2491.77  |  KL+NLL: 107434.42  (KL=107328.69, NLL=105.73)
 batch_loss: 2530.57  |  KL+NLL: 107400.03  (KL=107296.67, NLL=103.35)
 batch_loss: 2357.08  |  KL+NLL: 107178.27  (KL=107047.26, NLL=131.01)
 batch_loss: 2462.10  |  KL+NLL: 107275.07  (KL=107153.53, NLL=121.54)
 batch_loss: 2330.89  |  KL+NLL: 107423.42  (KL=107309.44, NLL=113.98)


Epoch 8/10:   5%|▌         | 18/360 [00:00<00:16, 20.58it/s]

 batch_loss: 2454.23  |  KL+NLL: 107448.85  (KL=107348.25, NLL=100.60)
 batch_loss: 2446.83  |  KL+NLL: 107459.19  (KL=107332.69, NLL=126.50)
 batch_loss: 2430.12  |  KL+NLL: 107433.67  (KL=107299.89, NLL=133.78)
 batch_loss: 2496.47  |  KL+NLL: 107544.49  (KL=107397.90, NLL=146.59)


Epoch 8/10:   7%|▋         | 24/360 [00:01<00:16, 20.54it/s]

 batch_loss: 2510.47  |  KL+NLL: 107474.43  (KL=107323.62, NLL=150.81)
 batch_loss: 2310.57  |  KL+NLL: 107349.98  (KL=107238.70, NLL=111.29)
 batch_loss: 2347.61  |  KL+NLL: 107397.01  (KL=107285.59, NLL=111.43)
 batch_loss: 2467.99  |  KL+NLL: 107643.93  (KL=107512.17, NLL=131.76)
 batch_loss: 2374.08  |  KL+NLL: 107489.73  (KL=107367.21, NLL=122.51)


Epoch 8/10:   8%|▊         | 27/360 [00:01<00:16, 20.66it/s]

 batch_loss: 2539.68  |  KL+NLL: 107287.51  (KL=107165.36, NLL=122.15)
 batch_loss: 2425.56  |  KL+NLL: 107397.71  (KL=107295.54, NLL=102.17)
 batch_loss: 2483.35  |  KL+NLL: 107381.90  (KL=107285.38, NLL=96.52)
 batch_loss: 2393.87  |  KL+NLL: 107619.73  (KL=107435.05, NLL=184.69)
 batch_loss: 2412.25  |  KL+NLL: 107667.92  (KL=107521.47, NLL=146.45)


Epoch 8/10:   9%|▉         | 33/360 [00:01<00:15, 21.02it/s]

 batch_loss: 2299.46  |  KL+NLL: 107524.50  (KL=107412.02, NLL=112.48)
 batch_loss: 2436.70  |  KL+NLL: 107551.95  (KL=107414.79, NLL=137.16)
 batch_loss: 2315.89  |  KL+NLL: 107441.14  (KL=107303.05, NLL=138.09)
 batch_loss: 2442.40  |  KL+NLL: 107435.18  (KL=107308.56, NLL=126.62)
 batch_loss: 2435.91  |  KL+NLL: 107385.54  (KL=107192.72, NLL=192.82)


Epoch 8/10:  11%|█         | 39/360 [00:01<00:15, 21.16it/s]

 batch_loss: 2362.68  |  KL+NLL: 107715.22  (KL=107609.06, NLL=106.16)
 batch_loss: 2369.34  |  KL+NLL: 107663.04  (KL=107540.59, NLL=122.45)
 batch_loss: 2294.12  |  KL+NLL: 107490.95  (KL=107377.02, NLL=113.94)
 batch_loss: 2356.02  |  KL+NLL: 107613.54  (KL=107506.07, NLL=107.47)
 batch_loss: 2280.37  |  KL+NLL: 107751.09  (KL=107614.52, NLL=136.58)


Epoch 8/10:  12%|█▏        | 42/360 [00:02<00:15, 21.06it/s]

 batch_loss: 2295.62  |  KL+NLL: 107832.18  (KL=107705.58, NLL=126.60)
 batch_loss: 2354.72  |  KL+NLL: 107759.22  (KL=107651.88, NLL=107.34)
 batch_loss: 2312.34  |  KL+NLL: 107731.33  (KL=107618.37, NLL=112.97)
 batch_loss: 2216.38  |  KL+NLL: 107619.28  (KL=107507.15, NLL=112.13)
 batch_loss: 2275.11  |  KL+NLL: 107487.52  (KL=107377.98, NLL=109.54)


Epoch 8/10:  13%|█▎        | 48/360 [00:02<00:15, 20.74it/s]

 batch_loss: 2225.14  |  KL+NLL: 107819.10  (KL=107696.58, NLL=122.52)
 batch_loss: 2285.92  |  KL+NLL: 107763.97  (KL=107662.32, NLL=101.65)
 batch_loss: 2312.75  |  KL+NLL: 107762.42  (KL=107635.51, NLL=126.92)
 batch_loss: 2266.48  |  KL+NLL: 107780.91  (KL=107650.44, NLL=130.48)
 batch_loss: 2255.57  |  KL+NLL: 107637.36  (KL=107517.27, NLL=120.09)


Epoch 8/10:  15%|█▌        | 54/360 [00:02<00:14, 20.69it/s]

 batch_loss: 2305.06  |  KL+NLL: 107817.83  (KL=107710.77, NLL=107.07)
 batch_loss: 2242.71  |  KL+NLL: 107728.73  (KL=107603.76, NLL=124.97)
 batch_loss: 2251.70  |  KL+NLL: 107840.59  (KL=107708.29, NLL=132.30)
 batch_loss: 2169.47  |  KL+NLL: 107455.34  (KL=107360.52, NLL=94.82)
 batch_loss: 2207.93  |  KL+NLL: 107649.15  (KL=107510.58, NLL=138.57)


Epoch 8/10:  16%|█▌        | 57/360 [00:02<00:14, 20.31it/s]

 batch_loss: 2172.54  |  KL+NLL: 107661.49  (KL=107532.37, NLL=129.12)
 batch_loss: 2303.44  |  KL+NLL: 107772.33  (KL=107680.84, NLL=91.49)
 batch_loss: 2301.66  |  KL+NLL: 107735.14  (KL=107608.92, NLL=126.21)
 batch_loss: 2149.21  |  KL+NLL: 107806.62  (KL=107698.33, NLL=108.30)
 batch_loss: 2240.67  |  KL+NLL: 107804.29  (KL=107678.27, NLL=126.03)


Epoch 8/10:  18%|█▊        | 63/360 [00:03<00:14, 20.71it/s]

 batch_loss: 2237.96  |  KL+NLL: 107812.20  (KL=107669.01, NLL=143.19)
 batch_loss: 2287.19  |  KL+NLL: 107775.96  (KL=107684.44, NLL=91.52)
 batch_loss: 2212.99  |  KL+NLL: 107716.50  (KL=107607.26, NLL=109.24)
 batch_loss: 2191.53  |  KL+NLL: 107751.25  (KL=107600.23, NLL=151.02)
 batch_loss: 2219.46  |  KL+NLL: 107817.58  (KL=107676.09, NLL=141.49)


Epoch 8/10:  19%|█▉        | 69/360 [00:03<00:13, 21.08it/s]

 batch_loss: 2213.32  |  KL+NLL: 107800.02  (KL=107595.21, NLL=204.81)
 batch_loss: 2258.26  |  KL+NLL: 107952.04  (KL=107812.16, NLL=139.88)
 batch_loss: 2169.07  |  KL+NLL: 107881.67  (KL=107772.91, NLL=108.77)
 batch_loss: 2099.09  |  KL+NLL: 107774.06  (KL=107666.41, NLL=107.64)
 batch_loss: 2129.41  |  KL+NLL: 107790.51  (KL=107656.02, NLL=134.49)


Epoch 8/10:  20%|██        | 72/360 [00:03<00:13, 20.85it/s]

 batch_loss: 2102.17  |  KL+NLL: 108036.26  (KL=107919.29, NLL=116.97)
 batch_loss: 2213.95  |  KL+NLL: 107524.90  (KL=107413.23, NLL=111.67)
 batch_loss: 2177.06  |  KL+NLL: 107883.45  (KL=107774.48, NLL=108.97)
 batch_loss: 2077.23  |  KL+NLL: 107789.32  (KL=107657.77, NLL=131.55)


Epoch 8/10:  21%|██        | 75/360 [00:03<00:13, 20.55it/s]

 batch_loss: 2275.09  |  KL+NLL: 107742.04  (KL=107602.20, NLL=139.84)
 batch_loss: 2062.45  |  KL+NLL: 107991.73  (KL=107895.34, NLL=96.39)
 batch_loss: 2188.99  |  KL+NLL: 107967.08  (KL=107856.63, NLL=110.45)
 batch_loss: 2148.43  |  KL+NLL: 107932.85  (KL=107820.16, NLL=112.69)


Epoch 8/10:  22%|██▎       | 81/360 [00:03<00:13, 20.16it/s]

 batch_loss: 2120.09  |  KL+NLL: 107886.84  (KL=107772.37, NLL=114.48)
 batch_loss: 2197.87  |  KL+NLL: 108073.61  (KL=107942.47, NLL=131.14)
 batch_loss: 2092.59  |  KL+NLL: 108227.49  (KL=108072.40, NLL=155.09)
 batch_loss: 2155.71  |  KL+NLL: 107965.34  (KL=107837.84, NLL=127.49)
 batch_loss: 2179.69  |  KL+NLL: 107866.41  (KL=107710.50, NLL=155.91)


Epoch 8/10:  24%|██▍       | 87/360 [00:04<00:13, 20.29it/s]

 batch_loss: 2075.73  |  KL+NLL: 107951.75  (KL=107857.71, NLL=94.04)
 batch_loss: 2093.51  |  KL+NLL: 108109.96  (KL=108005.56, NLL=104.40)
 batch_loss: 2138.50  |  KL+NLL: 107898.47  (KL=107784.74, NLL=113.72)
 batch_loss: 2063.36  |  KL+NLL: 108085.45  (KL=107921.94, NLL=163.52)
 batch_loss: 2064.35  |  KL+NLL: 108074.38  (KL=107940.05, NLL=134.33)


Epoch 8/10:  25%|██▌       | 90/360 [00:04<00:13, 20.46it/s]

 batch_loss: 2094.76  |  KL+NLL: 107937.30  (KL=107842.36, NLL=94.94)
 batch_loss: 2033.22  |  KL+NLL: 108074.03  (KL=107959.47, NLL=114.56)
 batch_loss: 2179.98  |  KL+NLL: 107960.51  (KL=107793.84, NLL=166.67)
 batch_loss: 2121.63  |  KL+NLL: 108217.96  (KL=108068.90, NLL=149.07)
 batch_loss: 2120.12  |  KL+NLL: 108052.48  (KL=107941.93, NLL=110.55)


Epoch 8/10:  27%|██▋       | 96/360 [00:04<00:12, 20.60it/s]

 batch_loss: 2118.39  |  KL+NLL: 108138.95  (KL=108035.66, NLL=103.28)
 batch_loss: 1984.01  |  KL+NLL: 108317.43  (KL=108184.35, NLL=133.08)
 batch_loss: 2037.83  |  KL+NLL: 108144.32  (KL=108024.09, NLL=120.23)
 batch_loss: 2141.13  |  KL+NLL: 108197.50  (KL=108052.39, NLL=145.11)
 batch_loss: 1960.79  |  KL+NLL: 108097.31  (KL=107997.76, NLL=99.55)


Epoch 8/10:  28%|██▊       | 102/360 [00:04<00:12, 20.72it/s]

 batch_loss: 2037.86  |  KL+NLL: 107925.81  (KL=107832.42, NLL=93.39)
 batch_loss: 2023.95  |  KL+NLL: 108197.90  (KL=108080.52, NLL=117.38)
 batch_loss: 1996.29  |  KL+NLL: 108129.28  (KL=107980.93, NLL=148.35)
 batch_loss: 1963.36  |  KL+NLL: 108081.14  (KL=107979.47, NLL=101.67)
 batch_loss: 2060.02  |  KL+NLL: 108076.61  (KL=107985.20, NLL=91.40)


Epoch 8/10:  29%|██▉       | 105/360 [00:05<00:12, 20.72it/s]

 batch_loss: 2008.62  |  KL+NLL: 108249.44  (KL=108141.03, NLL=108.41)
 batch_loss: 2171.85  |  KL+NLL: 108294.00  (KL=108173.36, NLL=120.64)
 batch_loss: 2012.89  |  KL+NLL: 108259.84  (KL=108107.73, NLL=152.11)
 batch_loss: 1958.26  |  KL+NLL: 108277.16  (KL=108096.01, NLL=181.15)
 batch_loss: 2004.40  |  KL+NLL: 108246.00  (KL=108137.71, NLL=108.28)


Epoch 8/10:  31%|███       | 111/360 [00:05<00:12, 20.33it/s]

 batch_loss: 1964.06  |  KL+NLL: 108068.97  (KL=107910.83, NLL=158.14)
 batch_loss: 2195.44  |  KL+NLL: 108127.55  (KL=108003.16, NLL=124.39)
 batch_loss: 2033.20  |  KL+NLL: 108303.21  (KL=108181.95, NLL=121.26)
 batch_loss: 1987.93  |  KL+NLL: 108032.02  (KL=107905.12, NLL=126.90)
 batch_loss: 1925.67  |  KL+NLL: 108021.62  (KL=107925.23, NLL=96.39)


Epoch 8/10:  32%|███▎      | 117/360 [00:05<00:12, 20.21it/s]

 batch_loss: 1991.22  |  KL+NLL: 108164.55  (KL=108038.14, NLL=126.41)
 batch_loss: 1848.15  |  KL+NLL: 108293.17  (KL=108149.37, NLL=143.80)
 batch_loss: 2029.23  |  KL+NLL: 108594.52  (KL=108467.59, NLL=126.93)
 batch_loss: 1934.63  |  KL+NLL: 108110.95  (KL=108008.41, NLL=102.55)
 batch_loss: 1987.46  |  KL+NLL: 108525.67  (KL=108427.70, NLL=97.97)


Epoch 8/10:  33%|███▎      | 120/360 [00:05<00:11, 20.43it/s]

 batch_loss: 1920.99  |  KL+NLL: 108237.89  (KL=108129.06, NLL=108.83)
 batch_loss: 1996.66  |  KL+NLL: 108307.83  (KL=108200.93, NLL=106.90)
 batch_loss: 1905.03  |  KL+NLL: 108404.24  (KL=108295.29, NLL=108.95)
 batch_loss: 1965.36  |  KL+NLL: 108433.26  (KL=108286.38, NLL=146.89)
 batch_loss: 1932.94  |  KL+NLL: 108317.45  (KL=108213.26, NLL=104.19)


Epoch 8/10:  35%|███▌      | 126/360 [00:06<00:11, 20.02it/s]

 batch_loss: 1927.49  |  KL+NLL: 108327.74  (KL=108215.45, NLL=112.28)
 batch_loss: 1933.69  |  KL+NLL: 108323.63  (KL=108199.98, NLL=123.65)
 batch_loss: 1980.44  |  KL+NLL: 108515.06  (KL=108408.12, NLL=106.93)
 batch_loss: 1788.28  |  KL+NLL: 108711.93  (KL=108557.20, NLL=154.73)
 batch_loss: 1878.33  |  KL+NLL: 108438.79  (KL=108323.77, NLL=115.03)


Epoch 8/10:  36%|███▌      | 129/360 [00:06<00:11, 20.49it/s]

 batch_loss: 1874.34  |  KL+NLL: 108219.98  (KL=108101.59, NLL=118.39)
 batch_loss: 1790.44  |  KL+NLL: 108113.40  (KL=107999.41, NLL=113.99)
 batch_loss: 1908.49  |  KL+NLL: 108307.60  (KL=108204.96, NLL=102.64)
 batch_loss: 1846.10  |  KL+NLL: 108476.44  (KL=108376.01, NLL=100.43)


Epoch 8/10:  38%|███▊      | 135/360 [00:06<00:11, 20.35it/s]

 batch_loss: 1868.14  |  KL+NLL: 108115.65  (KL=108005.62, NLL=110.03)
 batch_loss: 1830.82  |  KL+NLL: 108352.86  (KL=108236.10, NLL=116.76)
 batch_loss: 1895.47  |  KL+NLL: 108643.53  (KL=108491.55, NLL=151.98)
 batch_loss: 2006.75  |  KL+NLL: 108185.66  (KL=108054.91, NLL=130.75)
 batch_loss: 1904.51  |  KL+NLL: 108642.07  (KL=108468.62, NLL=173.44)


Epoch 8/10:  39%|███▉      | 141/360 [00:06<00:10, 20.68it/s]

 batch_loss: 1943.50  |  KL+NLL: 108342.81  (KL=108200.57, NLL=142.24)
 batch_loss: 1950.85  |  KL+NLL: 108434.84  (KL=108342.98, NLL=91.86)
 batch_loss: 1875.60  |  KL+NLL: 108455.03  (KL=108339.80, NLL=115.23)
 batch_loss: 1834.99  |  KL+NLL: 108595.96  (KL=108448.91, NLL=147.05)
 batch_loss: 1751.89  |  KL+NLL: 108461.04  (KL=108355.85, NLL=105.19)


Epoch 8/10:  40%|████      | 144/360 [00:07<00:10, 20.59it/s]

 batch_loss: 1895.06  |  KL+NLL: 108666.74  (KL=108544.00, NLL=122.74)
 batch_loss: 1881.88  |  KL+NLL: 108296.32  (KL=108170.27, NLL=126.05)
 batch_loss: 1831.10  |  KL+NLL: 108565.04  (KL=108404.41, NLL=160.63)
 batch_loss: 1837.32  |  KL+NLL: 108432.88  (KL=108337.77, NLL=95.11)
 batch_loss: 1902.26  |  KL+NLL: 108480.25  (KL=108358.20, NLL=122.05)


Epoch 8/10:  42%|████▏     | 150/360 [00:07<00:10, 20.44it/s]

 batch_loss: 1828.39  |  KL+NLL: 108480.10  (KL=108345.55, NLL=134.55)
 batch_loss: 1814.92  |  KL+NLL: 108536.24  (KL=108404.77, NLL=131.48)
 batch_loss: 1811.14  |  KL+NLL: 108582.18  (KL=108438.10, NLL=144.08)
 batch_loss: 1827.51  |  KL+NLL: 108650.40  (KL=108518.70, NLL=131.70)


Epoch 8/10:  42%|████▎     | 153/360 [00:07<00:10, 20.14it/s]

 batch_loss: 1822.47  |  KL+NLL: 108753.35  (KL=108596.85, NLL=156.49)
 batch_loss: 1843.53  |  KL+NLL: 108408.97  (KL=108292.46, NLL=116.51)
 batch_loss: 1877.77  |  KL+NLL: 108347.41  (KL=108230.51, NLL=116.90)
 batch_loss: 1813.00  |  KL+NLL: 108772.53  (KL=108645.48, NLL=127.04)
 batch_loss: 1791.24  |  KL+NLL: 108694.93  (KL=108602.76, NLL=92.17)


Epoch 8/10:  44%|████▍     | 159/360 [00:07<00:09, 21.03it/s]

 batch_loss: 1787.45  |  KL+NLL: 108792.39  (KL=108667.56, NLL=124.82)
 batch_loss: 1815.02  |  KL+NLL: 108558.36  (KL=108432.95, NLL=125.41)
 batch_loss: 1718.20  |  KL+NLL: 108560.30  (KL=108447.37, NLL=112.93)
 batch_loss: 1873.23  |  KL+NLL: 108799.57  (KL=108652.43, NLL=147.14)
 batch_loss: 1676.34  |  KL+NLL: 108837.26  (KL=108717.84, NLL=119.42)


Epoch 8/10:  46%|████▌     | 165/360 [00:08<00:09, 20.79it/s]

 batch_loss: 1796.38  |  KL+NLL: 108233.33  (KL=108139.22, NLL=94.12)
 batch_loss: 1822.72  |  KL+NLL: 108597.76  (KL=108489.43, NLL=108.33)
 batch_loss: 1801.52  |  KL+NLL: 108766.95  (KL=108632.65, NLL=134.30)
 batch_loss: 1637.88  |  KL+NLL: 108677.56  (KL=108528.23, NLL=149.33)
 batch_loss: 1756.24  |  KL+NLL: 109004.48  (KL=108834.32, NLL=170.16)


Epoch 8/10:  47%|████▋     | 168/360 [00:08<00:09, 21.04it/s]

 batch_loss: 1815.44  |  KL+NLL: 108430.51  (KL=108277.03, NLL=153.48)
 batch_loss: 1760.75  |  KL+NLL: 108654.61  (KL=108537.85, NLL=116.75)
 batch_loss: 1739.90  |  KL+NLL: 109051.45  (KL=108933.90, NLL=117.55)
 batch_loss: 1654.12  |  KL+NLL: 108747.96  (KL=108600.05, NLL=147.91)
 batch_loss: 1651.86  |  KL+NLL: 108701.87  (KL=108576.54, NLL=125.33)


Epoch 8/10:  48%|████▊     | 174/360 [00:08<00:08, 21.01it/s]

 batch_loss: 1652.16  |  KL+NLL: 108360.37  (KL=108231.83, NLL=128.54)
 batch_loss: 1794.40  |  KL+NLL: 108553.85  (KL=108424.59, NLL=129.25)
 batch_loss: 1790.41  |  KL+NLL: 108893.31  (KL=108755.18, NLL=138.13)
 batch_loss: 1815.96  |  KL+NLL: 108625.13  (KL=108512.62, NLL=112.50)
 batch_loss: 1744.88  |  KL+NLL: 108789.90  (KL=108665.23, NLL=124.66)


Epoch 8/10:  50%|█████     | 180/360 [00:08<00:08, 21.01it/s]

 batch_loss: 1641.01  |  KL+NLL: 108747.69  (KL=108603.21, NLL=144.47)
 batch_loss: 1636.74  |  KL+NLL: 108760.89  (KL=108636.28, NLL=124.61)
 batch_loss: 1787.27  |  KL+NLL: 108671.27  (KL=108508.27, NLL=162.99)
 batch_loss: 1704.76  |  KL+NLL: 108871.24  (KL=108754.27, NLL=116.97)
 batch_loss: 1657.67  |  KL+NLL: 108537.95  (KL=108424.80, NLL=113.16)


Epoch 8/10:  51%|█████     | 183/360 [00:08<00:08, 21.13it/s]

 batch_loss: 1735.08  |  KL+NLL: 108875.06  (KL=108751.63, NLL=123.43)
 batch_loss: 1712.00  |  KL+NLL: 108759.83  (KL=108633.41, NLL=126.42)
 batch_loss: 1692.13  |  KL+NLL: 108580.52  (KL=108434.98, NLL=145.53)
 batch_loss: 1738.33  |  KL+NLL: 109073.08  (KL=108964.46, NLL=108.62)
 batch_loss: 1715.79  |  KL+NLL: 108812.52  (KL=108681.88, NLL=130.63)


Epoch 8/10:  52%|█████▎    | 189/360 [00:09<00:08, 20.72it/s]

 batch_loss: 1705.65  |  KL+NLL: 108894.63  (KL=108808.06, NLL=86.57)
 batch_loss: 1711.56  |  KL+NLL: 108747.44  (KL=108624.23, NLL=123.21)
 batch_loss: 1713.02  |  KL+NLL: 108682.71  (KL=108546.17, NLL=136.54)
 batch_loss: 1682.14  |  KL+NLL: 108899.68  (KL=108778.30, NLL=121.38)
 batch_loss: 1667.35  |  KL+NLL: 108675.73  (KL=108519.25, NLL=156.48)


Epoch 8/10:  54%|█████▍    | 195/360 [00:09<00:08, 20.55it/s]

 batch_loss: 1737.61  |  KL+NLL: 108995.78  (KL=108857.91, NLL=137.86)
 batch_loss: 1673.91  |  KL+NLL: 108950.11  (KL=108831.66, NLL=118.45)
 batch_loss: 1649.95  |  KL+NLL: 108747.83  (KL=108640.80, NLL=107.03)
 batch_loss: 1675.56  |  KL+NLL: 108963.78  (KL=108849.49, NLL=114.28)
 batch_loss: 1631.93  |  KL+NLL: 108850.03  (KL=108727.00, NLL=123.03)


Epoch 8/10:  55%|█████▌    | 198/360 [00:09<00:08, 19.28it/s]

 batch_loss: 1664.55  |  KL+NLL: 109002.56  (KL=108855.31, NLL=147.25)
 batch_loss: 1658.04  |  KL+NLL: 108929.14  (KL=108780.80, NLL=148.34)
 batch_loss: 1634.18  |  KL+NLL: 108794.17  (KL=108667.80, NLL=126.36)
 batch_loss: 1692.99  |  KL+NLL: 108812.70  (KL=108694.18, NLL=118.52)


Epoch 8/10:  56%|█████▋    | 203/360 [00:09<00:08, 19.53it/s]

 batch_loss: 1637.77  |  KL+NLL: 109094.93  (KL=108987.76, NLL=107.17)
 batch_loss: 1732.97  |  KL+NLL: 108907.66  (KL=108792.48, NLL=115.19)
 batch_loss: 1642.16  |  KL+NLL: 109120.94  (KL=109002.42, NLL=118.52)
 batch_loss: 1635.77  |  KL+NLL: 109013.35  (KL=108860.42, NLL=152.93)


Epoch 8/10:  57%|█████▊    | 207/360 [00:10<00:07, 19.54it/s]

 batch_loss: 1759.00  |  KL+NLL: 108951.33  (KL=108763.35, NLL=187.98)
 batch_loss: 1581.69  |  KL+NLL: 109121.69  (KL=109005.84, NLL=115.86)
 batch_loss: 1564.76  |  KL+NLL: 109006.17  (KL=108871.49, NLL=134.68)
 batch_loss: 1608.25  |  KL+NLL: 108900.09  (KL=108777.24, NLL=122.85)


Epoch 8/10:  59%|█████▊    | 211/360 [00:10<00:07, 19.49it/s]

 batch_loss: 1579.71  |  KL+NLL: 109154.55  (KL=109003.30, NLL=151.25)
 batch_loss: 1576.19  |  KL+NLL: 109159.49  (KL=109035.51, NLL=123.98)
 batch_loss: 1647.25  |  KL+NLL: 108966.03  (KL=108861.14, NLL=104.89)
 batch_loss: 1681.77  |  KL+NLL: 109212.61  (KL=109081.10, NLL=131.51)
 batch_loss: 1550.46  |  KL+NLL: 108810.27  (KL=108633.09, NLL=177.19)


Epoch 8/10:  60%|██████    | 217/360 [00:10<00:07, 20.00it/s]

 batch_loss: 1659.03  |  KL+NLL: 108962.48  (KL=108832.77, NLL=129.72)
 batch_loss: 1593.24  |  KL+NLL: 108825.97  (KL=108711.09, NLL=114.89)
 batch_loss: 1592.09  |  KL+NLL: 109171.58  (KL=109050.10, NLL=121.48)
 batch_loss: 1726.48  |  KL+NLL: 109212.87  (KL=109089.35, NLL=123.52)
 batch_loss: 1628.34  |  KL+NLL: 108942.18  (KL=108750.31, NLL=191.87)


Epoch 8/10:  61%|██████    | 220/360 [00:10<00:06, 20.37it/s]

 batch_loss: 1643.20  |  KL+NLL: 109112.86  (KL=108987.09, NLL=125.77)
 batch_loss: 1602.00  |  KL+NLL: 109148.17  (KL=109043.35, NLL=104.81)
 batch_loss: 1547.37  |  KL+NLL: 109275.76  (KL=109140.30, NLL=135.47)
 batch_loss: 1531.32  |  KL+NLL: 109162.92  (KL=109056.19, NLL=106.74)
 batch_loss: 1591.15  |  KL+NLL: 109097.67  (KL=108991.92, NLL=105.75)


Epoch 8/10:  63%|██████▎   | 226/360 [00:11<00:06, 20.37it/s]

 batch_loss: 1584.66  |  KL+NLL: 109313.74  (KL=109186.84, NLL=126.90)
 batch_loss: 1528.88  |  KL+NLL: 109122.69  (KL=109014.54, NLL=108.15)
 batch_loss: 1688.20  |  KL+NLL: 109163.71  (KL=109033.24, NLL=130.46)
 batch_loss: 1544.28  |  KL+NLL: 109260.31  (KL=109131.55, NLL=128.76)
 batch_loss: 1609.08  |  KL+NLL: 109144.12  (KL=109006.54, NLL=137.58)


Epoch 8/10:  64%|██████▎   | 229/360 [00:11<00:06, 20.69it/s]

 batch_loss: 1532.37  |  KL+NLL: 109162.10  (KL=109029.21, NLL=132.89)
 batch_loss: 1562.64  |  KL+NLL: 109125.96  (KL=108998.23, NLL=127.74)
 batch_loss: 1615.45  |  KL+NLL: 109071.96  (KL=108925.42, NLL=146.54)
 batch_loss: 1609.33  |  KL+NLL: 109292.78  (KL=109141.13, NLL=151.65)


Epoch 8/10:  65%|██████▌   | 235/360 [00:11<00:06, 18.89it/s]

 batch_loss: 1590.90  |  KL+NLL: 109249.39  (KL=109145.15, NLL=104.24)
 batch_loss: 1525.77  |  KL+NLL: 109144.09  (KL=109010.27, NLL=133.81)
 batch_loss: 1497.97  |  KL+NLL: 109253.34  (KL=109121.72, NLL=131.62)
 batch_loss: 1528.36  |  KL+NLL: 109104.21  (KL=108967.17, NLL=137.04)


Epoch 8/10:  67%|██████▋   | 240/360 [00:11<00:06, 19.55it/s]

 batch_loss: 1493.19  |  KL+NLL: 109140.29  (KL=109014.83, NLL=125.46)
 batch_loss: 1505.06  |  KL+NLL: 109152.03  (KL=109055.19, NLL=96.84)
 batch_loss: 1530.69  |  KL+NLL: 109217.57  (KL=109074.27, NLL=143.30)
 batch_loss: 1533.86  |  KL+NLL: 109120.45  (KL=109005.48, NLL=114.96)
 batch_loss: 1512.50  |  KL+NLL: 109330.44  (KL=109158.38, NLL=172.05)


Epoch 8/10:  68%|██████▊   | 243/360 [00:11<00:05, 19.66it/s]

 batch_loss: 1491.15  |  KL+NLL: 109295.24  (KL=109167.37, NLL=127.87)
 batch_loss: 1553.55  |  KL+NLL: 109367.86  (KL=109230.09, NLL=137.76)
 batch_loss: 1491.39  |  KL+NLL: 109414.35  (KL=109289.08, NLL=125.28)
 batch_loss: 1500.92  |  KL+NLL: 109308.32  (KL=109209.38, NLL=98.95)
 batch_loss: 1498.27  |  KL+NLL: 109248.07  (KL=109080.20, NLL=167.86)


Epoch 8/10:  69%|██████▉   | 249/360 [00:12<00:05, 20.26it/s]

 batch_loss: 1525.91  |  KL+NLL: 109271.47  (KL=109162.75, NLL=108.72)
 batch_loss: 1498.11  |  KL+NLL: 109192.75  (KL=109088.01, NLL=104.74)
 batch_loss: 1539.25  |  KL+NLL: 109453.02  (KL=109318.61, NLL=134.41)
 batch_loss: 1506.81  |  KL+NLL: 109431.13  (KL=109297.83, NLL=133.30)
 batch_loss: 1495.00  |  KL+NLL: 109184.33  (KL=109046.48, NLL=137.85)


Epoch 8/10:  71%|███████   | 255/360 [00:12<00:04, 21.00it/s]

 batch_loss: 1516.69  |  KL+NLL: 109167.73  (KL=109051.41, NLL=116.32)
 batch_loss: 1386.26  |  KL+NLL: 109380.39  (KL=109219.83, NLL=160.56)
 batch_loss: 1422.05  |  KL+NLL: 109210.57  (KL=109077.01, NLL=133.56)
 batch_loss: 1491.52  |  KL+NLL: 109124.04  (KL=108969.78, NLL=154.26)
 batch_loss: 1551.94  |  KL+NLL: 109416.09  (KL=109290.86, NLL=125.23)


Epoch 8/10:  72%|███████▏  | 258/360 [00:12<00:04, 20.95it/s]

 batch_loss: 1455.31  |  KL+NLL: 109370.87  (KL=109242.96, NLL=127.91)
 batch_loss: 1439.33  |  KL+NLL: 109514.60  (KL=109413.14, NLL=101.46)
 batch_loss: 1478.96  |  KL+NLL: 109390.61  (KL=109273.19, NLL=117.42)
 batch_loss: 1430.23  |  KL+NLL: 109402.53  (KL=109268.24, NLL=134.28)
 batch_loss: 1476.82  |  KL+NLL: 109507.82  (KL=109371.23, NLL=136.60)


Epoch 8/10:  73%|███████▎  | 264/360 [00:12<00:04, 20.73it/s]

 batch_loss: 1457.76  |  KL+NLL: 109335.39  (KL=109235.98, NLL=99.42)
 batch_loss: 1454.37  |  KL+NLL: 109400.08  (KL=109291.73, NLL=108.36)
 batch_loss: 1390.81  |  KL+NLL: 109597.13  (KL=109485.58, NLL=111.55)
 batch_loss: 1549.10  |  KL+NLL: 109500.87  (KL=109407.91, NLL=92.96)
 batch_loss: 1400.52  |  KL+NLL: 109288.64  (KL=109161.14, NLL=127.50)


Epoch 8/10:  75%|███████▌  | 270/360 [00:13<00:04, 20.36it/s]

 batch_loss: 1495.42  |  KL+NLL: 109528.13  (KL=109418.97, NLL=109.16)
 batch_loss: 1484.81  |  KL+NLL: 109612.87  (KL=109502.50, NLL=110.37)
 batch_loss: 1452.09  |  KL+NLL: 109490.73  (KL=109388.13, NLL=102.60)
 batch_loss: 1385.18  |  KL+NLL: 109716.97  (KL=109570.68, NLL=146.29)
 batch_loss: 1442.44  |  KL+NLL: 109545.05  (KL=109364.96, NLL=180.09)


Epoch 8/10:  76%|███████▌  | 273/360 [00:13<00:04, 20.11it/s]

 batch_loss: 1436.28  |  KL+NLL: 109424.88  (KL=109298.21, NLL=126.67)
 batch_loss: 1457.21  |  KL+NLL: 109412.41  (KL=109279.30, NLL=133.11)
 batch_loss: 1400.78  |  KL+NLL: 109662.44  (KL=109522.48, NLL=139.96)
 batch_loss: 1432.51  |  KL+NLL: 109476.87  (KL=109326.73, NLL=150.13)


Epoch 8/10:  77%|███████▋  | 276/360 [00:13<00:04, 19.81it/s]

 batch_loss: 1366.54  |  KL+NLL: 109452.61  (KL=109298.45, NLL=154.16)
 batch_loss: 1385.96  |  KL+NLL: 109895.88  (KL=109743.94, NLL=151.95)
 batch_loss: 1467.56  |  KL+NLL: 109364.85  (KL=109258.16, NLL=106.69)
 batch_loss: 1415.19  |  KL+NLL: 109446.20  (KL=109295.62, NLL=150.57)


Epoch 8/10:  78%|███████▊  | 282/360 [00:13<00:03, 20.21it/s]

 batch_loss: 1439.44  |  KL+NLL: 109330.34  (KL=109200.95, NLL=129.39)
 batch_loss: 1486.95  |  KL+NLL: 109832.87  (KL=109727.45, NLL=105.42)
 batch_loss: 1326.81  |  KL+NLL: 109272.48  (KL=109149.42, NLL=123.06)
 batch_loss: 1462.94  |  KL+NLL: 109509.31  (KL=109381.12, NLL=128.19)
 batch_loss: 1445.63  |  KL+NLL: 109499.79  (KL=109368.30, NLL=131.49)


Epoch 8/10:  80%|████████  | 288/360 [00:14<00:03, 20.07it/s]

 batch_loss: 1319.89  |  KL+NLL: 109380.28  (KL=109268.12, NLL=112.16)
 batch_loss: 1445.98  |  KL+NLL: 109464.10  (KL=109350.32, NLL=113.78)
 batch_loss: 1442.12  |  KL+NLL: 109567.59  (KL=109453.66, NLL=113.93)
 batch_loss: 1440.62  |  KL+NLL: 109638.71  (KL=109515.40, NLL=123.31)
 batch_loss: 1360.22  |  KL+NLL: 109642.06  (KL=109534.91, NLL=107.15)


Epoch 8/10:  81%|████████  | 291/360 [00:14<00:03, 20.37it/s]

 batch_loss: 1341.78  |  KL+NLL: 109456.25  (KL=109346.96, NLL=109.29)
 batch_loss: 1447.05  |  KL+NLL: 109388.62  (KL=109296.16, NLL=92.45)
 batch_loss: 1317.44  |  KL+NLL: 109813.87  (KL=109708.68, NLL=105.19)
 batch_loss: 1347.88  |  KL+NLL: 109602.48  (KL=109494.48, NLL=108.00)
 batch_loss: 1410.50  |  KL+NLL: 109448.14  (KL=109328.95, NLL=119.20)


Epoch 8/10:  82%|████████▎ | 297/360 [00:14<00:03, 20.69it/s]

 batch_loss: 1397.18  |  KL+NLL: 109686.03  (KL=109595.82, NLL=90.21)
 batch_loss: 1375.69  |  KL+NLL: 109568.28  (KL=109451.68, NLL=116.60)
 batch_loss: 1432.49  |  KL+NLL: 109532.83  (KL=109428.84, NLL=103.99)
 batch_loss: 1353.83  |  KL+NLL: 109603.81  (KL=109456.94, NLL=146.87)
 batch_loss: 1373.71  |  KL+NLL: 109677.93  (KL=109512.77, NLL=165.16)


Epoch 8/10:  84%|████████▍ | 303/360 [00:14<00:02, 20.65it/s]

 batch_loss: 1291.15  |  KL+NLL: 109397.75  (KL=109278.00, NLL=119.75)
 batch_loss: 1389.50  |  KL+NLL: 109663.41  (KL=109557.94, NLL=105.47)
 batch_loss: 1371.06  |  KL+NLL: 109691.03  (KL=109585.19, NLL=105.84)
 batch_loss: 1423.94  |  KL+NLL: 109753.91  (KL=109585.58, NLL=168.33)
 batch_loss: 1194.85  |  KL+NLL: 109636.86  (KL=109507.41, NLL=129.45)


Epoch 8/10:  85%|████████▌ | 306/360 [00:15<00:02, 20.45it/s]

 batch_loss: 1263.39  |  KL+NLL: 109804.47  (KL=109653.79, NLL=150.68)
 batch_loss: 1262.99  |  KL+NLL: 109820.26  (KL=109677.86, NLL=142.40)
 batch_loss: 1306.52  |  KL+NLL: 109529.99  (KL=109411.52, NLL=118.48)
 batch_loss: 1351.61  |  KL+NLL: 109884.00  (KL=109773.42, NLL=110.57)
 batch_loss: 1339.23  |  KL+NLL: 109756.86  (KL=109626.12, NLL=130.74)


Epoch 8/10:  87%|████████▋ | 312/360 [00:15<00:02, 20.74it/s]

 batch_loss: 1286.86  |  KL+NLL: 109657.99  (KL=109535.66, NLL=122.33)
 batch_loss: 1303.33  |  KL+NLL: 109830.08  (KL=109709.30, NLL=120.79)
 batch_loss: 1333.72  |  KL+NLL: 109634.15  (KL=109510.65, NLL=123.50)
 batch_loss: 1317.69  |  KL+NLL: 110068.02  (KL=109971.23, NLL=96.80)
 batch_loss: 1358.97  |  KL+NLL: 109646.24  (KL=109541.42, NLL=104.82)


Epoch 8/10:  88%|████████▊ | 315/360 [00:15<00:02, 20.38it/s]

 batch_loss: 1255.60  |  KL+NLL: 109693.82  (KL=109588.91, NLL=104.90)
 batch_loss: 1346.37  |  KL+NLL: 109724.73  (KL=109609.61, NLL=115.12)
 batch_loss: 1368.85  |  KL+NLL: 110064.13  (KL=109935.19, NLL=128.94)
 batch_loss: 1181.42  |  KL+NLL: 109794.04  (KL=109659.14, NLL=134.90)


Epoch 8/10:  89%|████████▉ | 321/360 [00:15<00:01, 20.87it/s]

 batch_loss: 1236.14  |  KL+NLL: 109923.23  (KL=109830.41, NLL=92.82)
 batch_loss: 1282.27  |  KL+NLL: 109785.46  (KL=109660.47, NLL=124.99)
 batch_loss: 1305.31  |  KL+NLL: 109701.93  (KL=109566.01, NLL=135.93)
 batch_loss: 1233.94  |  KL+NLL: 109869.07  (KL=109748.09, NLL=120.98)
 batch_loss: 1288.33  |  KL+NLL: 109621.85  (KL=109502.99, NLL=118.86)


Epoch 8/10:  91%|█████████ | 327/360 [00:16<00:01, 21.32it/s]

 batch_loss: 1236.42  |  KL+NLL: 109778.21  (KL=109643.29, NLL=134.92)
 batch_loss: 1244.30  |  KL+NLL: 109632.39  (KL=109517.84, NLL=114.56)
 batch_loss: 1211.48  |  KL+NLL: 109928.97  (KL=109818.03, NLL=110.93)
 batch_loss: 1315.68  |  KL+NLL: 110031.01  (KL=109859.72, NLL=171.29)
 batch_loss: 1226.76  |  KL+NLL: 109664.97  (KL=109536.67, NLL=128.30)


Epoch 8/10:  92%|█████████▏| 330/360 [00:16<00:01, 21.25it/s]

 batch_loss: 1296.66  |  KL+NLL: 109741.20  (KL=109631.62, NLL=109.58)
 batch_loss: 1223.60  |  KL+NLL: 109972.05  (KL=109863.05, NLL=108.99)
 batch_loss: 1299.71  |  KL+NLL: 110030.48  (KL=109897.49, NLL=132.99)
 batch_loss: 1251.16  |  KL+NLL: 109668.14  (KL=109550.38, NLL=117.76)
 batch_loss: 1362.54  |  KL+NLL: 109755.51  (KL=109633.44, NLL=122.07)


Epoch 8/10:  93%|█████████▎| 336/360 [00:16<00:01, 20.58it/s]

 batch_loss: 1239.19  |  KL+NLL: 109803.35  (KL=109686.10, NLL=117.24)
 batch_loss: 1218.66  |  KL+NLL: 109479.02  (KL=109374.95, NLL=104.07)
 batch_loss: 1203.63  |  KL+NLL: 109963.93  (KL=109824.25, NLL=139.68)
 batch_loss: 1283.38  |  KL+NLL: 109867.20  (KL=109743.53, NLL=123.67)


Epoch 8/10:  94%|█████████▍| 339/360 [00:16<00:01, 20.80it/s]

 batch_loss: 1211.73  |  KL+NLL: 109780.59  (KL=109683.09, NLL=97.50)
 batch_loss: 1272.72  |  KL+NLL: 110015.42  (KL=109889.11, NLL=126.31)
 batch_loss: 1246.10  |  KL+NLL: 109924.55  (KL=109816.05, NLL=108.49)
 batch_loss: 1171.42  |  KL+NLL: 109585.90  (KL=109452.22, NLL=133.68)
 batch_loss: 1198.10  |  KL+NLL: 109928.92  (KL=109806.35, NLL=122.57)


Epoch 8/10:  96%|█████████▌| 345/360 [00:16<00:00, 20.32it/s]

 batch_loss: 1272.73  |  KL+NLL: 109630.11  (KL=109481.09, NLL=149.02)
 batch_loss: 1320.91  |  KL+NLL: 109769.45  (KL=109643.70, NLL=125.75)
 batch_loss: 1164.79  |  KL+NLL: 109622.66  (KL=109506.13, NLL=116.52)
 batch_loss: 1284.94  |  KL+NLL: 110137.95  (KL=110015.05, NLL=122.90)


Epoch 8/10:  97%|█████████▋| 348/360 [00:17<00:00, 20.20it/s]

 batch_loss: 1235.11  |  KL+NLL: 109804.28  (KL=109664.98, NLL=139.29)
 batch_loss: 1215.94  |  KL+NLL: 110087.87  (KL=109969.20, NLL=118.66)
 batch_loss: 1167.69  |  KL+NLL: 109987.84  (KL=109881.40, NLL=106.44)
 batch_loss: 1206.41  |  KL+NLL: 110171.86  (KL=110029.98, NLL=141.88)
 batch_loss: 1286.03  |  KL+NLL: 109930.80  (KL=109798.19, NLL=132.61)


Epoch 8/10:  98%|█████████▊| 354/360 [00:17<00:00, 20.66it/s]

 batch_loss: 1268.64  |  KL+NLL: 109781.52  (KL=109664.91, NLL=116.61)
 batch_loss: 1179.62  |  KL+NLL: 110132.88  (KL=110026.65, NLL=106.23)
 batch_loss: 1202.73  |  KL+NLL: 110033.58  (KL=109912.16, NLL=121.43)
 batch_loss: 1182.06  |  KL+NLL: 110033.09  (KL=109911.53, NLL=121.55)
 batch_loss: 1205.72  |  KL+NLL: 110166.70  (KL=110030.12, NLL=136.58)


Epoch 8/10: 100%|██████████| 360/360 [00:17<00:00, 20.43it/s]


 batch_loss: 1164.97  |  KL+NLL: 110158.65  (KL=110018.75, NLL=139.90)
 batch_loss: 1183.09  |  KL+NLL: 110030.27  (KL=109925.91, NLL=104.36)
 batch_loss: 1243.83  |  KL+NLL: 110118.53  (KL=109978.11, NLL=140.42)
 batch_loss: 1220.95  |  KL+NLL: 109877.73  (KL=109762.48, NLL=115.26)
 batch_loss: 1174.24  |  KL+NLL: 110055.68  (KL=109921.01, NLL=134.68)
Epoch 8 - ELBO Loss: 1766.3067


Epoch 9/10:   1%|          | 4/360 [00:00<00:22, 15.71it/s]

 batch_loss: 1154.56  |  KL+NLL: 109976.46  (KL=109811.60, NLL=164.86)
 batch_loss: 1219.63  |  KL+NLL: 109900.82  (KL=109769.81, NLL=131.01)
 batch_loss: 1201.24  |  KL+NLL: 109893.24  (KL=109768.01, NLL=125.23)
 batch_loss: 1220.61  |  KL+NLL: 109915.91  (KL=109807.73, NLL=108.18)
 batch_loss: 1132.46  |  KL+NLL: 110026.69  (KL=109895.11, NLL=131.58)


Epoch 9/10:   3%|▎         | 10/360 [00:00<00:17, 19.56it/s]

 batch_loss: 1251.81  |  KL+NLL: 110033.59  (KL=109884.59, NLL=149.00)
 batch_loss: 1204.09  |  KL+NLL: 109810.34  (KL=109707.11, NLL=103.23)
 batch_loss: 1179.16  |  KL+NLL: 110357.08  (KL=110255.72, NLL=101.36)
 batch_loss: 1170.53  |  KL+NLL: 109861.34  (KL=109734.92, NLL=126.42)
 batch_loss: 1218.41  |  KL+NLL: 110285.75  (KL=110170.01, NLL=115.75)


Epoch 9/10:   4%|▎         | 13/360 [00:00<00:17, 20.40it/s]

 batch_loss: 1123.83  |  KL+NLL: 109968.14  (KL=109876.12, NLL=92.02)
 batch_loss: 1125.09  |  KL+NLL: 109961.77  (KL=109855.87, NLL=105.90)
 batch_loss: 1094.88  |  KL+NLL: 110308.71  (KL=110215.82, NLL=92.89)
 batch_loss: 1077.65  |  KL+NLL: 110062.62  (KL=109942.81, NLL=119.81)
 batch_loss: 1210.06  |  KL+NLL: 110092.12  (KL=109933.16, NLL=158.96)


Epoch 9/10:   5%|▌         | 19/360 [00:00<00:16, 21.06it/s]

 batch_loss: 1157.22  |  KL+NLL: 110043.14  (KL=109939.21, NLL=103.93)
 batch_loss: 1199.30  |  KL+NLL: 110218.88  (KL=110104.94, NLL=113.95)
 batch_loss: 1159.29  |  KL+NLL: 110149.03  (KL=110003.85, NLL=145.18)
 batch_loss: 1181.77  |  KL+NLL: 109924.92  (KL=109797.42, NLL=127.50)
 batch_loss: 1132.01  |  KL+NLL: 110155.30  (KL=110053.89, NLL=101.41)


Epoch 9/10:   7%|▋         | 25/360 [00:01<00:15, 21.01it/s]

 batch_loss: 1118.21  |  KL+NLL: 110340.41  (KL=110212.98, NLL=127.42)
 batch_loss: 1151.97  |  KL+NLL: 110262.12  (KL=110108.38, NLL=153.74)
 batch_loss: 1206.38  |  KL+NLL: 110214.72  (KL=110087.73, NLL=126.99)
 batch_loss: 1201.96  |  KL+NLL: 110375.76  (KL=110269.73, NLL=106.04)
 batch_loss: 1127.18  |  KL+NLL: 110248.93  (KL=110082.76, NLL=166.17)


Epoch 9/10:   8%|▊         | 28/360 [00:01<00:15, 21.07it/s]

 batch_loss: 1108.10  |  KL+NLL: 110226.06  (KL=110131.62, NLL=94.44)
 batch_loss: 1085.17  |  KL+NLL: 110101.18  (KL=110021.27, NLL=79.92)
 batch_loss: 1139.85  |  KL+NLL: 110146.51  (KL=109995.60, NLL=150.91)
 batch_loss: 1073.64  |  KL+NLL: 109937.72  (KL=109845.83, NLL=91.90)
 batch_loss: 1079.91  |  KL+NLL: 110314.03  (KL=110201.35, NLL=112.67)


Epoch 9/10:   9%|▉         | 34/360 [00:01<00:15, 20.94it/s]

 batch_loss: 1141.65  |  KL+NLL: 110069.58  (KL=109910.74, NLL=158.84)
 batch_loss: 1126.11  |  KL+NLL: 110282.81  (KL=110178.62, NLL=104.19)
 batch_loss: 1147.84  |  KL+NLL: 110212.30  (KL=110094.16, NLL=118.15)
 batch_loss: 1064.30  |  KL+NLL: 110425.63  (KL=110314.82, NLL=110.81)


Epoch 9/10:  10%|█         | 37/360 [00:01<00:15, 20.54it/s]

 batch_loss: 1127.19  |  KL+NLL: 110148.08  (KL=110003.10, NLL=144.97)
 batch_loss: 1104.95  |  KL+NLL: 110183.34  (KL=110096.77, NLL=86.57)
 batch_loss: 1041.36  |  KL+NLL: 110233.28  (KL=110120.01, NLL=113.27)
 batch_loss: 1127.45  |  KL+NLL: 110324.00  (KL=110198.05, NLL=125.96)
 batch_loss: 1077.84  |  KL+NLL: 110173.04  (KL=110080.38, NLL=92.67)


Epoch 9/10:  12%|█▏        | 43/360 [00:02<00:15, 20.92it/s]

 batch_loss: 1101.47  |  KL+NLL: 110107.48  (KL=110001.23, NLL=106.25)
 batch_loss: 1107.72  |  KL+NLL: 110322.46  (KL=110166.35, NLL=156.11)
 batch_loss: 1069.74  |  KL+NLL: 110456.77  (KL=110361.27, NLL=95.51)
 batch_loss: 1168.18  |  KL+NLL: 110699.48  (KL=110575.16, NLL=124.33)
 batch_loss: 1116.39  |  KL+NLL: 110571.92  (KL=110431.94, NLL=139.98)


Epoch 9/10:  14%|█▎        | 49/360 [00:02<00:14, 20.80it/s]

 batch_loss: 1088.11  |  KL+NLL: 110416.92  (KL=110282.27, NLL=134.64)
 batch_loss: 1151.50  |  KL+NLL: 110419.29  (KL=110280.23, NLL=139.06)
 batch_loss: 1062.86  |  KL+NLL: 110215.59  (KL=110102.23, NLL=113.35)
 batch_loss: 1055.86  |  KL+NLL: 110290.38  (KL=110174.12, NLL=116.26)
 batch_loss: 1079.97  |  KL+NLL: 110237.67  (KL=110131.00, NLL=106.67)


Epoch 9/10:  14%|█▍        | 52/360 [00:02<00:14, 20.94it/s]

 batch_loss: 1087.01  |  KL+NLL: 110339.37  (KL=110199.91, NLL=139.46)
 batch_loss: 1082.76  |  KL+NLL: 110381.14  (KL=110248.55, NLL=132.59)
 batch_loss: 1062.25  |  KL+NLL: 110437.99  (KL=110300.02, NLL=137.97)
 batch_loss: 1047.28  |  KL+NLL: 110341.58  (KL=110212.07, NLL=129.51)
 batch_loss: 1054.07  |  KL+NLL: 110543.43  (KL=110388.65, NLL=154.78)


Epoch 9/10:  16%|█▌        | 58/360 [00:02<00:14, 20.72it/s]

 batch_loss: 1068.92  |  KL+NLL: 110458.93  (KL=110339.14, NLL=119.79)
 batch_loss: 1083.92  |  KL+NLL: 110412.91  (KL=110243.05, NLL=169.85)
 batch_loss: 1102.38  |  KL+NLL: 110266.35  (KL=110163.81, NLL=102.54)
 batch_loss: 1093.24  |  KL+NLL: 110520.70  (KL=110427.19, NLL=93.52)
 batch_loss: 1062.60  |  KL+NLL: 110294.07  (KL=110153.64, NLL=140.43)


Epoch 9/10:  17%|█▋        | 61/360 [00:02<00:14, 20.83it/s]

 batch_loss: 1058.33  |  KL+NLL: 110263.69  (KL=110146.28, NLL=117.40)
 batch_loss: 1088.42  |  KL+NLL: 110513.98  (KL=110378.73, NLL=135.26)
 batch_loss: 1078.50  |  KL+NLL: 110563.76  (KL=110402.94, NLL=160.82)
 batch_loss: 1009.90  |  KL+NLL: 110092.12  (KL=109985.34, NLL=106.78)


Epoch 9/10:  19%|█▊        | 67/360 [00:03<00:14, 20.89it/s]

 batch_loss: 1052.66  |  KL+NLL: 110593.18  (KL=110468.77, NLL=124.42)
 batch_loss: 1093.56  |  KL+NLL: 110334.63  (KL=110187.90, NLL=146.73)
 batch_loss: 1026.05  |  KL+NLL: 110514.61  (KL=110424.31, NLL=90.29)
 batch_loss: 1090.15  |  KL+NLL: 110364.22  (KL=110228.02, NLL=136.19)
 batch_loss: 1104.33  |  KL+NLL: 110509.66  (KL=110370.77, NLL=138.90)


Epoch 9/10:  19%|█▉        | 70/360 [00:03<00:13, 21.01it/s]

 batch_loss: 1049.83  |  KL+NLL: 110628.94  (KL=110519.57, NLL=109.37)
 batch_loss: 1000.62  |  KL+NLL: 110614.82  (KL=110490.17, NLL=124.65)
 batch_loss: 1042.90  |  KL+NLL: 110451.77  (KL=110325.30, NLL=126.46)
 batch_loss: 1036.93  |  KL+NLL: 110384.78  (KL=110277.88, NLL=106.90)


Epoch 9/10:  21%|██        | 76/360 [00:03<00:14, 19.94it/s]

 batch_loss: 1007.06  |  KL+NLL: 110389.06  (KL=110250.14, NLL=138.92)
 batch_loss: 986.78  |  KL+NLL: 110390.24  (KL=110258.22, NLL=132.02)
 batch_loss: 1010.60  |  KL+NLL: 110583.58  (KL=110450.72, NLL=132.86)
 batch_loss: 1030.01  |  KL+NLL: 110535.49  (KL=110392.91, NLL=142.59)


Epoch 9/10:  22%|██▏       | 79/360 [00:03<00:13, 20.25it/s]

 batch_loss: 942.53  |  KL+NLL: 110472.07  (KL=110332.23, NLL=139.85)
 batch_loss: 1075.18  |  KL+NLL: 110472.68  (KL=110311.51, NLL=161.18)
 batch_loss: 1041.24  |  KL+NLL: 110718.65  (KL=110598.34, NLL=120.32)
 batch_loss: 1000.16  |  KL+NLL: 110564.49  (KL=110441.79, NLL=122.70)
 batch_loss: 967.35  |  KL+NLL: 110469.80  (KL=110357.53, NLL=112.27)


Epoch 9/10:  24%|██▎       | 85/360 [00:04<00:13, 20.56it/s]

 batch_loss: 994.90  |  KL+NLL: 110478.41  (KL=110386.46, NLL=91.95)
 batch_loss: 1010.43  |  KL+NLL: 110540.03  (KL=110422.69, NLL=117.35)
 batch_loss: 1064.66  |  KL+NLL: 110718.80  (KL=110607.24, NLL=111.55)
 batch_loss: 1061.94  |  KL+NLL: 110520.23  (KL=110404.17, NLL=116.06)
 batch_loss: 1038.30  |  KL+NLL: 110585.99  (KL=110478.12, NLL=107.88)


Epoch 9/10:  25%|██▌       | 91/360 [00:04<00:13, 20.51it/s]

 batch_loss: 1027.85  |  KL+NLL: 110715.40  (KL=110588.78, NLL=126.62)
 batch_loss: 996.85  |  KL+NLL: 110594.50  (KL=110467.98, NLL=126.51)
 batch_loss: 1054.34  |  KL+NLL: 110524.05  (KL=110384.40, NLL=139.66)
 batch_loss: 1003.44  |  KL+NLL: 110626.06  (KL=110524.80, NLL=101.26)
 batch_loss: 965.28  |  KL+NLL: 110407.72  (KL=110286.38, NLL=121.33)


Epoch 9/10:  26%|██▌       | 94/360 [00:04<00:13, 20.39it/s]

 batch_loss: 973.25  |  KL+NLL: 110847.40  (KL=110720.02, NLL=127.38)
 batch_loss: 1004.97  |  KL+NLL: 110524.24  (KL=110394.53, NLL=129.70)
 batch_loss: 966.68  |  KL+NLL: 110280.49  (KL=110165.84, NLL=114.65)
 batch_loss: 1022.46  |  KL+NLL: 110927.73  (KL=110796.92, NLL=130.81)
 batch_loss: 985.29  |  KL+NLL: 110473.11  (KL=110355.69, NLL=117.42)


Epoch 9/10:  28%|██▊       | 100/360 [00:04<00:12, 20.63it/s]

 batch_loss: 966.76  |  KL+NLL: 110549.85  (KL=110402.21, NLL=147.64)
 batch_loss: 948.33  |  KL+NLL: 110534.24  (KL=110431.30, NLL=102.94)
 batch_loss: 966.15  |  KL+NLL: 110576.60  (KL=110480.59, NLL=96.01)
 batch_loss: 1002.22  |  KL+NLL: 110672.50  (KL=110543.37, NLL=129.14)
 batch_loss: 982.75  |  KL+NLL: 110703.82  (KL=110581.44, NLL=122.39)


Epoch 9/10:  29%|██▉       | 106/360 [00:05<00:12, 20.39it/s]

 batch_loss: 960.05  |  KL+NLL: 110667.00  (KL=110554.55, NLL=112.45)
 batch_loss: 932.37  |  KL+NLL: 110670.25  (KL=110575.13, NLL=95.12)
 batch_loss: 1029.36  |  KL+NLL: 110773.74  (KL=110639.67, NLL=134.07)
 batch_loss: 999.63  |  KL+NLL: 110772.91  (KL=110666.10, NLL=106.81)
 batch_loss: 968.00  |  KL+NLL: 110573.19  (KL=110452.34, NLL=120.85)


Epoch 9/10:  30%|███       | 109/360 [00:05<00:12, 20.56it/s]

 batch_loss: 836.16  |  KL+NLL: 110703.84  (KL=110565.32, NLL=138.52)
 batch_loss: 1027.69  |  KL+NLL: 110542.58  (KL=110404.59, NLL=137.99)
 batch_loss: 908.03  |  KL+NLL: 110544.77  (KL=110439.41, NLL=105.37)
 batch_loss: 940.10  |  KL+NLL: 110653.78  (KL=110538.60, NLL=115.18)
 batch_loss: 953.25  |  KL+NLL: 110651.13  (KL=110541.72, NLL=109.42)


Epoch 9/10:  32%|███▏      | 115/360 [00:05<00:11, 20.99it/s]

 batch_loss: 925.72  |  KL+NLL: 110670.05  (KL=110542.94, NLL=127.11)
 batch_loss: 1003.65  |  KL+NLL: 110626.19  (KL=110528.52, NLL=97.66)
 batch_loss: 910.63  |  KL+NLL: 110655.80  (KL=110540.40, NLL=115.40)
 batch_loss: 955.98  |  KL+NLL: 110615.67  (KL=110511.31, NLL=104.36)
 batch_loss: 906.49  |  KL+NLL: 110863.26  (KL=110747.05, NLL=116.21)


Epoch 9/10:  33%|███▎      | 118/360 [00:05<00:11, 21.10it/s]

 batch_loss: 976.14  |  KL+NLL: 110743.26  (KL=110600.45, NLL=142.80)
 batch_loss: 1000.87  |  KL+NLL: 110971.83  (KL=110868.32, NLL=103.51)
 batch_loss: 933.33  |  KL+NLL: 110978.03  (KL=110825.77, NLL=152.26)
 batch_loss: 944.32  |  KL+NLL: 110481.29  (KL=110361.30, NLL=119.99)


Epoch 9/10:  34%|███▍      | 124/360 [00:06<00:11, 20.65it/s]

 batch_loss: 915.91  |  KL+NLL: 110716.43  (KL=110612.88, NLL=103.55)
 batch_loss: 946.24  |  KL+NLL: 110728.44  (KL=110617.33, NLL=111.12)
 batch_loss: 942.94  |  KL+NLL: 110753.16  (KL=110632.51, NLL=120.65)
 batch_loss: 906.00  |  KL+NLL: 110768.14  (KL=110657.20, NLL=110.94)
 batch_loss: 976.27  |  KL+NLL: 110705.93  (KL=110608.84, NLL=97.09)


Epoch 9/10:  36%|███▌      | 130/360 [00:06<00:11, 20.76it/s]

 batch_loss: 894.58  |  KL+NLL: 110888.20  (KL=110765.94, NLL=122.26)
 batch_loss: 886.28  |  KL+NLL: 110680.89  (KL=110554.13, NLL=126.75)
 batch_loss: 912.31  |  KL+NLL: 110701.54  (KL=110589.12, NLL=112.42)
 batch_loss: 1024.04  |  KL+NLL: 110702.17  (KL=110563.16, NLL=139.01)
 batch_loss: 956.02  |  KL+NLL: 110677.15  (KL=110552.36, NLL=124.79)


Epoch 9/10:  37%|███▋      | 133/360 [00:06<00:11, 20.52it/s]

 batch_loss: 914.98  |  KL+NLL: 110605.51  (KL=110508.41, NLL=97.09)
 batch_loss: 932.60  |  KL+NLL: 110851.50  (KL=110740.20, NLL=111.30)
 batch_loss: 904.63  |  KL+NLL: 110733.88  (KL=110616.73, NLL=117.15)
 batch_loss: 909.00  |  KL+NLL: 111073.89  (KL=110919.62, NLL=154.26)
 batch_loss: 912.54  |  KL+NLL: 110751.87  (KL=110610.27, NLL=141.60)


Epoch 9/10:  39%|███▊      | 139/360 [00:06<00:10, 21.07it/s]

 batch_loss: 934.23  |  KL+NLL: 110918.75  (KL=110801.50, NLL=117.25)
 batch_loss: 903.50  |  KL+NLL: 110951.36  (KL=110823.38, NLL=127.99)
 batch_loss: 940.56  |  KL+NLL: 110729.97  (KL=110604.72, NLL=125.25)
 batch_loss: 947.57  |  KL+NLL: 111110.37  (KL=110968.74, NLL=141.63)
 batch_loss: 949.01  |  KL+NLL: 111004.20  (KL=110870.27, NLL=133.94)


Epoch 9/10:  39%|███▉      | 142/360 [00:06<00:10, 20.91it/s]

 batch_loss: 864.38  |  KL+NLL: 110975.28  (KL=110854.04, NLL=121.24)
 batch_loss: 837.42  |  KL+NLL: 110801.00  (KL=110676.36, NLL=124.64)
 batch_loss: 898.95  |  KL+NLL: 110852.91  (KL=110734.34, NLL=118.58)
 batch_loss: 852.67  |  KL+NLL: 110975.24  (KL=110834.15, NLL=141.10)


Epoch 9/10:  41%|████      | 148/360 [00:07<00:10, 20.74it/s]

 batch_loss: 862.09  |  KL+NLL: 110605.18  (KL=110504.41, NLL=100.78)
 batch_loss: 870.60  |  KL+NLL: 110872.95  (KL=110766.43, NLL=106.52)
 batch_loss: 833.10  |  KL+NLL: 110795.53  (KL=110674.20, NLL=121.32)
 batch_loss: 948.19  |  KL+NLL: 110852.47  (KL=110728.06, NLL=124.41)
 batch_loss: 850.71  |  KL+NLL: 111143.40  (KL=111013.26, NLL=130.14)


Epoch 9/10:  43%|████▎     | 154/360 [00:07<00:09, 21.15it/s]

 batch_loss: 862.98  |  KL+NLL: 111051.27  (KL=110927.93, NLL=123.34)
 batch_loss: 928.98  |  KL+NLL: 110784.31  (KL=110675.41, NLL=108.90)
 batch_loss: 915.06  |  KL+NLL: 110840.16  (KL=110739.54, NLL=100.62)
 batch_loss: 828.31  |  KL+NLL: 110902.01  (KL=110790.54, NLL=111.47)
 batch_loss: 881.92  |  KL+NLL: 110953.14  (KL=110843.27, NLL=109.87)


Epoch 9/10:  44%|████▎     | 157/360 [00:07<00:09, 21.38it/s]

 batch_loss: 923.44  |  KL+NLL: 110997.72  (KL=110895.79, NLL=101.93)
 batch_loss: 980.19  |  KL+NLL: 110868.34  (KL=110734.08, NLL=134.27)
 batch_loss: 870.59  |  KL+NLL: 111016.54  (KL=110906.94, NLL=109.61)
 batch_loss: 844.49  |  KL+NLL: 110794.03  (KL=110650.09, NLL=143.94)
 batch_loss: 875.85  |  KL+NLL: 110959.72  (KL=110818.46, NLL=141.26)


Epoch 9/10:  45%|████▌     | 163/360 [00:07<00:09, 20.76it/s]

 batch_loss: 901.81  |  KL+NLL: 110889.14  (KL=110768.33, NLL=120.81)
 batch_loss: 925.02  |  KL+NLL: 110760.92  (KL=110656.88, NLL=104.04)
 batch_loss: 876.63  |  KL+NLL: 111008.11  (KL=110883.70, NLL=124.42)
 batch_loss: 858.92  |  KL+NLL: 110970.52  (KL=110870.89, NLL=99.63)
 batch_loss: 867.77  |  KL+NLL: 110681.34  (KL=110582.40, NLL=98.94)


Epoch 9/10:  47%|████▋     | 169/360 [00:08<00:09, 20.70it/s]

 batch_loss: 826.54  |  KL+NLL: 111053.58  (KL=110922.16, NLL=131.42)
 batch_loss: 877.33  |  KL+NLL: 111162.49  (KL=111053.67, NLL=108.82)
 batch_loss: 883.34  |  KL+NLL: 111231.65  (KL=111075.72, NLL=155.93)
 batch_loss: 867.16  |  KL+NLL: 111142.01  (KL=111015.91, NLL=126.10)
 batch_loss: 830.74  |  KL+NLL: 111255.56  (KL=111139.39, NLL=116.17)


Epoch 9/10:  48%|████▊     | 172/360 [00:08<00:09, 20.47it/s]

 batch_loss: 918.52  |  KL+NLL: 111110.25  (KL=110990.24, NLL=120.01)
 batch_loss: 875.71  |  KL+NLL: 111067.61  (KL=110968.69, NLL=98.93)
 batch_loss: 949.73  |  KL+NLL: 110986.03  (KL=110874.84, NLL=111.19)
 batch_loss: 881.43  |  KL+NLL: 111135.92  (KL=111016.03, NLL=119.89)
 batch_loss: 883.73  |  KL+NLL: 110931.24  (KL=110823.42, NLL=107.82)


Epoch 9/10:  49%|████▉     | 178/360 [00:08<00:08, 20.72it/s]

 batch_loss: 875.98  |  KL+NLL: 111052.56  (KL=110913.65, NLL=138.91)
 batch_loss: 823.67  |  KL+NLL: 111031.67  (KL=110907.52, NLL=124.15)
 batch_loss: 828.29  |  KL+NLL: 110925.83  (KL=110822.39, NLL=103.44)
 batch_loss: 822.65  |  KL+NLL: 111110.06  (KL=110999.23, NLL=110.83)
 batch_loss: 879.02  |  KL+NLL: 111011.83  (KL=110888.36, NLL=123.47)


Epoch 9/10:  51%|█████     | 184/360 [00:08<00:08, 21.17it/s]

 batch_loss: 905.05  |  KL+NLL: 111212.67  (KL=111070.91, NLL=141.75)
 batch_loss: 829.19  |  KL+NLL: 111286.94  (KL=111141.19, NLL=145.75)
 batch_loss: 809.03  |  KL+NLL: 110852.32  (KL=110739.45, NLL=112.87)
 batch_loss: 885.64  |  KL+NLL: 110883.63  (KL=110759.95, NLL=123.68)
 batch_loss: 840.37  |  KL+NLL: 111015.39  (KL=110889.51, NLL=125.88)


Epoch 9/10:  52%|█████▏    | 187/360 [00:09<00:08, 21.01it/s]

 batch_loss: 853.09  |  KL+NLL: 111050.61  (KL=110956.20, NLL=94.40)
 batch_loss: 856.19  |  KL+NLL: 110884.34  (KL=110778.88, NLL=105.46)
 batch_loss: 863.75  |  KL+NLL: 111289.96  (KL=111190.31, NLL=99.64)
 batch_loss: 895.02  |  KL+NLL: 111210.05  (KL=111088.23, NLL=121.82)
 batch_loss: 860.89  |  KL+NLL: 111143.30  (KL=111021.55, NLL=121.75)


Epoch 9/10:  54%|█████▎    | 193/360 [00:09<00:07, 20.90it/s]

 batch_loss: 816.32  |  KL+NLL: 111191.65  (KL=111040.45, NLL=151.21)
 batch_loss: 785.45  |  KL+NLL: 111018.71  (KL=110916.09, NLL=102.61)
 batch_loss: 830.54  |  KL+NLL: 111181.51  (KL=111044.12, NLL=137.39)
 batch_loss: 873.76  |  KL+NLL: 111011.98  (KL=110875.89, NLL=136.09)
 batch_loss: 789.37  |  KL+NLL: 111063.21  (KL=110944.52, NLL=118.69)


Epoch 9/10:  55%|█████▌    | 199/360 [00:09<00:07, 20.63it/s]

 batch_loss: 805.91  |  KL+NLL: 111084.83  (KL=110977.06, NLL=107.77)
 batch_loss: 824.11  |  KL+NLL: 110923.95  (KL=110792.08, NLL=131.87)
 batch_loss: 815.22  |  KL+NLL: 111111.66  (KL=111001.17, NLL=110.49)
 batch_loss: 835.85  |  KL+NLL: 111363.16  (KL=111238.88, NLL=124.28)
 batch_loss: 856.98  |  KL+NLL: 111139.48  (KL=111016.27, NLL=123.22)


Epoch 9/10:  56%|█████▌    | 202/360 [00:09<00:07, 20.31it/s]

 batch_loss: 844.58  |  KL+NLL: 110953.28  (KL=110831.11, NLL=122.17)
 batch_loss: 797.55  |  KL+NLL: 111246.01  (KL=111139.27, NLL=106.74)
 batch_loss: 811.57  |  KL+NLL: 110841.68  (KL=110720.85, NLL=120.82)
 batch_loss: 802.08  |  KL+NLL: 111160.60  (KL=111060.45, NLL=100.15)
 batch_loss: 837.29  |  KL+NLL: 110972.02  (KL=110848.52, NLL=123.51)


Epoch 9/10:  58%|█████▊    | 208/360 [00:10<00:07, 20.71it/s]

 batch_loss: 883.90  |  KL+NLL: 111192.24  (KL=111082.10, NLL=110.14)
 batch_loss: 819.63  |  KL+NLL: 111205.71  (KL=111124.90, NLL=80.81)
 batch_loss: 808.35  |  KL+NLL: 111308.04  (KL=111188.89, NLL=119.15)
 batch_loss: 832.55  |  KL+NLL: 111314.33  (KL=111174.05, NLL=140.28)
 batch_loss: 838.72  |  KL+NLL: 111470.02  (KL=111357.73, NLL=112.29)


Epoch 9/10:  59%|█████▉    | 214/360 [00:10<00:07, 20.44it/s]

 batch_loss: 794.92  |  KL+NLL: 111118.78  (KL=111002.55, NLL=116.23)
 batch_loss: 869.63  |  KL+NLL: 111081.66  (KL=110936.96, NLL=144.70)
 batch_loss: 763.09  |  KL+NLL: 111243.65  (KL=111105.41, NLL=138.24)
 batch_loss: 798.35  |  KL+NLL: 111217.47  (KL=111070.46, NLL=147.01)
 batch_loss: 820.58  |  KL+NLL: 111463.04  (KL=111306.14, NLL=156.90)


Epoch 9/10:  60%|██████    | 217/360 [00:10<00:07, 20.32it/s]

 batch_loss: 799.72  |  KL+NLL: 111272.62  (KL=111154.65, NLL=117.98)
 batch_loss: 821.29  |  KL+NLL: 111574.42  (KL=111432.58, NLL=141.84)
 batch_loss: 807.24  |  KL+NLL: 111285.23  (KL=111102.94, NLL=182.29)
 batch_loss: 795.88  |  KL+NLL: 111148.10  (KL=111017.83, NLL=130.27)
 batch_loss: 810.87  |  KL+NLL: 111426.95  (KL=111282.02, NLL=144.93)


Epoch 9/10:  62%|██████▏   | 223/360 [00:10<00:06, 20.89it/s]

 batch_loss: 830.62  |  KL+NLL: 111214.97  (KL=111117.67, NLL=97.30)
 batch_loss: 732.71  |  KL+NLL: 111108.53  (KL=110988.16, NLL=120.37)
 batch_loss: 826.11  |  KL+NLL: 111429.03  (KL=111309.55, NLL=119.48)
 batch_loss: 819.44  |  KL+NLL: 111206.79  (KL=111105.02, NLL=101.77)
 batch_loss: 789.11  |  KL+NLL: 111116.64  (KL=111007.36, NLL=109.28)


Epoch 9/10:  64%|██████▎   | 229/360 [00:11<00:06, 21.03it/s]

 batch_loss: 830.26  |  KL+NLL: 111062.70  (KL=110930.08, NLL=132.62)
 batch_loss: 779.46  |  KL+NLL: 111002.33  (KL=110897.09, NLL=105.24)
 batch_loss: 813.40  |  KL+NLL: 111278.38  (KL=111160.16, NLL=118.22)
 batch_loss: 783.05  |  KL+NLL: 111190.66  (KL=111064.73, NLL=125.92)
 batch_loss: 777.64  |  KL+NLL: 111318.17  (KL=111187.30, NLL=130.87)


Epoch 9/10:  64%|██████▍   | 232/360 [00:11<00:06, 20.74it/s]

 batch_loss: 772.23  |  KL+NLL: 111110.02  (KL=110987.44, NLL=122.58)
 batch_loss: 821.84  |  KL+NLL: 111442.19  (KL=111334.13, NLL=108.05)
 batch_loss: 735.85  |  KL+NLL: 111283.63  (KL=111134.62, NLL=149.00)
 batch_loss: 793.39  |  KL+NLL: 111525.82  (KL=111407.30, NLL=118.51)
 batch_loss: 741.43  |  KL+NLL: 111311.60  (KL=111184.98, NLL=126.62)


Epoch 9/10:  66%|██████▌   | 238/360 [00:11<00:06, 20.02it/s]

 batch_loss: 727.86  |  KL+NLL: 111271.44  (KL=111164.32, NLL=107.12)
 batch_loss: 737.33  |  KL+NLL: 111287.76  (KL=111142.86, NLL=144.90)
 batch_loss: 732.94  |  KL+NLL: 111802.30  (KL=111681.97, NLL=120.33)
 batch_loss: 809.08  |  KL+NLL: 111450.85  (KL=111326.60, NLL=124.25)


Epoch 9/10:  67%|██████▋   | 241/360 [00:11<00:05, 20.11it/s]

 batch_loss: 789.77  |  KL+NLL: 111482.62  (KL=111355.43, NLL=127.19)
 batch_loss: 700.10  |  KL+NLL: 111355.91  (KL=111256.18, NLL=99.73)
 batch_loss: 663.18  |  KL+NLL: 111251.09  (KL=111123.27, NLL=127.81)
 batch_loss: 733.28  |  KL+NLL: 111278.34  (KL=111168.02, NLL=110.32)


Epoch 9/10:  68%|██████▊   | 246/360 [00:11<00:05, 19.88it/s]

 batch_loss: 817.47  |  KL+NLL: 111312.02  (KL=111217.11, NLL=94.91)
 batch_loss: 801.87  |  KL+NLL: 111247.56  (KL=111116.52, NLL=131.05)
 batch_loss: 777.31  |  KL+NLL: 111388.03  (KL=111212.93, NLL=175.10)
 batch_loss: 730.12  |  KL+NLL: 111197.80  (KL=111075.94, NLL=121.87)
 batch_loss: 749.05  |  KL+NLL: 111524.85  (KL=111419.46, NLL=105.39)


Epoch 9/10:  70%|███████   | 252/360 [00:12<00:05, 20.16it/s]

 batch_loss: 787.71  |  KL+NLL: 111247.29  (KL=111145.84, NLL=101.45)
 batch_loss: 726.53  |  KL+NLL: 111290.47  (KL=111192.14, NLL=98.33)
 batch_loss: 775.93  |  KL+NLL: 111472.86  (KL=111358.34, NLL=114.53)
 batch_loss: 783.29  |  KL+NLL: 111714.51  (KL=111580.77, NLL=133.74)
 batch_loss: 713.40  |  KL+NLL: 111286.14  (KL=111145.44, NLL=140.71)


Epoch 9/10:  71%|███████   | 255/360 [00:12<00:05, 20.30it/s]

 batch_loss: 717.05  |  KL+NLL: 111348.55  (KL=111221.48, NLL=127.07)
 batch_loss: 757.06  |  KL+NLL: 111197.89  (KL=111068.53, NLL=129.35)
 batch_loss: 721.50  |  KL+NLL: 111045.21  (KL=110935.94, NLL=109.27)
 batch_loss: 723.01  |  KL+NLL: 111495.25  (KL=111377.96, NLL=117.29)
 batch_loss: 755.03  |  KL+NLL: 111439.45  (KL=111314.23, NLL=125.22)


Epoch 9/10:  72%|███████▎  | 261/360 [00:12<00:04, 20.59it/s]

 batch_loss: 707.64  |  KL+NLL: 111724.80  (KL=111607.59, NLL=117.21)
 batch_loss: 703.64  |  KL+NLL: 111641.34  (KL=111517.21, NLL=124.13)
 batch_loss: 772.22  |  KL+NLL: 111320.91  (KL=111206.53, NLL=114.38)
 batch_loss: 763.54  |  KL+NLL: 111358.27  (KL=111234.84, NLL=123.43)


Epoch 9/10:  73%|███████▎  | 264/360 [00:12<00:04, 19.87it/s]

 batch_loss: 776.54  |  KL+NLL: 111502.69  (KL=111375.59, NLL=127.11)
 batch_loss: 746.40  |  KL+NLL: 111415.81  (KL=111313.87, NLL=101.94)
 batch_loss: 776.08  |  KL+NLL: 111377.64  (KL=111255.77, NLL=121.87)
 batch_loss: 753.79  |  KL+NLL: 111437.29  (KL=111334.23, NLL=103.05)


Epoch 9/10:  74%|███████▍  | 268/360 [00:13<00:04, 19.49it/s]

 batch_loss: 824.23  |  KL+NLL: 111291.84  (KL=111144.41, NLL=147.43)
 batch_loss: 709.24  |  KL+NLL: 111475.03  (KL=111366.59, NLL=108.44)
 batch_loss: 736.91  |  KL+NLL: 111481.03  (KL=111363.16, NLL=117.87)
 batch_loss: 715.58  |  KL+NLL: 111251.41  (KL=111151.29, NLL=100.12)


Epoch 9/10:  76%|███████▌  | 273/360 [00:13<00:04, 20.15it/s]

 batch_loss: 737.30  |  KL+NLL: 111749.24  (KL=111626.48, NLL=122.75)
 batch_loss: 758.83  |  KL+NLL: 111645.17  (KL=111510.84, NLL=134.33)
 batch_loss: 780.50  |  KL+NLL: 111492.50  (KL=111385.64, NLL=106.86)
 batch_loss: 720.47  |  KL+NLL: 111615.69  (KL=111500.48, NLL=115.21)
 batch_loss: 739.78  |  KL+NLL: 111350.09  (KL=111237.28, NLL=112.81)


Epoch 9/10:  77%|███████▋  | 276/360 [00:13<00:04, 20.20it/s]

 batch_loss: 721.34  |  KL+NLL: 111644.22  (KL=111535.87, NLL=108.36)
 batch_loss: 720.88  |  KL+NLL: 111501.48  (KL=111364.88, NLL=136.61)
 batch_loss: 740.08  |  KL+NLL: 111404.55  (KL=111285.56, NLL=118.99)
 batch_loss: 660.75  |  KL+NLL: 111611.00  (KL=111476.02, NLL=134.98)


Epoch 9/10:  78%|███████▊  | 282/360 [00:13<00:03, 19.89it/s]

 batch_loss: 695.30  |  KL+NLL: 111644.20  (KL=111495.62, NLL=148.58)
 batch_loss: 725.56  |  KL+NLL: 111321.83  (KL=111223.71, NLL=98.12)
 batch_loss: 746.35  |  KL+NLL: 111514.28  (KL=111388.21, NLL=126.07)
 batch_loss: 681.02  |  KL+NLL: 111423.23  (KL=111324.30, NLL=98.93)


Epoch 9/10:  79%|███████▉  | 286/360 [00:13<00:03, 19.44it/s]

 batch_loss: 735.72  |  KL+NLL: 111337.56  (KL=111229.95, NLL=107.61)
 batch_loss: 746.43  |  KL+NLL: 111531.61  (KL=111422.10, NLL=109.51)
 batch_loss: 698.63  |  KL+NLL: 111426.55  (KL=111303.88, NLL=122.67)
 batch_loss: 688.51  |  KL+NLL: 111423.16  (KL=111308.85, NLL=114.31)


Epoch 9/10:  80%|████████  | 288/360 [00:14<00:03, 19.37it/s]

 batch_loss: 703.31  |  KL+NLL: 111435.71  (KL=111336.20, NLL=99.52)
 batch_loss: 709.61  |  KL+NLL: 111560.74  (KL=111454.91, NLL=105.83)
 batch_loss: 757.21  |  KL+NLL: 111659.45  (KL=111540.32, NLL=119.13)
 batch_loss: 698.06  |  KL+NLL: 111542.24  (KL=111412.09, NLL=130.15)


Epoch 9/10:  82%|████████▏ | 294/360 [00:14<00:03, 19.82it/s]

 batch_loss: 751.49  |  KL+NLL: 111613.40  (KL=111492.09, NLL=121.31)
 batch_loss: 713.86  |  KL+NLL: 111568.30  (KL=111464.13, NLL=104.17)
 batch_loss: 721.28  |  KL+NLL: 111687.54  (KL=111529.41, NLL=158.13)
 batch_loss: 722.65  |  KL+NLL: 111751.14  (KL=111662.95, NLL=88.20)
 batch_loss: 646.56  |  KL+NLL: 111660.56  (KL=111546.66, NLL=113.90)


Epoch 9/10:  83%|████████▎ | 300/360 [00:14<00:02, 20.39it/s]

 batch_loss: 661.63  |  KL+NLL: 111631.55  (KL=111531.50, NLL=100.05)
 batch_loss: 691.65  |  KL+NLL: 111569.64  (KL=111458.14, NLL=111.50)
 batch_loss: 697.55  |  KL+NLL: 111399.10  (KL=111310.16, NLL=88.94)
 batch_loss: 692.05  |  KL+NLL: 111562.94  (KL=111426.62, NLL=136.33)
 batch_loss: 672.25  |  KL+NLL: 111727.10  (KL=111585.45, NLL=141.65)


Epoch 9/10:  84%|████████▍ | 303/360 [00:14<00:02, 20.60it/s]

 batch_loss: 638.65  |  KL+NLL: 111603.31  (KL=111493.62, NLL=109.68)
 batch_loss: 656.98  |  KL+NLL: 111723.36  (KL=111572.29, NLL=151.08)
 batch_loss: 735.10  |  KL+NLL: 111706.43  (KL=111596.85, NLL=109.58)
 batch_loss: 717.93  |  KL+NLL: 111398.39  (KL=111291.59, NLL=106.80)
 batch_loss: 740.59  |  KL+NLL: 111634.00  (KL=111496.14, NLL=137.86)


Epoch 9/10:  86%|████████▌ | 309/360 [00:15<00:02, 20.71it/s]

 batch_loss: 694.57  |  KL+NLL: 111729.43  (KL=111565.48, NLL=163.94)
 batch_loss: 725.97  |  KL+NLL: 111829.29  (KL=111725.73, NLL=103.56)
 batch_loss: 687.87  |  KL+NLL: 111591.82  (KL=111451.27, NLL=140.56)
 batch_loss: 674.05  |  KL+NLL: 111606.85  (KL=111485.33, NLL=121.53)


Epoch 9/10:  87%|████████▋ | 312/360 [00:15<00:02, 19.81it/s]

 batch_loss: 653.37  |  KL+NLL: 111722.27  (KL=111588.53, NLL=133.74)
 batch_loss: 744.17  |  KL+NLL: 111430.47  (KL=111296.23, NLL=134.25)
 batch_loss: 692.48  |  KL+NLL: 111671.76  (KL=111518.49, NLL=153.27)
 batch_loss: 689.34  |  KL+NLL: 111651.78  (KL=111527.24, NLL=124.54)


Epoch 9/10:  88%|████████▊ | 317/360 [00:15<00:02, 20.16it/s]

 batch_loss: 718.43  |  KL+NLL: 111766.27  (KL=111646.80, NLL=119.47)
 batch_loss: 647.88  |  KL+NLL: 111709.64  (KL=111607.66, NLL=101.98)
 batch_loss: 716.59  |  KL+NLL: 111552.13  (KL=111422.38, NLL=129.75)
 batch_loss: 707.00  |  KL+NLL: 111564.09  (KL=111451.73, NLL=112.35)
 batch_loss: 689.05  |  KL+NLL: 111848.78  (KL=111727.76, NLL=121.02)


Epoch 9/10:  90%|████████▉ | 323/360 [00:15<00:01, 20.17it/s]

 batch_loss: 623.91  |  KL+NLL: 111715.05  (KL=111602.02, NLL=113.03)
 batch_loss: 623.33  |  KL+NLL: 111766.90  (KL=111671.26, NLL=95.65)
 batch_loss: 667.63  |  KL+NLL: 111690.59  (KL=111589.38, NLL=101.21)
 batch_loss: 668.94  |  KL+NLL: 111576.26  (KL=111462.18, NLL=114.08)
 batch_loss: 651.13  |  KL+NLL: 111363.60  (KL=111215.10, NLL=148.50)


Epoch 9/10:  91%|█████████ | 326/360 [00:15<00:01, 20.42it/s]

 batch_loss: 710.56  |  KL+NLL: 111824.02  (KL=111692.51, NLL=131.51)
 batch_loss: 721.01  |  KL+NLL: 111818.97  (KL=111692.53, NLL=126.44)
 batch_loss: 652.24  |  KL+NLL: 111668.78  (KL=111566.81, NLL=101.97)
 batch_loss: 663.33  |  KL+NLL: 111733.22  (KL=111631.27, NLL=101.94)
 batch_loss: 675.81  |  KL+NLL: 111726.65  (KL=111618.95, NLL=107.71)


Epoch 9/10:  92%|█████████▏| 332/360 [00:16<00:01, 21.05it/s]

 batch_loss: 659.40  |  KL+NLL: 111918.56  (KL=111789.80, NLL=128.76)
 batch_loss: 662.44  |  KL+NLL: 111753.21  (KL=111656.73, NLL=96.48)
 batch_loss: 666.10  |  KL+NLL: 111842.66  (KL=111726.82, NLL=115.84)
 batch_loss: 601.01  |  KL+NLL: 111645.08  (KL=111525.08, NLL=120.00)
 batch_loss: 635.01  |  KL+NLL: 111361.81  (KL=111258.99, NLL=102.82)


Epoch 9/10:  94%|█████████▍| 338/360 [00:16<00:01, 21.35it/s]

 batch_loss: 691.21  |  KL+NLL: 111741.20  (KL=111626.83, NLL=114.37)
 batch_loss: 644.37  |  KL+NLL: 111811.76  (KL=111705.06, NLL=106.70)
 batch_loss: 668.78  |  KL+NLL: 111766.99  (KL=111656.29, NLL=110.70)
 batch_loss: 599.85  |  KL+NLL: 111596.48  (KL=111468.08, NLL=128.40)
 batch_loss: 660.44  |  KL+NLL: 111839.07  (KL=111732.02, NLL=107.06)


Epoch 9/10:  95%|█████████▍| 341/360 [00:16<00:00, 21.56it/s]

 batch_loss: 634.09  |  KL+NLL: 111904.64  (KL=111789.02, NLL=115.62)
 batch_loss: 729.39  |  KL+NLL: 111856.66  (KL=111757.60, NLL=99.06)
 batch_loss: 706.65  |  KL+NLL: 111791.75  (KL=111676.59, NLL=115.16)
 batch_loss: 703.01  |  KL+NLL: 111607.96  (KL=111476.78, NLL=131.17)
 batch_loss: 641.70  |  KL+NLL: 111757.96  (KL=111646.72, NLL=111.24)


Epoch 9/10:  96%|█████████▋| 347/360 [00:16<00:00, 20.86it/s]

 batch_loss: 674.29  |  KL+NLL: 111694.77  (KL=111563.30, NLL=131.47)
 batch_loss: 675.51  |  KL+NLL: 112118.42  (KL=111999.64, NLL=118.78)
 batch_loss: 624.29  |  KL+NLL: 112183.63  (KL=112104.33, NLL=79.30)
 batch_loss: 618.44  |  KL+NLL: 111934.37  (KL=111825.30, NLL=109.07)
 batch_loss: 662.50  |  KL+NLL: 111855.77  (KL=111739.01, NLL=116.76)


Epoch 9/10:  98%|█████████▊| 353/360 [00:17<00:00, 20.72it/s]

 batch_loss: 680.10  |  KL+NLL: 111918.20  (KL=111792.86, NLL=125.35)
 batch_loss: 640.42  |  KL+NLL: 111756.02  (KL=111659.02, NLL=97.00)
 batch_loss: 590.77  |  KL+NLL: 111581.44  (KL=111469.58, NLL=111.86)
 batch_loss: 596.39  |  KL+NLL: 111820.40  (KL=111696.61, NLL=123.79)
 batch_loss: 620.66  |  KL+NLL: 111656.75  (KL=111539.20, NLL=117.55)


Epoch 9/10:  99%|█████████▉| 356/360 [00:17<00:00, 21.26it/s]

 batch_loss: 620.16  |  KL+NLL: 111870.31  (KL=111762.55, NLL=107.77)
 batch_loss: 618.01  |  KL+NLL: 111855.32  (KL=111758.19, NLL=97.13)
 batch_loss: 631.73  |  KL+NLL: 111737.51  (KL=111622.95, NLL=114.55)
 batch_loss: 638.58  |  KL+NLL: 112022.92  (KL=111908.16, NLL=114.75)
 batch_loss: 581.37  |  KL+NLL: 111909.96  (KL=111787.16, NLL=122.81)


Epoch 9/10: 100%|██████████| 360/360 [00:17<00:00, 20.55it/s]


 batch_loss: 613.15  |  KL+NLL: 111875.09  (KL=111724.95, NLL=150.14)
 batch_loss: 640.91  |  KL+NLL: 111954.08  (KL=111815.52, NLL=138.57)
Epoch 9 - ELBO Loss: 869.5661


Epoch 10/10:   0%|          | 1/360 [00:00<00:48,  7.45it/s]

 batch_loss: 616.96  |  KL+NLL: 112037.36  (KL=111923.62, NLL=113.74)


Epoch 10/10:   1%|          | 3/360 [00:00<00:26, 13.55it/s]

 batch_loss: 611.60  |  KL+NLL: 111752.41  (KL=111622.70, NLL=129.72)
 batch_loss: 623.29  |  KL+NLL: 111720.87  (KL=111586.90, NLL=133.97)
 batch_loss: 589.08  |  KL+NLL: 111964.97  (KL=111839.45, NLL=125.52)


Epoch 10/10:   1%|▏         | 5/360 [00:00<00:22, 15.86it/s]

 batch_loss: 641.13  |  KL+NLL: 111903.93  (KL=111791.44, NLL=112.49)


Epoch 10/10:   2%|▏         | 8/360 [00:00<00:19, 17.70it/s]

 batch_loss: 674.39  |  KL+NLL: 111938.10  (KL=111831.11, NLL=106.99)
 batch_loss: 637.19  |  KL+NLL: 111962.64  (KL=111869.09, NLL=93.55)
 batch_loss: 606.95  |  KL+NLL: 112000.00  (KL=111869.30, NLL=130.69)
 batch_loss: 536.13  |  KL+NLL: 111770.72  (KL=111612.12, NLL=158.59)


Epoch 10/10:   3%|▎         | 12/360 [00:00<00:18, 18.64it/s]

 batch_loss: 639.00  |  KL+NLL: 111856.99  (KL=111718.80, NLL=138.18)
 batch_loss: 623.03  |  KL+NLL: 111980.18  (KL=111842.02, NLL=138.15)
 batch_loss: 589.74  |  KL+NLL: 111946.64  (KL=111801.01, NLL=145.63)
 batch_loss: 662.79  |  KL+NLL: 111910.64  (KL=111793.38, NLL=117.26)


Epoch 10/10:   5%|▌         | 18/360 [00:00<00:16, 20.40it/s]

 batch_loss: 657.15  |  KL+NLL: 111781.37  (KL=111660.39, NLL=120.98)
 batch_loss: 661.29  |  KL+NLL: 111934.24  (KL=111824.14, NLL=110.10)
 batch_loss: 593.98  |  KL+NLL: 111614.72  (KL=111510.79, NLL=103.93)
 batch_loss: 591.93  |  KL+NLL: 112050.93  (KL=111952.16, NLL=98.78)
 batch_loss: 614.17  |  KL+NLL: 112002.61  (KL=111889.27, NLL=113.34)


Epoch 10/10:   6%|▌         | 21/360 [00:01<00:16, 20.96it/s]

 batch_loss: 637.31  |  KL+NLL: 111801.33  (KL=111673.70, NLL=127.63)
 batch_loss: 601.24  |  KL+NLL: 112060.83  (KL=111968.45, NLL=92.38)
 batch_loss: 608.56  |  KL+NLL: 111675.09  (KL=111540.95, NLL=134.14)
 batch_loss: 588.46  |  KL+NLL: 112234.77  (KL=112119.04, NLL=115.74)
 batch_loss: 601.84  |  KL+NLL: 111924.28  (KL=111806.43, NLL=117.85)


Epoch 10/10:   8%|▊         | 27/360 [00:01<00:16, 20.67it/s]

 batch_loss: 618.27  |  KL+NLL: 112348.70  (KL=112224.91, NLL=123.79)
 batch_loss: 626.36  |  KL+NLL: 111945.32  (KL=111829.79, NLL=115.54)
 batch_loss: 603.14  |  KL+NLL: 111621.50  (KL=111505.71, NLL=115.78)
 batch_loss: 608.52  |  KL+NLL: 112347.89  (KL=112237.56, NLL=110.33)
 batch_loss: 592.04  |  KL+NLL: 112161.32  (KL=112012.91, NLL=148.41)


Epoch 10/10:   9%|▉         | 33/360 [00:01<00:15, 20.79it/s]

 batch_loss: 595.93  |  KL+NLL: 111938.11  (KL=111827.14, NLL=110.97)
 batch_loss: 554.47  |  KL+NLL: 112180.52  (KL=112068.29, NLL=112.23)
 batch_loss: 593.02  |  KL+NLL: 111928.07  (KL=111823.38, NLL=104.69)
 batch_loss: 629.30  |  KL+NLL: 111887.42  (KL=111786.99, NLL=100.42)
 batch_loss: 652.26  |  KL+NLL: 112287.56  (KL=112171.47, NLL=116.09)


Epoch 10/10:  10%|█         | 36/360 [00:01<00:15, 20.87it/s]

 batch_loss: 609.86  |  KL+NLL: 111961.69  (KL=111820.72, NLL=140.97)
 batch_loss: 598.85  |  KL+NLL: 111967.94  (KL=111853.17, NLL=114.76)
 batch_loss: 578.95  |  KL+NLL: 111824.62  (KL=111713.63, NLL=110.99)
 batch_loss: 600.48  |  KL+NLL: 112088.23  (KL=111980.38, NLL=107.85)
 batch_loss: 636.71  |  KL+NLL: 112143.32  (KL=112019.98, NLL=123.33)


Epoch 10/10:  12%|█▏        | 42/360 [00:02<00:15, 20.73it/s]

 batch_loss: 617.76  |  KL+NLL: 111873.56  (KL=111774.82, NLL=98.74)
 batch_loss: 614.06  |  KL+NLL: 112133.97  (KL=112009.07, NLL=124.89)
 batch_loss: 577.05  |  KL+NLL: 112262.47  (KL=112146.24, NLL=116.23)
 batch_loss: 602.25  |  KL+NLL: 112042.85  (KL=111920.29, NLL=122.56)
 batch_loss: 603.67  |  KL+NLL: 112080.52  (KL=111963.44, NLL=117.08)


Epoch 10/10:  13%|█▎        | 48/360 [00:02<00:15, 20.55it/s]

 batch_loss: 606.02  |  KL+NLL: 112005.99  (KL=111900.82, NLL=105.17)
 batch_loss: 571.62  |  KL+NLL: 111704.96  (KL=111600.86, NLL=104.10)
 batch_loss: 601.22  |  KL+NLL: 111985.96  (KL=111865.40, NLL=120.56)
 batch_loss: 562.53  |  KL+NLL: 111927.00  (KL=111831.78, NLL=95.22)
 batch_loss: 589.70  |  KL+NLL: 111831.84  (KL=111691.33, NLL=140.52)


Epoch 10/10:  14%|█▍        | 51/360 [00:02<00:14, 20.90it/s]

 batch_loss: 566.49  |  KL+NLL: 111891.30  (KL=111775.14, NLL=116.15)
 batch_loss: 554.82  |  KL+NLL: 111707.05  (KL=111586.27, NLL=120.78)
 batch_loss: 570.96  |  KL+NLL: 112080.95  (KL=111982.19, NLL=98.76)
 batch_loss: 530.76  |  KL+NLL: 112122.26  (KL=112010.25, NLL=112.01)
 batch_loss: 571.41  |  KL+NLL: 112123.41  (KL=111991.33, NLL=132.09)


Epoch 10/10:  16%|█▌        | 57/360 [00:02<00:14, 21.01it/s]

 batch_loss: 590.63  |  KL+NLL: 112438.93  (KL=112295.16, NLL=143.77)
 batch_loss: 626.68  |  KL+NLL: 112054.74  (KL=111909.36, NLL=145.38)
 batch_loss: 558.57  |  KL+NLL: 111846.24  (KL=111726.34, NLL=119.90)
 batch_loss: 563.59  |  KL+NLL: 112176.97  (KL=112019.33, NLL=157.64)
 batch_loss: 539.83  |  KL+NLL: 112284.57  (KL=112177.21, NLL=107.36)


Epoch 10/10:  17%|█▋        | 60/360 [00:02<00:14, 21.12it/s]

 batch_loss: 559.72  |  KL+NLL: 112121.75  (KL=112018.30, NLL=103.44)
 batch_loss: 593.88  |  KL+NLL: 112344.02  (KL=112217.26, NLL=126.76)
 batch_loss: 555.68  |  KL+NLL: 112100.12  (KL=111997.94, NLL=102.19)
 batch_loss: 565.43  |  KL+NLL: 111976.81  (KL=111888.61, NLL=88.20)


Epoch 10/10:  18%|█▊        | 66/360 [00:03<00:14, 20.58it/s]

 batch_loss: 606.35  |  KL+NLL: 112332.69  (KL=112204.53, NLL=128.16)
 batch_loss: 525.55  |  KL+NLL: 112657.28  (KL=112528.23, NLL=129.05)
 batch_loss: 541.80  |  KL+NLL: 112213.45  (KL=112080.56, NLL=132.89)
 batch_loss: 578.39  |  KL+NLL: 112065.77  (KL=111914.61, NLL=151.16)
 batch_loss: 552.36  |  KL+NLL: 112009.93  (KL=111878.11, NLL=131.82)


Epoch 10/10:  20%|██        | 72/360 [00:03<00:14, 20.48it/s]

 batch_loss: 574.37  |  KL+NLL: 111964.14  (KL=111836.16, NLL=127.97)
 batch_loss: 589.67  |  KL+NLL: 112177.89  (KL=112044.64, NLL=133.25)
 batch_loss: 562.35  |  KL+NLL: 111773.67  (KL=111662.16, NLL=111.51)
 batch_loss: 564.35  |  KL+NLL: 112279.73  (KL=112162.16, NLL=117.57)
 batch_loss: 541.13  |  KL+NLL: 112159.31  (KL=112041.66, NLL=117.66)


Epoch 10/10:  21%|██        | 75/360 [00:03<00:13, 20.79it/s]

 batch_loss: 581.85  |  KL+NLL: 112056.28  (KL=111933.34, NLL=122.93)
 batch_loss: 544.92  |  KL+NLL: 111990.46  (KL=111875.02, NLL=115.44)
 batch_loss: 593.42  |  KL+NLL: 112149.97  (KL=112030.09, NLL=119.87)
 batch_loss: 520.67  |  KL+NLL: 112311.62  (KL=112214.98, NLL=96.64)
 batch_loss: 598.85  |  KL+NLL: 112088.41  (KL=111926.76, NLL=161.65)


Epoch 10/10:  22%|██▎       | 81/360 [00:04<00:13, 20.24it/s]

 batch_loss: 557.76  |  KL+NLL: 111887.22  (KL=111791.09, NLL=96.13)
 batch_loss: 561.72  |  KL+NLL: 112141.14  (KL=112015.30, NLL=125.83)
 batch_loss: 586.79  |  KL+NLL: 112009.77  (KL=111892.69, NLL=117.09)
 batch_loss: 543.29  |  KL+NLL: 112158.54  (KL=112056.02, NLL=102.53)


Epoch 10/10:  23%|██▎       | 84/360 [00:04<00:13, 20.34it/s]

 batch_loss: 568.98  |  KL+NLL: 112030.42  (KL=111913.63, NLL=116.79)
 batch_loss: 527.94  |  KL+NLL: 112017.73  (KL=111914.09, NLL=103.64)
 batch_loss: 552.93  |  KL+NLL: 112092.35  (KL=111957.05, NLL=135.30)
 batch_loss: 566.47  |  KL+NLL: 112219.38  (KL=112116.66, NLL=102.71)
 batch_loss: 535.62  |  KL+NLL: 112103.45  (KL=111993.91, NLL=109.54)


Epoch 10/10:  25%|██▌       | 90/360 [00:04<00:13, 20.10it/s]

 batch_loss: 546.13  |  KL+NLL: 112163.11  (KL=112049.28, NLL=113.83)
 batch_loss: 620.46  |  KL+NLL: 112271.98  (KL=112143.52, NLL=128.46)
 batch_loss: 552.33  |  KL+NLL: 112284.88  (KL=112141.31, NLL=143.57)
 batch_loss: 542.65  |  KL+NLL: 112185.60  (KL=112090.33, NLL=95.28)
 batch_loss: 536.11  |  KL+NLL: 112286.69  (KL=112198.59, NLL=88.10)


Epoch 10/10:  26%|██▌       | 93/360 [00:04<00:13, 20.35it/s]

 batch_loss: 553.10  |  KL+NLL: 112298.19  (KL=112160.22, NLL=137.97)
 batch_loss: 565.08  |  KL+NLL: 112218.61  (KL=112126.91, NLL=91.70)
 batch_loss: 515.75  |  KL+NLL: 112209.39  (KL=112093.51, NLL=115.88)
 batch_loss: 513.33  |  KL+NLL: 112252.07  (KL=112119.88, NLL=132.20)


Epoch 10/10:  28%|██▊       | 99/360 [00:04<00:12, 20.70it/s]

 batch_loss: 522.90  |  KL+NLL: 112212.17  (KL=112116.71, NLL=95.46)
 batch_loss: 538.27  |  KL+NLL: 112233.41  (KL=112108.54, NLL=124.87)
 batch_loss: 603.36  |  KL+NLL: 112184.67  (KL=112051.81, NLL=132.85)
 batch_loss: 542.33  |  KL+NLL: 112405.64  (KL=112300.91, NLL=104.72)
 batch_loss: 591.50  |  KL+NLL: 112294.26  (KL=112178.83, NLL=115.43)


Epoch 10/10:  29%|██▉       | 105/360 [00:05<00:12, 20.54it/s]

 batch_loss: 554.71  |  KL+NLL: 112046.31  (KL=111918.00, NLL=128.31)
 batch_loss: 543.49  |  KL+NLL: 112172.48  (KL=112039.80, NLL=132.68)
 batch_loss: 549.37  |  KL+NLL: 112343.32  (KL=112208.58, NLL=134.74)
 batch_loss: 515.96  |  KL+NLL: 111941.00  (KL=111799.67, NLL=141.33)
 batch_loss: 544.77  |  KL+NLL: 112301.60  (KL=112195.83, NLL=105.77)


Epoch 10/10:  30%|███       | 108/360 [00:05<00:12, 20.37it/s]

 batch_loss: 520.94  |  KL+NLL: 112201.98  (KL=112072.35, NLL=129.63)
 batch_loss: 516.97  |  KL+NLL: 112235.37  (KL=112122.48, NLL=112.89)
 batch_loss: 569.17  |  KL+NLL: 112284.40  (KL=112159.67, NLL=124.73)
 batch_loss: 501.22  |  KL+NLL: 112361.51  (KL=112232.69, NLL=128.82)
 batch_loss: 515.22  |  KL+NLL: 112200.57  (KL=112083.38, NLL=117.18)


Epoch 10/10:  32%|███▏      | 114/360 [00:05<00:12, 20.37it/s]

 batch_loss: 528.06  |  KL+NLL: 112295.05  (KL=112169.91, NLL=125.15)
 batch_loss: 483.52  |  KL+NLL: 112121.74  (KL=112015.16, NLL=106.57)
 batch_loss: 552.79  |  KL+NLL: 112167.69  (KL=112041.55, NLL=126.15)
 batch_loss: 539.33  |  KL+NLL: 112310.13  (KL=112212.18, NLL=97.95)


Epoch 10/10:  32%|███▎      | 117/360 [00:05<00:12, 19.89it/s]

 batch_loss: 521.63  |  KL+NLL: 112503.67  (KL=112356.92, NLL=146.75)
 batch_loss: 549.81  |  KL+NLL: 112128.74  (KL=112005.62, NLL=123.11)
 batch_loss: 453.24  |  KL+NLL: 112569.82  (KL=112446.71, NLL=123.11)
 batch_loss: 539.33  |  KL+NLL: 112412.95  (KL=112261.92, NLL=151.03)


Epoch 10/10:  34%|███▎      | 121/360 [00:06<00:12, 19.44it/s]

 batch_loss: 514.01  |  KL+NLL: 112111.66  (KL=111992.54, NLL=119.12)
 batch_loss: 533.90  |  KL+NLL: 112189.93  (KL=112047.59, NLL=142.34)
 batch_loss: 500.19  |  KL+NLL: 112414.79  (KL=112301.70, NLL=113.09)
 batch_loss: 563.22  |  KL+NLL: 112159.55  (KL=112029.36, NLL=130.19)


Epoch 10/10:  35%|███▌      | 126/360 [00:06<00:11, 19.57it/s]

 batch_loss: 616.26  |  KL+NLL: 112454.91  (KL=112309.62, NLL=145.29)
 batch_loss: 519.26  |  KL+NLL: 112301.19  (KL=112156.45, NLL=144.75)
 batch_loss: 515.59  |  KL+NLL: 112447.15  (KL=112294.16, NLL=153.00)
 batch_loss: 555.61  |  KL+NLL: 112395.55  (KL=112248.61, NLL=146.94)
 batch_loss: 540.78  |  KL+NLL: 112302.56  (KL=112193.07, NLL=109.49)


Epoch 10/10:  36%|███▌      | 130/360 [00:06<00:12, 18.91it/s]

 batch_loss: 592.34  |  KL+NLL: 112157.22  (KL=112043.31, NLL=113.91)
 batch_loss: 496.34  |  KL+NLL: 112470.59  (KL=112339.20, NLL=131.40)
 batch_loss: 522.47  |  KL+NLL: 112396.16  (KL=112294.99, NLL=101.17)
 batch_loss: 520.79  |  KL+NLL: 112284.31  (KL=112152.79, NLL=131.52)


Epoch 10/10:  38%|███▊      | 136/360 [00:06<00:11, 19.68it/s]

 batch_loss: 536.19  |  KL+NLL: 112322.92  (KL=112202.49, NLL=120.43)
 batch_loss: 502.53  |  KL+NLL: 112427.21  (KL=112312.57, NLL=114.64)
 batch_loss: 498.77  |  KL+NLL: 112345.17  (KL=112243.40, NLL=101.78)
 batch_loss: 458.77  |  KL+NLL: 112202.88  (KL=112108.55, NLL=94.32)
 batch_loss: 524.80  |  KL+NLL: 112467.59  (KL=112356.00, NLL=111.59)


Epoch 10/10:  39%|███▊      | 139/360 [00:06<00:10, 20.10it/s]

 batch_loss: 514.76  |  KL+NLL: 112739.52  (KL=112628.16, NLL=111.35)
 batch_loss: 521.31  |  KL+NLL: 112346.41  (KL=112212.67, NLL=133.74)
 batch_loss: 562.14  |  KL+NLL: 112297.69  (KL=112180.22, NLL=117.47)
 batch_loss: 597.18  |  KL+NLL: 112486.84  (KL=112385.62, NLL=101.22)
 batch_loss: 470.40  |  KL+NLL: 112475.85  (KL=112364.45, NLL=111.40)


Epoch 10/10:  40%|████      | 145/360 [00:07<00:10, 20.33it/s]

 batch_loss: 495.13  |  KL+NLL: 112138.16  (KL=112035.62, NLL=102.54)
 batch_loss: 556.61  |  KL+NLL: 112186.64  (KL=112096.34, NLL=90.30)
 batch_loss: 509.62  |  KL+NLL: 112202.41  (KL=112060.97, NLL=141.44)
 batch_loss: 484.45  |  KL+NLL: 112672.55  (KL=112538.35, NLL=134.19)
 batch_loss: 504.20  |  KL+NLL: 112517.36  (KL=112412.31, NLL=105.04)


Epoch 10/10:  42%|████▏     | 151/360 [00:07<00:10, 20.42it/s]

 batch_loss: 558.63  |  KL+NLL: 112465.11  (KL=112346.87, NLL=118.24)
 batch_loss: 502.81  |  KL+NLL: 112304.52  (KL=112188.81, NLL=115.71)
 batch_loss: 538.93  |  KL+NLL: 112433.96  (KL=112326.91, NLL=107.06)
 batch_loss: 511.80  |  KL+NLL: 112151.49  (KL=112048.35, NLL=103.14)
 batch_loss: 494.86  |  KL+NLL: 112559.34  (KL=112431.80, NLL=127.54)


Epoch 10/10:  43%|████▎     | 154/360 [00:07<00:09, 20.60it/s]

 batch_loss: 495.23  |  KL+NLL: 112322.06  (KL=112164.86, NLL=157.20)
 batch_loss: 471.44  |  KL+NLL: 112339.78  (KL=112235.58, NLL=104.20)
 batch_loss: 452.56  |  KL+NLL: 112492.05  (KL=112380.95, NLL=111.10)
 batch_loss: 444.75  |  KL+NLL: 112616.32  (KL=112492.70, NLL=123.62)
 batch_loss: 545.72  |  KL+NLL: 112361.57  (KL=112240.73, NLL=120.83)


Epoch 10/10:  44%|████▍     | 160/360 [00:07<00:09, 20.68it/s]

 batch_loss: 526.18  |  KL+NLL: 112716.05  (KL=112601.41, NLL=114.64)
 batch_loss: 522.75  |  KL+NLL: 112457.13  (KL=112331.77, NLL=125.35)
 batch_loss: 565.39  |  KL+NLL: 112370.65  (KL=112258.45, NLL=112.20)
 batch_loss: 530.19  |  KL+NLL: 112656.36  (KL=112564.19, NLL=92.17)
 batch_loss: 515.63  |  KL+NLL: 112253.78  (KL=112144.33, NLL=109.45)


Epoch 10/10:  46%|████▌     | 166/360 [00:08<00:09, 20.66it/s]

 batch_loss: 442.45  |  KL+NLL: 112456.29  (KL=112321.04, NLL=135.25)
 batch_loss: 505.16  |  KL+NLL: 112388.84  (KL=112284.52, NLL=104.32)
 batch_loss: 398.95  |  KL+NLL: 112560.90  (KL=112434.12, NLL=126.78)
 batch_loss: 509.94  |  KL+NLL: 112424.02  (KL=112290.01, NLL=134.01)
 batch_loss: 516.31  |  KL+NLL: 112502.02  (KL=112365.78, NLL=136.24)


Epoch 10/10:  47%|████▋     | 169/360 [00:08<00:09, 20.50it/s]

 batch_loss: 495.94  |  KL+NLL: 112397.97  (KL=112290.27, NLL=107.70)
 batch_loss: 531.11  |  KL+NLL: 112329.07  (KL=112211.54, NLL=117.54)
 batch_loss: 492.39  |  KL+NLL: 112469.79  (KL=112350.59, NLL=119.20)
 batch_loss: 515.85  |  KL+NLL: 112473.77  (KL=112365.44, NLL=108.33)
 batch_loss: 532.97  |  KL+NLL: 112399.99  (KL=112283.12, NLL=116.87)


Epoch 10/10:  49%|████▊     | 175/360 [00:08<00:09, 20.14it/s]

 batch_loss: 509.72  |  KL+NLL: 112501.07  (KL=112367.07, NLL=134.00)
 batch_loss: 439.92  |  KL+NLL: 112395.72  (KL=112291.54, NLL=104.18)
 batch_loss: 516.74  |  KL+NLL: 112538.05  (KL=112389.98, NLL=148.07)
 batch_loss: 494.91  |  KL+NLL: 112471.12  (KL=112340.20, NLL=130.93)


Epoch 10/10:  49%|████▉     | 178/360 [00:08<00:08, 20.63it/s]

 batch_loss: 495.58  |  KL+NLL: 112636.25  (KL=112521.16, NLL=115.09)
 batch_loss: 463.15  |  KL+NLL: 112583.64  (KL=112470.37, NLL=113.27)
 batch_loss: 460.98  |  KL+NLL: 112568.90  (KL=112453.84, NLL=115.06)
 batch_loss: 497.96  |  KL+NLL: 112684.93  (KL=112543.85, NLL=141.07)
 batch_loss: 482.07  |  KL+NLL: 112597.04  (KL=112466.02, NLL=131.02)


Epoch 10/10:  51%|█████     | 184/360 [00:09<00:08, 20.60it/s]

 batch_loss: 479.11  |  KL+NLL: 112534.99  (KL=112394.52, NLL=140.46)
 batch_loss: 506.30  |  KL+NLL: 112639.63  (KL=112532.95, NLL=106.68)
 batch_loss: 484.14  |  KL+NLL: 112733.72  (KL=112603.06, NLL=130.66)
 batch_loss: 452.00  |  KL+NLL: 112505.55  (KL=112416.48, NLL=89.07)
 batch_loss: 488.75  |  KL+NLL: 112236.74  (KL=112122.68, NLL=114.06)


Epoch 10/10:  53%|█████▎    | 190/360 [00:09<00:08, 20.60it/s]

 batch_loss: 498.72  |  KL+NLL: 112682.97  (KL=112566.83, NLL=116.15)
 batch_loss: 466.98  |  KL+NLL: 112414.95  (KL=112311.16, NLL=103.78)
 batch_loss: 506.47  |  KL+NLL: 112324.30  (KL=112226.37, NLL=97.94)
 batch_loss: 486.55  |  KL+NLL: 112430.22  (KL=112311.59, NLL=118.63)
 batch_loss: 492.84  |  KL+NLL: 112355.80  (KL=112210.11, NLL=145.69)


Epoch 10/10:  54%|█████▎    | 193/360 [00:09<00:08, 20.78it/s]

 batch_loss: 520.48  |  KL+NLL: 112877.72  (KL=112749.62, NLL=128.11)
 batch_loss: 480.12  |  KL+NLL: 112342.55  (KL=112242.45, NLL=100.11)
 batch_loss: 452.51  |  KL+NLL: 112614.37  (KL=112507.15, NLL=107.22)
 batch_loss: 426.10  |  KL+NLL: 112461.05  (KL=112323.16, NLL=137.90)
 batch_loss: 506.53  |  KL+NLL: 112662.71  (KL=112533.61, NLL=129.10)


Epoch 10/10:  55%|█████▌    | 199/360 [00:09<00:07, 20.44it/s]

 batch_loss: 412.17  |  KL+NLL: 112468.21  (KL=112356.98, NLL=111.23)
 batch_loss: 465.83  |  KL+NLL: 112462.26  (KL=112353.52, NLL=108.74)
 batch_loss: 498.11  |  KL+NLL: 112858.43  (KL=112744.59, NLL=113.84)
 batch_loss: 469.30  |  KL+NLL: 112429.29  (KL=112315.30, NLL=113.98)


Epoch 10/10:  56%|█████▌    | 202/360 [00:09<00:07, 20.10it/s]

 batch_loss: 475.13  |  KL+NLL: 112629.73  (KL=112508.04, NLL=121.69)
 batch_loss: 435.53  |  KL+NLL: 112591.55  (KL=112481.48, NLL=110.07)
 batch_loss: 470.63  |  KL+NLL: 112527.88  (KL=112404.54, NLL=123.34)
 batch_loss: 485.44  |  KL+NLL: 112662.30  (KL=112540.90, NLL=121.40)
 batch_loss: 481.06  |  KL+NLL: 112469.65  (KL=112359.80, NLL=109.85)


Epoch 10/10:  58%|█████▊    | 208/360 [00:10<00:07, 20.25it/s]

 batch_loss: 468.12  |  KL+NLL: 112294.95  (KL=112161.06, NLL=133.88)
 batch_loss: 486.27  |  KL+NLL: 112441.13  (KL=112312.01, NLL=129.12)
 batch_loss: 507.19  |  KL+NLL: 112769.83  (KL=112655.34, NLL=114.49)
 batch_loss: 465.54  |  KL+NLL: 112451.92  (KL=112314.44, NLL=137.48)
 batch_loss: 488.88  |  KL+NLL: 112508.64  (KL=112404.70, NLL=103.94)


Epoch 10/10:  59%|█████▉    | 214/360 [00:10<00:07, 20.35it/s]

 batch_loss: 465.38  |  KL+NLL: 112351.40  (KL=112249.75, NLL=101.65)
 batch_loss: 452.21  |  KL+NLL: 112626.39  (KL=112486.95, NLL=139.44)
 batch_loss: 475.23  |  KL+NLL: 112681.72  (KL=112561.95, NLL=119.77)
 batch_loss: 477.20  |  KL+NLL: 112497.04  (KL=112336.00, NLL=161.04)
 batch_loss: 446.38  |  KL+NLL: 112636.82  (KL=112513.91, NLL=122.91)


Epoch 10/10:  60%|██████    | 217/360 [00:10<00:07, 20.26it/s]

 batch_loss: 466.44  |  KL+NLL: 112630.93  (KL=112526.19, NLL=104.74)
 batch_loss: 466.36  |  KL+NLL: 112613.04  (KL=112491.33, NLL=121.71)
 batch_loss: 476.42  |  KL+NLL: 112611.57  (KL=112488.80, NLL=122.77)
 batch_loss: 419.94  |  KL+NLL: 112773.31  (KL=112652.72, NLL=120.59)
 batch_loss: 456.82  |  KL+NLL: 112467.35  (KL=112338.97, NLL=128.38)


Epoch 10/10:  62%|██████▏   | 223/360 [00:11<00:06, 20.67it/s]

 batch_loss: 480.64  |  KL+NLL: 112318.05  (KL=112213.69, NLL=104.36)
 batch_loss: 476.68  |  KL+NLL: 112811.27  (KL=112704.69, NLL=106.58)
 batch_loss: 446.48  |  KL+NLL: 112963.04  (KL=112841.88, NLL=121.17)
 batch_loss: 397.89  |  KL+NLL: 112781.41  (KL=112657.04, NLL=124.37)
 batch_loss: 470.42  |  KL+NLL: 112683.91  (KL=112562.75, NLL=121.16)


Epoch 10/10:  64%|██████▎   | 229/360 [00:11<00:06, 20.50it/s]

 batch_loss: 457.16  |  KL+NLL: 112488.04  (KL=112377.52, NLL=110.53)
 batch_loss: 440.61  |  KL+NLL: 112812.77  (KL=112688.09, NLL=124.67)
 batch_loss: 466.89  |  KL+NLL: 112782.97  (KL=112668.08, NLL=114.89)
 batch_loss: 453.28  |  KL+NLL: 112556.93  (KL=112442.88, NLL=114.06)
 batch_loss: 448.43  |  KL+NLL: 112243.48  (KL=112128.12, NLL=115.36)


Epoch 10/10:  64%|██████▍   | 232/360 [00:11<00:06, 20.46it/s]

 batch_loss: 411.89  |  KL+NLL: 112561.85  (KL=112442.30, NLL=119.55)
 batch_loss: 433.00  |  KL+NLL: 112831.62  (KL=112716.43, NLL=115.19)
 batch_loss: 459.15  |  KL+NLL: 112823.36  (KL=112718.60, NLL=104.76)
 batch_loss: 477.45  |  KL+NLL: 112796.15  (KL=112686.17, NLL=109.98)
 batch_loss: 490.60  |  KL+NLL: 112725.76  (KL=112629.59, NLL=96.16)


Epoch 10/10:  66%|██████▌   | 238/360 [00:11<00:06, 20.16it/s]

 batch_loss: 465.94  |  KL+NLL: 112422.32  (KL=112299.26, NLL=123.06)
 batch_loss: 464.48  |  KL+NLL: 112717.15  (KL=112607.70, NLL=109.44)
 batch_loss: 449.97  |  KL+NLL: 112767.28  (KL=112613.27, NLL=154.01)
 batch_loss: 428.94  |  KL+NLL: 112675.29  (KL=112548.08, NLL=127.21)


Epoch 10/10:  67%|██████▋   | 241/360 [00:11<00:05, 20.57it/s]

 batch_loss: 434.57  |  KL+NLL: 112743.72  (KL=112646.56, NLL=97.16)
 batch_loss: 415.77  |  KL+NLL: 112421.06  (KL=112307.23, NLL=113.83)
 batch_loss: 404.90  |  KL+NLL: 112665.47  (KL=112539.38, NLL=126.10)
 batch_loss: 398.60  |  KL+NLL: 112679.99  (KL=112560.95, NLL=119.03)
 batch_loss: 451.97  |  KL+NLL: 112608.69  (KL=112488.70, NLL=119.99)


Epoch 10/10:  69%|██████▊   | 247/360 [00:12<00:05, 20.24it/s]

 batch_loss: 463.54  |  KL+NLL: 112584.52  (KL=112484.76, NLL=99.76)
 batch_loss: 492.50  |  KL+NLL: 112552.99  (KL=112404.19, NLL=148.81)
 batch_loss: 482.35  |  KL+NLL: 112463.00  (KL=112347.06, NLL=115.94)
 batch_loss: 473.39  |  KL+NLL: 112451.32  (KL=112349.43, NLL=101.89)
 batch_loss: 447.77  |  KL+NLL: 112610.75  (KL=112505.34, NLL=105.41)


Epoch 10/10:  70%|███████   | 253/360 [00:12<00:05, 21.14it/s]

 batch_loss: 441.85  |  KL+NLL: 112584.41  (KL=112480.72, NLL=103.69)
 batch_loss: 449.96  |  KL+NLL: 112609.53  (KL=112476.15, NLL=133.38)
 batch_loss: 459.20  |  KL+NLL: 112658.23  (KL=112539.58, NLL=118.66)
 batch_loss: 436.02  |  KL+NLL: 112373.65  (KL=112260.82, NLL=112.83)
 batch_loss: 410.05  |  KL+NLL: 112679.29  (KL=112543.48, NLL=135.81)


Epoch 10/10:  71%|███████   | 256/360 [00:12<00:04, 20.93it/s]

 batch_loss: 478.13  |  KL+NLL: 112757.37  (KL=112637.47, NLL=119.90)
 batch_loss: 466.39  |  KL+NLL: 112709.31  (KL=112583.17, NLL=126.14)
 batch_loss: 436.48  |  KL+NLL: 112744.48  (KL=112631.55, NLL=112.93)
 batch_loss: 448.80  |  KL+NLL: 112725.84  (KL=112626.26, NLL=99.58)
 batch_loss: 455.21  |  KL+NLL: 112712.87  (KL=112582.35, NLL=130.52)


Epoch 10/10:  73%|███████▎  | 262/360 [00:12<00:04, 20.93it/s]

 batch_loss: 421.20  |  KL+NLL: 112544.28  (KL=112429.26, NLL=115.02)
 batch_loss: 416.99  |  KL+NLL: 112839.18  (KL=112725.44, NLL=113.74)
 batch_loss: 416.95  |  KL+NLL: 112928.86  (KL=112819.74, NLL=109.12)
 batch_loss: 450.61  |  KL+NLL: 112535.16  (KL=112402.93, NLL=132.23)
 batch_loss: 409.51  |  KL+NLL: 112752.06  (KL=112650.67, NLL=101.39)


Epoch 10/10:  74%|███████▍  | 268/360 [00:13<00:04, 20.37it/s]

 batch_loss: 461.39  |  KL+NLL: 112499.49  (KL=112377.09, NLL=122.40)
 batch_loss: 392.90  |  KL+NLL: 112785.17  (KL=112672.12, NLL=113.05)
 batch_loss: 403.90  |  KL+NLL: 112695.74  (KL=112579.55, NLL=116.18)
 batch_loss: 457.27  |  KL+NLL: 112837.09  (KL=112728.41, NLL=108.67)
 batch_loss: 457.98  |  KL+NLL: 112915.21  (KL=112782.69, NLL=132.53)


Epoch 10/10:  75%|███████▌  | 271/360 [00:13<00:04, 19.98it/s]

 batch_loss: 416.63  |  KL+NLL: 112918.25  (KL=112798.02, NLL=120.23)
 batch_loss: 423.64  |  KL+NLL: 112524.04  (KL=112416.24, NLL=107.79)
 batch_loss: 420.55  |  KL+NLL: 112880.07  (KL=112755.34, NLL=124.72)
 batch_loss: 444.11  |  KL+NLL: 112839.32  (KL=112722.22, NLL=117.11)
 batch_loss: 417.66  |  KL+NLL: 112892.63  (KL=112783.40, NLL=109.23)


Epoch 10/10:  77%|███████▋  | 277/360 [00:13<00:04, 20.55it/s]

 batch_loss: 450.54  |  KL+NLL: 112701.10  (KL=112597.02, NLL=104.08)
 batch_loss: 481.53  |  KL+NLL: 112530.58  (KL=112413.38, NLL=117.19)
 batch_loss: 456.70  |  KL+NLL: 112773.99  (KL=112666.37, NLL=107.62)
 batch_loss: 450.10  |  KL+NLL: 112535.16  (KL=112433.72, NLL=101.44)
 batch_loss: 395.28  |  KL+NLL: 112368.84  (KL=112272.11, NLL=96.73)


Epoch 10/10:  78%|███████▊  | 280/360 [00:13<00:03, 20.38it/s]

 batch_loss: 430.34  |  KL+NLL: 112795.06  (KL=112688.13, NLL=106.92)
 batch_loss: 411.12  |  KL+NLL: 112869.15  (KL=112723.45, NLL=145.69)
 batch_loss: 442.80  |  KL+NLL: 112881.33  (KL=112759.61, NLL=121.72)
 batch_loss: 372.62  |  KL+NLL: 112975.04  (KL=112870.95, NLL=104.08)


Epoch 10/10:  79%|███████▉  | 286/360 [00:14<00:03, 20.47it/s]

 batch_loss: 429.14  |  KL+NLL: 112799.79  (KL=112678.21, NLL=121.58)
 batch_loss: 406.55  |  KL+NLL: 112835.46  (KL=112720.92, NLL=114.54)
 batch_loss: 452.81  |  KL+NLL: 112905.02  (KL=112800.45, NLL=104.57)
 batch_loss: 426.31  |  KL+NLL: 112907.84  (KL=112785.93, NLL=121.91)
 batch_loss: 384.34  |  KL+NLL: 112757.99  (KL=112645.64, NLL=112.35)


Epoch 10/10:  81%|████████  | 292/360 [00:14<00:03, 20.93it/s]

 batch_loss: 396.00  |  KL+NLL: 112518.41  (KL=112414.95, NLL=103.46)
 batch_loss: 448.04  |  KL+NLL: 112780.80  (KL=112682.09, NLL=98.71)
 batch_loss: 411.00  |  KL+NLL: 112881.43  (KL=112776.80, NLL=104.64)
 batch_loss: 410.21  |  KL+NLL: 112907.58  (KL=112781.77, NLL=125.81)
 batch_loss: 393.22  |  KL+NLL: 112650.61  (KL=112526.93, NLL=123.68)


Epoch 10/10:  82%|████████▏ | 295/360 [00:14<00:03, 20.92it/s]

 batch_loss: 409.55  |  KL+NLL: 112892.07  (KL=112768.47, NLL=123.60)
 batch_loss: 377.55  |  KL+NLL: 112645.67  (KL=112540.04, NLL=105.63)
 batch_loss: 389.51  |  KL+NLL: 112913.08  (KL=112797.38, NLL=115.70)
 batch_loss: 405.62  |  KL+NLL: 112743.10  (KL=112635.44, NLL=107.66)
 batch_loss: 417.26  |  KL+NLL: 112768.72  (KL=112651.96, NLL=116.76)


Epoch 10/10:  84%|████████▎ | 301/360 [00:14<00:02, 20.91it/s]

 batch_loss: 430.27  |  KL+NLL: 112700.98  (KL=112608.09, NLL=92.89)
 batch_loss: 416.30  |  KL+NLL: 112631.06  (KL=112510.95, NLL=120.11)
 batch_loss: 399.49  |  KL+NLL: 112400.82  (KL=112303.14, NLL=97.68)
 batch_loss: 403.30  |  KL+NLL: 112731.00  (KL=112627.64, NLL=103.36)
 batch_loss: 431.17  |  KL+NLL: 112749.84  (KL=112637.61, NLL=112.23)


Epoch 10/10:  85%|████████▌ | 307/360 [00:15<00:02, 20.66it/s]

 batch_loss: 439.97  |  KL+NLL: 112783.11  (KL=112678.41, NLL=104.71)
 batch_loss: 430.21  |  KL+NLL: 113074.94  (KL=112958.25, NLL=116.69)
 batch_loss: 419.08  |  KL+NLL: 112840.80  (KL=112709.10, NLL=131.70)
 batch_loss: 387.89  |  KL+NLL: 112878.51  (KL=112772.48, NLL=106.03)
 batch_loss: 400.75  |  KL+NLL: 112877.88  (KL=112739.24, NLL=138.64)


Epoch 10/10:  86%|████████▌ | 310/360 [00:15<00:02, 20.33it/s]

 batch_loss: 404.22  |  KL+NLL: 112836.80  (KL=112703.80, NLL=133.00)
 batch_loss: 372.23  |  KL+NLL: 112811.67  (KL=112692.44, NLL=119.23)
 batch_loss: 402.59  |  KL+NLL: 112780.00  (KL=112666.67, NLL=113.32)
 batch_loss: 420.31  |  KL+NLL: 112847.76  (KL=112681.73, NLL=166.03)
 batch_loss: 380.92  |  KL+NLL: 112978.45  (KL=112866.79, NLL=111.66)


Epoch 10/10:  88%|████████▊ | 316/360 [00:15<00:02, 20.08it/s]

 batch_loss: 393.04  |  KL+NLL: 112902.11  (KL=112797.18, NLL=104.93)
 batch_loss: 402.91  |  KL+NLL: 112948.91  (KL=112851.73, NLL=97.18)
 batch_loss: 398.09  |  KL+NLL: 112981.08  (KL=112869.41, NLL=111.66)
 batch_loss: 423.47  |  KL+NLL: 112915.80  (KL=112784.71, NLL=131.09)
 batch_loss: 397.96  |  KL+NLL: 112735.77  (KL=112630.83, NLL=104.94)


Epoch 10/10:  89%|████████▉ | 322/360 [00:15<00:01, 20.09it/s]

 batch_loss: 412.84  |  KL+NLL: 112694.49  (KL=112580.91, NLL=113.58)
 batch_loss: 428.41  |  KL+NLL: 112960.19  (KL=112852.21, NLL=107.98)
 batch_loss: 385.37  |  KL+NLL: 112729.03  (KL=112611.42, NLL=117.60)
 batch_loss: 380.61  |  KL+NLL: 113040.46  (KL=112932.78, NLL=107.68)
 batch_loss: 377.39  |  KL+NLL: 113237.61  (KL=113122.45, NLL=115.16)


Epoch 10/10:  90%|█████████ | 325/360 [00:15<00:01, 20.39it/s]

 batch_loss: 407.69  |  KL+NLL: 112722.99  (KL=112587.77, NLL=135.22)
 batch_loss: 369.81  |  KL+NLL: 113110.98  (KL=113010.66, NLL=100.31)
 batch_loss: 408.52  |  KL+NLL: 112902.87  (KL=112774.75, NLL=128.12)
 batch_loss: 379.58  |  KL+NLL: 112923.16  (KL=112810.97, NLL=112.19)
 batch_loss: 411.60  |  KL+NLL: 112704.83  (KL=112592.10, NLL=112.73)


Epoch 10/10:  92%|█████████▏| 331/360 [00:16<00:01, 20.36it/s]

 batch_loss: 394.77  |  KL+NLL: 112830.26  (KL=112726.60, NLL=103.66)
 batch_loss: 397.63  |  KL+NLL: 112980.44  (KL=112872.22, NLL=108.22)
 batch_loss: 407.71  |  KL+NLL: 112857.54  (KL=112746.37, NLL=111.18)
 batch_loss: 385.06  |  KL+NLL: 112945.01  (KL=112823.90, NLL=121.11)
 batch_loss: 417.12  |  KL+NLL: 112864.91  (KL=112756.89, NLL=108.02)


Epoch 10/10:  94%|█████████▎| 337/360 [00:16<00:01, 20.56it/s]

 batch_loss: 426.10  |  KL+NLL: 112880.80  (KL=112766.43, NLL=114.37)
 batch_loss: 367.02  |  KL+NLL: 112852.24  (KL=112728.50, NLL=123.74)
 batch_loss: 423.77  |  KL+NLL: 112877.01  (KL=112773.33, NLL=103.68)
 batch_loss: 383.09  |  KL+NLL: 113001.15  (KL=112876.07, NLL=125.08)
 batch_loss: 388.66  |  KL+NLL: 112837.49  (KL=112722.36, NLL=115.13)


Epoch 10/10:  94%|█████████▍| 340/360 [00:16<00:00, 20.51it/s]

 batch_loss: 377.31  |  KL+NLL: 113192.28  (KL=113069.84, NLL=122.44)
 batch_loss: 430.11  |  KL+NLL: 112801.74  (KL=112689.84, NLL=111.90)
 batch_loss: 369.43  |  KL+NLL: 112895.11  (KL=112762.16, NLL=132.95)
 batch_loss: 432.66  |  KL+NLL: 112852.05  (KL=112741.46, NLL=110.59)


Epoch 10/10:  96%|█████████▌| 346/360 [00:17<00:00, 20.20it/s]

 batch_loss: 343.54  |  KL+NLL: 112835.27  (KL=112709.03, NLL=126.24)
 batch_loss: 388.97  |  KL+NLL: 112991.40  (KL=112888.21, NLL=103.19)
 batch_loss: 395.88  |  KL+NLL: 112893.65  (KL=112776.02, NLL=117.63)
 batch_loss: 387.02  |  KL+NLL: 112846.68  (KL=112737.21, NLL=109.47)
 batch_loss: 393.33  |  KL+NLL: 113015.32  (KL=112880.49, NLL=134.83)


Epoch 10/10:  97%|█████████▋| 349/360 [00:17<00:00, 20.21it/s]

 batch_loss: 387.90  |  KL+NLL: 113051.69  (KL=112937.50, NLL=114.19)
 batch_loss: 353.67  |  KL+NLL: 113004.51  (KL=112890.91, NLL=113.59)
 batch_loss: 435.48  |  KL+NLL: 112724.07  (KL=112604.72, NLL=119.35)
 batch_loss: 376.48  |  KL+NLL: 112886.39  (KL=112738.97, NLL=147.42)
 batch_loss: 373.17  |  KL+NLL: 113013.15  (KL=112866.15, NLL=147.00)


Epoch 10/10:  99%|█████████▊| 355/360 [00:17<00:00, 20.89it/s]

 batch_loss: 400.38  |  KL+NLL: 112836.34  (KL=112689.55, NLL=146.79)
 batch_loss: 358.12  |  KL+NLL: 112924.94  (KL=112832.34, NLL=92.61)
 batch_loss: 384.12  |  KL+NLL: 113246.97  (KL=113129.37, NLL=117.60)
 batch_loss: 385.24  |  KL+NLL: 113095.90  (KL=113000.00, NLL=95.90)
 batch_loss: 372.96  |  KL+NLL: 112988.61  (KL=112878.38, NLL=110.23)


Epoch 10/10: 100%|██████████| 360/360 [00:17<00:00, 20.38it/s]


 batch_loss: 374.43  |  KL+NLL: 113159.55  (KL=113036.30, NLL=123.24)
 batch_loss: 389.71  |  KL+NLL: 113083.71  (KL=112963.16, NLL=120.54)
 batch_loss: 428.87  |  KL+NLL: 113007.77  (KL=112892.84, NLL=114.92)
 batch_loss: 401.30  |  KL+NLL: 112996.53  (KL=112861.52, NLL=135.00)
Epoch 10 - ELBO Loss: 491.7120


Acc check epoch 10: 100%|██████████| 360/360 [00:04<00:00, 79.42it/s]

Epoch 10 - Train Acc: 18.52%
AutoDiagonalNormal.loc: [ 0.00305209  0.70558995 -0.05031909 ... -1.1473572  -1.7199486
 -1.7610376 ]
AutoDiagonalNormal.scale: [1.9176738  1.8905535  1.9324698  ... 0.49652037 1.0352587  1.0843289 ]
Configuration saved to config_relu_gaussian_20250715_100559.json





In [13]:
def plot_training_results_with_stats(losses, accuracies, accuracy_epochs, weight_stats, bias_stats, act_name, prior_name, timestamp):
    """Plot training results with weight and bias statistics"""
    plt.figure(figsize=(16, 12))
    
    # Plot 1: Training Loss
    plt.subplot(2, 2, 1)
    plt.plot(range(1, len(losses) + 1), losses)
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('ELBO Loss')
    plt.grid(True)
    
    # Plot 2: Training Accuracy
    plt.subplot(2, 2, 2)
    plt.plot(accuracy_epochs, accuracies, 'o-')
    plt.title('Training Accuracy (Every 10 Epochs)')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.grid(True)
    
    # Plot 3: Weight Statistics Boxplot
    plt.subplot(2, 2, 3)
    weight_data = []
    weight_labels = []
    
    for i, epoch in enumerate(weight_stats['epochs']):
        # Combine means and stds for this epoch
        epoch_data = weight_stats['means'][i] + weight_stats['stds'][i]
        weight_data.append(epoch_data)
        weight_labels.append(f'Epoch {epoch}')
    
    if weight_data:
        bp1 = plt.boxplot(weight_data, labels=weight_labels, patch_artist=True)
        for patch in bp1['boxes']:
            patch.set_facecolor('lightblue')
    
    plt.title('LOC Statistics Distribution')
    plt.xlabel('Epoch')
    plt.ylabel('LOC Values')
    plt.xticks(rotation=45)
    plt.grid(True, alpha=0.3)
    
    # Plot 4: Bias Statistics Boxplot
    plt.subplot(2, 2, 4)
    bias_data = []
    bias_labels = []
    
    for i, epoch in enumerate(bias_stats['epochs']):
        # Combine means and stds for this epoch
        epoch_data = bias_stats['means'][i] + bias_stats['stds'][i]
        bias_data.append(epoch_data)
        bias_labels.append(f'Epoch {epoch}')
    
    if bias_data:
        bp2 = plt.boxplot(bias_data, tick_labels=bias_labels, patch_artist=True)
        for patch in bp2['boxes']:
            patch.set_facecolor('lightcoral')
    
    plt.title('SCALE Statistics Distribution')
    plt.xlabel('Epoch')
    plt.ylabel('SCALE Values')
    plt.xticks(rotation=45)
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f'results_GP_eurosat/bayesian_cnn_training_results_{act_name}_{prior_name}_{timestamp}.png')
    plt.show()

In [14]:
act_name = bayesian_model.activation_fn.__name__ if hasattr(bayesian_model.activation_fn, '__name__') else str(bayesian_model.activation_fn)
prior_name = getattr(bayesian_model, 'prior_dist', 'prior')

In [None]:
plot_training_results_with_stats(losses, accuracies, accuracy_epochs, weight_stats, bias_stats, act_name, prior_name, experiment_timestamp)

In [None]:
# clear the parameter store and reload the parameter store from the best result
pyro.clear_param_store()

bayesian_model.load_state_dict(torch.load(best_model_path))
guide.load_state_dict(torch.load(best_guide_path))
pyro.get_param_store().set_state(torch.load(best_param_store_path,weights_only=False))

In [None]:
# print confusion matrix
import numpy as np
from sklearn.metrics import confusion_matrix


def predict_data(model, loader_of_interest, num_samples=10):
    model.eval()
    guide.eval()

    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for images, labels in tqdm(loader_of_interest, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)

            logits_mc = torch.zeros(num_samples, images.size(0), model.fc1.out_features).to(device)

            for i in range(num_samples):
                guide_trace = pyro.poutine.trace(guide).get_trace(images)
                replayed_model = pyro.poutine.replay(model, trace=guide_trace)
                logits = replayed_model(images)
                logits_mc[i] = logits

            avg_logits = logits_mc.mean(dim=0)
            predictions = torch.argmax(avg_logits, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predictions.cpu().numpy())

    return all_labels, all_predictions

In [None]:
all_labels, all_predictions = predict_data(bayesian_model, test_loader, num_samples=10)
cm = confusion_matrix(all_labels, all_predictions)
#print accuracy from confusion matrix
accuracy = np.trace(cm) / np.sum(cm)
print(f"Accuracy from confusion matrix: {accuracy * 100:.6f}%")

In [None]:
df = pd.DataFrame({'True Label': all_labels, 'Predicted Label': all_predictions})

In [None]:
#save the prediction label and true label in a csv file

def save_predictions_to_csv(labels, predictions, filename='predictions.csv'):
    df = pd.DataFrame({'True Label': labels, 'Predicted Label': predictions})
    df.to_csv(filename, index=False)
    print(f"Predictions saved to {filename}")

In [None]:
save_predictions_to_csv(all_labels, all_predictions, os.path.join('results_GP_eurosat', f'predictions_{act_name}_{prior_name}_{experiment_timestamp}_{accuracy * 100:.0f}.csv'))

In [None]:
# print pyro param store values
for name, param in pyro.get_param_store().items():
    if 'loc' in name or 'scale' in name:
        print(f"{name}: {param.detach().cpu().numpy()}")