In [1]:
import logging
import time
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scarches as sca
import scvi
from scvi.model import SCVI
import scipy.sparse
import anndata
import os
from scib.metrics import metrics
from lataq_reproduce.exp_dict import EXPERIMENT_INFO
from lataq_reproduce.utils import label_encoder

# 自定义 remove_sparsity
def custom_remove_sparsity(adata):
    if scipy.sparse.issparse(adata.X):
        X_dense = adata.X.toarray().astype(np.int64)
        return sc.AnnData(X=X_dense, obs=adata.obs.copy(deep=True), var=adata.var.copy(deep=True))
    return adata

# 设置日志
logging.basicConfig(level=logging.INFO)

# 数据和结果路径
DATA_DIR = "pancreas.h5ad"  # 修改为你的 pancreas.h5ad 文件路径
RES_PATH = "result"  # 结果保存路径，确保目录存在
data = "pancreas"
EXP_PARAMS = EXPERIMENT_INFO[data]
FILE_NAME = EXP_PARAMS["file_name"]

# 加载数据
logging.info(f"Loading dataset: {data}")
adata = sc.read("pancreas.h5ad")
condition_key = EXP_PARAMS["condition_key"]
cell_type_key = EXP_PARAMS["cell_type_key"]
reference = EXP_PARAMS["reference"]
query = EXP_PARAMS["query"]
print("Type of adata.X:", type(adata.X))
# 数据处理
import scipy.sparse
import scanpy as sc

def custom_remove_sparsity(adata):
    if scipy.sparse.issparse(adata.X):
        return sc.AnnData(X=adata.X.toarray(), obs=adata.obs.copy(deep=True), var=adata.var.copy(deep=True))
    return adata
adata = custom_remove_sparsity(adata)
source_adata = adata[adata.obs[condition_key].isin(reference)].copy()
target_adata = adata[adata.obs[condition_key].isin(query)].copy()
logging.info("Data loaded and processed successfully")
# 在 setup_anndata 之前
print("Batch categories:", source_adata.obs[condition_key].value_counts())
print("Label categories:", source_adata.obs[cell_type_key[0]].value_counts())
# 设置 anndata
SCVI.setup_anndata(source_adata, batch_key=condition_key, labels_key=cell_type_key[0])


# 训练参考模型
vae_ref = sca.models.SCVI(source_adata)
ref_time = time.time()
vae_ref.train()
vae_ref_scan = sca.models.SCANVI.from_scvi_model(vae_ref, unlabeled_category="Unknown")
vae_ref_scan.train(max_epochs=20)
ref_time = time.time() - ref_time
vae_ref_scan.save(f"{RES_PATH}/scanvi_model", overwrite=True)
logging.info("Reference model trained and saved")

# 训练查询模型
vae_q = sca.models.SCANVI.load_query_data(target_adata, f"{RES_PATH}/scanvi_model")
vae_q._unlabeled_indices = np.arange(target_adata.n_obs)
vae_q._labeled_indices = []
query_time = time.time()
vae_q.train(max_epochs=100, plan_kwargs=dict(weight_decay=0.0), check_val_every_n_epoch=10)
query_time = time.time() - query_time
vae_q.save(f"{RES_PATH}/scanvi_query_model", overwrite=True)
logging.info("Query model trained and saved")

# 评估整合指标
scores = None
try:
    print("target_adata shape:", target_adata.shape)
    print("source_adata shape:", source_adata.shape)
    adata_full = anndata.concat([target_adata, source_adata], axis=0, join="outer", label="query", keys=["Query", "Reference"])
    print("adata_full shape:", adata_full.shape)
    print("adata_full.obs columns:", adata_full.obs.columns)

    # 获取潜在表示
    adata_latent_full = sc.AnnData(vae_q.get_latent_representation(adata_full))
    adata_latent_full.obs["batch"] = adata_full.obs[condition_key].tolist()
    adata_latent_full.obs["celltype"] = adata_full.obs[cell_type_key[0]].tolist()

    # 集成指标计算
    conditions, _ = label_encoder(adata, condition_key=condition_key)
    labels, _ = label_encoder(adata, condition_key=cell_type_key[0])
    adata.obs["batch"] = conditions.squeeze(axis=1)
    adata.obs["celltype"] = labels.squeeze(axis=1)
    adata.obs["batch"] = adata.obs["batch"].astype("category")
    adata.obs["celltype"] = adata.obs["celltype"].astype("category")
    sc.pp.pca(adata)
    sc.pp.pca(adata_latent_full)

    adata.write(f"{RES_PATH}/adata_original.h5ad")
    adata_latent_full.write(f"{RES_PATH}/adata_latent.h5ad")

    scores = metrics(
        adata,
        adata_latent_full,
        "batch",
        "celltype",
        isolated_labels_asw_=True,
        silhouette_=True,
        graph_conn_=True,
        pcr_=True,
        isolated_labels_f1_=True,
        nmi_=True,
        ari_=True,
    )
    scores = scores.T
    scores = scores[
        [
            "NMI_cluster/label",
            "ARI_cluster/label",
            "ASW_label",
            "ASW_label/batch",
            "PCR_batch",
            "isolated_label_F1",
            "isolated_label_silhouette",
            "graph_conn",
        ]
    ]
    print("Integration Scores:\n", scores)
except Exception as e:
    print(f"Error in computing integration scores: {e}")

# 保存结果
results = {
    "reference_time": ref_time,
    "query_time": query_time,
    "integration_scores": scores,
}
logging.info("Results computed successfully")
print("\nResults:")
print("Reference Time (seconds):", results["reference_time"])
print("Query Time (seconds):", results["query_time"])
print("Integration Scores:\n", results["integration_scores"])

# 保存 integration_scores
if scores is not None:
    scores.to_csv(f"{RES_PATH}/integration_scores.csv")

 captum (see https://github.com/pytorch/captum).


Type of adata.X: <class 'scipy.sparse._csr.csr_matrix'>
Batch categories: study
inDrop3       3605
smartseq2     2394
inDrop1       1937
inDrop2       1724
smarter       1492
inDrop4       1303
fluidigmc1     638
Name: count, dtype: int64
Label categories: cell_type
alpha                 4459
beta                  3563
ductal                1557
acinar                1167
delta                  802
gamma                  571
activated_stellate     355
endothelial            287
quiescent_stellate     180
macrophage              63
mast                    35
epsilon                 27
schwann                 20
t_cell                   7
Name: count, dtype: int64


  self.validate_field(adata)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/root/miniconda3/envs/cmml/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Epoch 1/400:   0%|          | 0/400 [00:00<?, ?it/s]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 2/400:   0%|          | 1/400 [00:01<10:27,  1.57s/it, v_num=1, train_loss_step=1.05e+3, train_loss_epoch=1.29e+3]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 3/400:   0%|          | 2/400 [00:02<08:57,  1.35s/it, v_num=1, train_loss_step=1.13e+3, train_loss_epoch=1.1e+3] 

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 4/400:   1%|          | 3/400 [00:03<08:30,  1.29s/it, v_num=1, train_loss_step=1.05e+3, train_loss_epoch=1.07e+3]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 5/400:   1%|          | 4/400 [00:05<08:17,  1.26s/it, v_num=1, train_loss_step=888, train_loss_epoch=1.05e+3]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 6/400:   1%|▏         | 5/400 [00:06<08:10,  1.24s/it, v_num=1, train_loss_step=1.08e+3, train_loss_epoch=1.04e+3]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 7/400:   2%|▏         | 6/400 [00:07<08:06,  1.24s/it, v_num=1, train_loss_step=1.15e+3, train_loss_epoch=1.03e+3]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 8/400:   2%|▏         | 7/400 [00:08<07:55,  1.21s/it, v_num=1, train_loss_step=959, train_loss_epoch=1.02e+3]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 9/400:   2%|▏         | 8/400 [00:10<07:55,  1.21s/it, v_num=1, train_loss_step=979, train_loss_epoch=1.01e+3]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 10/400:   2%|▏         | 9/400 [00:11<07:56,  1.22s/it, v_num=1, train_loss_step=927, train_loss_epoch=1.01e+3]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 11/400:   2%|▎         | 10/400 [00:12<07:56,  1.22s/it, v_num=1, train_loss_step=1e+3, train_loss_epoch=1e+3]  

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 12/400:   3%|▎         | 11/400 [00:13<07:56,  1.22s/it, v_num=1, train_loss_step=1.09e+3, train_loss_epoch=999]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 13/400:   3%|▎         | 12/400 [00:14<07:56,  1.23s/it, v_num=1, train_loss_step=1.07e+3, train_loss_epoch=995]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 14/400:   3%|▎         | 13/400 [00:16<07:55,  1.23s/it, v_num=1, train_loss_step=1.1e+3, train_loss_epoch=991] 

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 15/400:   4%|▎         | 14/400 [00:17<07:52,  1.22s/it, v_num=1, train_loss_step=1.04e+3, train_loss_epoch=987]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 16/400:   4%|▍         | 15/400 [00:18<07:49,  1.22s/it, v_num=1, train_loss_step=947, train_loss_epoch=984]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 17/400:   4%|▍         | 16/400 [00:19<07:46,  1.22s/it, v_num=1, train_loss_step=1.09e+3, train_loss_epoch=981]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 18/400:   4%|▍         | 17/400 [00:21<07:49,  1.23s/it, v_num=1, train_loss_step=858, train_loss_epoch=978]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 19/400:   4%|▍         | 18/400 [00:22<07:50,  1.23s/it, v_num=1, train_loss_step=976, train_loss_epoch=976]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 20/400:   5%|▍         | 19/400 [00:23<07:47,  1.23s/it, v_num=1, train_loss_step=1.17e+3, train_loss_epoch=973]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 21/400:   5%|▌         | 20/400 [00:24<07:45,  1.23s/it, v_num=1, train_loss_step=896, train_loss_epoch=971]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 22/400:   5%|▌         | 21/400 [00:25<07:44,  1.23s/it, v_num=1, train_loss_step=1.19e+3, train_loss_epoch=969]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 23/400:   6%|▌         | 22/400 [00:27<07:43,  1.23s/it, v_num=1, train_loss_step=1.13e+3, train_loss_epoch=967]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 24/400:   6%|▌         | 23/400 [00:28<07:44,  1.23s/it, v_num=1, train_loss_step=866, train_loss_epoch=965]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 25/400:   6%|▌         | 24/400 [00:29<07:44,  1.24s/it, v_num=1, train_loss_step=1.03e+3, train_loss_epoch=963]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 26/400:   6%|▋         | 25/400 [00:30<07:45,  1.24s/it, v_num=1, train_loss_step=844, train_loss_epoch=961]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 27/400:   6%|▋         | 26/400 [00:32<07:45,  1.24s/it, v_num=1, train_loss_step=1.03e+3, train_loss_epoch=960]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 28/400:   7%|▋         | 27/400 [00:33<07:47,  1.25s/it, v_num=1, train_loss_step=982, train_loss_epoch=958]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 29/400:   7%|▋         | 28/400 [00:34<07:53,  1.27s/it, v_num=1, train_loss_step=1.17e+3, train_loss_epoch=956]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 30/400:   7%|▋         | 29/400 [00:36<07:55,  1.28s/it, v_num=1, train_loss_step=870, train_loss_epoch=955]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 31/400:   8%|▊         | 30/400 [00:37<07:52,  1.28s/it, v_num=1, train_loss_step=1.1e+3, train_loss_epoch=954]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 32/400:   8%|▊         | 31/400 [00:38<07:49,  1.27s/it, v_num=1, train_loss_step=951, train_loss_epoch=952]   

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 33/400:   8%|▊         | 32/400 [00:39<07:49,  1.27s/it, v_num=1, train_loss_step=1.04e+3, train_loss_epoch=951]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 34/400:   8%|▊         | 33/400 [00:41<07:46,  1.27s/it, v_num=1, train_loss_step=845, train_loss_epoch=950]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 35/400:   8%|▊         | 34/400 [00:42<07:44,  1.27s/it, v_num=1, train_loss_step=1.02e+3, train_loss_epoch=949]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 36/400:   9%|▉         | 35/400 [00:43<07:45,  1.28s/it, v_num=1, train_loss_step=899, train_loss_epoch=948]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 37/400:   9%|▉         | 36/400 [00:45<07:48,  1.29s/it, v_num=1, train_loss_step=1.11e+3, train_loss_epoch=947]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 38/400:   9%|▉         | 37/400 [00:46<07:45,  1.28s/it, v_num=1, train_loss_step=957, train_loss_epoch=946]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 39/400:  10%|▉         | 38/400 [00:47<07:44,  1.28s/it, v_num=1, train_loss_step=992, train_loss_epoch=945]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 40/400:  10%|▉         | 39/400 [00:48<07:42,  1.28s/it, v_num=1, train_loss_step=1.03e+3, train_loss_epoch=944]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 41/400:  10%|█         | 40/400 [00:49<07:17,  1.22s/it, v_num=1, train_loss_step=890, train_loss_epoch=943]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 42/400:  10%|█         | 41/400 [00:51<07:15,  1.21s/it, v_num=1, train_loss_step=981, train_loss_epoch=942]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 43/400:  10%|█         | 42/400 [00:52<07:22,  1.24s/it, v_num=1, train_loss_step=1.15e+3, train_loss_epoch=942]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 44/400:  11%|█         | 43/400 [00:53<07:28,  1.26s/it, v_num=1, train_loss_step=1.06e+3, train_loss_epoch=941]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 45/400:  11%|█         | 44/400 [00:55<07:35,  1.28s/it, v_num=1, train_loss_step=1e+3, train_loss_epoch=940]   

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 46/400:  11%|█▏        | 45/400 [00:56<07:34,  1.28s/it, v_num=1, train_loss_step=1.06e+3, train_loss_epoch=940]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 47/400:  12%|█▏        | 46/400 [00:57<07:34,  1.28s/it, v_num=1, train_loss_step=1.02e+3, train_loss_epoch=939]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 48/400:  12%|█▏        | 47/400 [00:58<07:33,  1.29s/it, v_num=1, train_loss_step=904, train_loss_epoch=938]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 49/400:  12%|█▏        | 48/400 [01:00<07:29,  1.28s/it, v_num=1, train_loss_step=925, train_loss_epoch=938]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 50/400:  12%|█▏        | 49/400 [01:01<07:37,  1.30s/it, v_num=1, train_loss_step=827, train_loss_epoch=937]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 51/400:  12%|█▎        | 50/400 [01:02<07:34,  1.30s/it, v_num=1, train_loss_step=1.06e+3, train_loss_epoch=937]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 52/400:  13%|█▎        | 51/400 [01:04<07:31,  1.30s/it, v_num=1, train_loss_step=1.17e+3, train_loss_epoch=936]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 53/400:  13%|█▎        | 52/400 [01:05<07:30,  1.29s/it, v_num=1, train_loss_step=1.08e+3, train_loss_epoch=936]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 54/400:  13%|█▎        | 53/400 [01:06<07:35,  1.31s/it, v_num=1, train_loss_step=1.06e+3, train_loss_epoch=935]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 55/400:  14%|█▎        | 54/400 [01:08<07:38,  1.33s/it, v_num=1, train_loss_step=952, train_loss_epoch=935]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 56/400:  14%|█▍        | 55/400 [01:09<07:36,  1.32s/it, v_num=1, train_loss_step=921, train_loss_epoch=934]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 57/400:  14%|█▍        | 56/400 [01:10<07:30,  1.31s/it, v_num=1, train_loss_step=1e+3, train_loss_epoch=934]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 58/400:  14%|█▍        | 57/400 [01:11<07:21,  1.29s/it, v_num=1, train_loss_step=1.06e+3, train_loss_epoch=934]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 59/400:  14%|█▍        | 58/400 [01:13<07:13,  1.27s/it, v_num=1, train_loss_step=1.06e+3, train_loss_epoch=933]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 60/400:  15%|█▍        | 59/400 [01:14<07:11,  1.27s/it, v_num=1, train_loss_step=1.02e+3, train_loss_epoch=933]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 61/400:  15%|█▌        | 60/400 [01:15<07:08,  1.26s/it, v_num=1, train_loss_step=908, train_loss_epoch=932]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 62/400:  15%|█▌        | 61/400 [01:16<07:05,  1.26s/it, v_num=1, train_loss_step=871, train_loss_epoch=932]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 63/400:  16%|█▌        | 62/400 [01:18<07:00,  1.24s/it, v_num=1, train_loss_step=1.12e+3, train_loss_epoch=932]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 64/400:  16%|█▌        | 63/400 [01:19<06:57,  1.24s/it, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=931]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 65/400:  16%|█▌        | 64/400 [01:20<06:55,  1.24s/it, v_num=1, train_loss_step=1.03e+3, train_loss_epoch=931]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 66/400:  16%|█▋        | 65/400 [01:21<06:52,  1.23s/it, v_num=1, train_loss_step=944, train_loss_epoch=931]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 67/400:  16%|█▋        | 66/400 [01:23<06:49,  1.23s/it, v_num=1, train_loss_step=933, train_loss_epoch=930]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 68/400:  17%|█▋        | 67/400 [01:24<06:47,  1.22s/it, v_num=1, train_loss_step=940, train_loss_epoch=930]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 69/400:  17%|█▋        | 68/400 [01:25<06:47,  1.23s/it, v_num=1, train_loss_step=920, train_loss_epoch=929]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 70/400:  17%|█▋        | 69/400 [01:26<06:45,  1.22s/it, v_num=1, train_loss_step=1.03e+3, train_loss_epoch=929]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 71/400:  18%|█▊        | 70/400 [01:27<06:41,  1.22s/it, v_num=1, train_loss_step=981, train_loss_epoch=929]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 72/400:  18%|█▊        | 71/400 [01:29<06:41,  1.22s/it, v_num=1, train_loss_step=978, train_loss_epoch=929]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 73/400:  18%|█▊        | 72/400 [01:30<06:42,  1.23s/it, v_num=1, train_loss_step=1.05e+3, train_loss_epoch=929]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 74/400:  18%|█▊        | 73/400 [01:31<06:38,  1.22s/it, v_num=1, train_loss_step=935, train_loss_epoch=928]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 75/400:  18%|█▊        | 74/400 [01:32<06:41,  1.23s/it, v_num=1, train_loss_step=1.24e+3, train_loss_epoch=928]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 76/400:  19%|█▉        | 75/400 [01:34<06:40,  1.23s/it, v_num=1, train_loss_step=887, train_loss_epoch=928]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 77/400:  19%|█▉        | 76/400 [01:35<06:37,  1.23s/it, v_num=1, train_loss_step=965, train_loss_epoch=928]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 78/400:  19%|█▉        | 77/400 [01:36<06:34,  1.22s/it, v_num=1, train_loss_step=956, train_loss_epoch=927]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 79/400:  20%|█▉        | 78/400 [01:37<06:32,  1.22s/it, v_num=1, train_loss_step=1.02e+3, train_loss_epoch=927]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 80/400:  20%|█▉        | 79/400 [01:38<06:31,  1.22s/it, v_num=1, train_loss_step=924, train_loss_epoch=927]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 81/400:  20%|██        | 80/400 [01:40<06:33,  1.23s/it, v_num=1, train_loss_step=1.07e+3, train_loss_epoch=927]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 82/400:  20%|██        | 81/400 [01:41<06:35,  1.24s/it, v_num=1, train_loss_step=1.1e+3, train_loss_epoch=927] 

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 83/400:  20%|██        | 82/400 [01:42<06:31,  1.23s/it, v_num=1, train_loss_step=909, train_loss_epoch=927]   

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 84/400:  21%|██        | 83/400 [01:43<06:30,  1.23s/it, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=926]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 85/400:  21%|██        | 84/400 [01:45<06:29,  1.23s/it, v_num=1, train_loss_step=925, train_loss_epoch=926]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 86/400:  21%|██▏       | 85/400 [01:46<06:30,  1.24s/it, v_num=1, train_loss_step=1.07e+3, train_loss_epoch=926]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 87/400:  22%|██▏       | 86/400 [01:47<06:29,  1.24s/it, v_num=1, train_loss_step=1.09e+3, train_loss_epoch=926]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 88/400:  22%|██▏       | 87/400 [01:48<06:30,  1.25s/it, v_num=1, train_loss_step=951, train_loss_epoch=926]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 89/400:  22%|██▏       | 88/400 [01:50<06:32,  1.26s/it, v_num=1, train_loss_step=1.04e+3, train_loss_epoch=925]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 90/400:  22%|██▏       | 89/400 [01:51<06:28,  1.25s/it, v_num=1, train_loss_step=906, train_loss_epoch=926]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 91/400:  22%|██▎       | 90/400 [01:52<06:24,  1.24s/it, v_num=1, train_loss_step=974, train_loss_epoch=925]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 92/400:  23%|██▎       | 91/400 [01:53<06:22,  1.24s/it, v_num=1, train_loss_step=975, train_loss_epoch=925]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 93/400:  23%|██▎       | 92/400 [01:55<06:20,  1.24s/it, v_num=1, train_loss_step=1.04e+3, train_loss_epoch=925]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 94/400:  23%|██▎       | 93/400 [01:56<06:17,  1.23s/it, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=925]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 95/400:  24%|██▎       | 94/400 [01:57<06:15,  1.23s/it, v_num=1, train_loss_step=1.04e+3, train_loss_epoch=925]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 96/400:  24%|██▍       | 95/400 [01:58<06:09,  1.21s/it, v_num=1, train_loss_step=886, train_loss_epoch=925]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 97/400:  24%|██▍       | 96/400 [01:59<06:12,  1.23s/it, v_num=1, train_loss_step=1.17e+3, train_loss_epoch=924]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 98/400:  24%|██▍       | 97/400 [02:01<06:13,  1.23s/it, v_num=1, train_loss_step=1.23e+3, train_loss_epoch=925]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 99/400:  24%|██▍       | 98/400 [02:02<06:12,  1.23s/it, v_num=1, train_loss_step=888, train_loss_epoch=925]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 100/400:  25%|██▍       | 99/400 [02:03<06:12,  1.24s/it, v_num=1, train_loss_step=1.1e+3, train_loss_epoch=924]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 101/400:  25%|██▌       | 100/400 [02:04<06:10,  1.23s/it, v_num=1, train_loss_step=899, train_loss_epoch=924]   

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 102/400:  25%|██▌       | 101/400 [02:06<06:08,  1.23s/it, v_num=1, train_loss_step=941, train_loss_epoch=924]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 103/400:  26%|██▌       | 102/400 [02:07<06:04,  1.22s/it, v_num=1, train_loss_step=1.09e+3, train_loss_epoch=924]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 104/400:  26%|██▌       | 103/400 [02:08<06:02,  1.22s/it, v_num=1, train_loss_step=966, train_loss_epoch=923]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 105/400:  26%|██▌       | 104/400 [02:09<06:03,  1.23s/it, v_num=1, train_loss_step=964, train_loss_epoch=923]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 106/400:  26%|██▋       | 105/400 [02:11<06:01,  1.22s/it, v_num=1, train_loss_step=967, train_loss_epoch=923]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 107/400:  26%|██▋       | 106/400 [02:12<06:00,  1.23s/it, v_num=1, train_loss_step=1.09e+3, train_loss_epoch=923]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 108/400:  27%|██▋       | 107/400 [02:13<05:58,  1.22s/it, v_num=1, train_loss_step=867, train_loss_epoch=923]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 109/400:  27%|██▋       | 108/400 [02:14<05:58,  1.23s/it, v_num=1, train_loss_step=826, train_loss_epoch=923]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 110/400:  27%|██▋       | 109/400 [02:15<05:55,  1.22s/it, v_num=1, train_loss_step=903, train_loss_epoch=923]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 111/400:  28%|██▊       | 110/400 [02:17<05:52,  1.22s/it, v_num=1, train_loss_step=910, train_loss_epoch=923]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 112/400:  28%|██▊       | 111/400 [02:18<05:51,  1.22s/it, v_num=1, train_loss_step=916, train_loss_epoch=923]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 113/400:  28%|██▊       | 112/400 [02:19<05:52,  1.23s/it, v_num=1, train_loss_step=1.14e+3, train_loss_epoch=923]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 114/400:  28%|██▊       | 113/400 [02:20<05:50,  1.22s/it, v_num=1, train_loss_step=1.11e+3, train_loss_epoch=923]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 115/400:  28%|██▊       | 114/400 [02:21<05:48,  1.22s/it, v_num=1, train_loss_step=933, train_loss_epoch=923]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 116/400:  29%|██▉       | 115/400 [02:23<05:48,  1.22s/it, v_num=1, train_loss_step=824, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 117/400:  29%|██▉       | 116/400 [02:24<05:46,  1.22s/it, v_num=1, train_loss_step=820, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 118/400:  29%|██▉       | 117/400 [02:25<05:45,  1.22s/it, v_num=1, train_loss_step=901, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 119/400:  30%|██▉       | 118/400 [02:26<05:43,  1.22s/it, v_num=1, train_loss_step=938, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 120/400:  30%|██▉       | 119/400 [02:28<05:41,  1.22s/it, v_num=1, train_loss_step=859, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 121/400:  30%|███       | 120/400 [02:29<05:43,  1.23s/it, v_num=1, train_loss_step=911, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 122/400:  30%|███       | 121/400 [02:30<05:41,  1.22s/it, v_num=1, train_loss_step=959, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 123/400:  30%|███       | 122/400 [02:31<05:40,  1.22s/it, v_num=1, train_loss_step=925, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 124/400:  31%|███       | 123/400 [02:33<05:39,  1.23s/it, v_num=1, train_loss_step=939, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 125/400:  31%|███       | 124/400 [02:34<05:37,  1.22s/it, v_num=1, train_loss_step=911, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 126/400:  31%|███▏      | 125/400 [02:35<05:35,  1.22s/it, v_num=1, train_loss_step=841, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 127/400:  32%|███▏      | 126/400 [02:36<05:35,  1.22s/it, v_num=1, train_loss_step=959, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 128/400:  32%|███▏      | 127/400 [02:37<05:33,  1.22s/it, v_num=1, train_loss_step=1.04e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 129/400:  32%|███▏      | 128/400 [02:39<05:34,  1.23s/it, v_num=1, train_loss_step=1.14e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 130/400:  32%|███▏      | 129/400 [02:40<05:31,  1.22s/it, v_num=1, train_loss_step=1.13e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 131/400:  32%|███▎      | 130/400 [02:41<05:31,  1.23s/it, v_num=1, train_loss_step=896, train_loss_epoch=922]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 132/400:  33%|███▎      | 131/400 [02:42<05:30,  1.23s/it, v_num=1, train_loss_step=961, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 133/400:  33%|███▎      | 132/400 [02:44<05:29,  1.23s/it, v_num=1, train_loss_step=987, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 134/400:  33%|███▎      | 133/400 [02:45<05:27,  1.22s/it, v_num=1, train_loss_step=1.08e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 135/400:  34%|███▎      | 134/400 [02:46<05:22,  1.21s/it, v_num=1, train_loss_step=983, train_loss_epoch=922]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 136/400:  34%|███▍      | 135/400 [02:47<05:20,  1.21s/it, v_num=1, train_loss_step=931, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 137/400:  34%|███▍      | 136/400 [02:48<05:18,  1.21s/it, v_num=1, train_loss_step=808, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 138/400:  34%|███▍      | 137/400 [02:50<05:18,  1.21s/it, v_num=1, train_loss_step=935, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 139/400:  34%|███▍      | 138/400 [02:51<05:16,  1.21s/it, v_num=1, train_loss_step=1.49e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 140/400:  35%|███▍      | 139/400 [02:52<05:13,  1.20s/it, v_num=1, train_loss_step=883, train_loss_epoch=922]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 141/400:  35%|███▌      | 140/400 [02:53<04:55,  1.14s/it, v_num=1, train_loss_step=1.04e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 142/400:  35%|███▌      | 141/400 [02:54<05:00,  1.16s/it, v_num=1, train_loss_step=957, train_loss_epoch=921]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 143/400:  36%|███▌      | 142/400 [02:55<05:04,  1.18s/it, v_num=1, train_loss_step=913, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 144/400:  36%|███▌      | 143/400 [02:57<05:09,  1.20s/it, v_num=1, train_loss_step=936, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 145/400:  36%|███▌      | 144/400 [02:58<05:09,  1.21s/it, v_num=1, train_loss_step=985, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 146/400:  36%|███▋      | 145/400 [02:59<05:09,  1.22s/it, v_num=1, train_loss_step=1.04e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 147/400:  36%|███▋      | 146/400 [03:00<05:10,  1.22s/it, v_num=1, train_loss_step=939, train_loss_epoch=921]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 148/400:  37%|███▋      | 147/400 [03:02<05:10,  1.23s/it, v_num=1, train_loss_step=905, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 149/400:  37%|███▋      | 148/400 [03:03<05:10,  1.23s/it, v_num=1, train_loss_step=964, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 150/400:  37%|███▋      | 149/400 [03:04<05:09,  1.23s/it, v_num=1, train_loss_step=1.06e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 151/400:  38%|███▊      | 150/400 [03:05<05:08,  1.24s/it, v_num=1, train_loss_step=962, train_loss_epoch=921]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 152/400:  38%|███▊      | 151/400 [03:07<05:07,  1.23s/it, v_num=1, train_loss_step=1.1e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 153/400:  38%|███▊      | 152/400 [03:08<05:06,  1.24s/it, v_num=1, train_loss_step=1.09e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 154/400:  38%|███▊      | 153/400 [03:09<05:05,  1.24s/it, v_num=1, train_loss_step=1.04e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 155/400:  38%|███▊      | 154/400 [03:10<05:04,  1.24s/it, v_num=1, train_loss_step=926, train_loss_epoch=921]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 156/400:  39%|███▉      | 155/400 [03:11<05:02,  1.24s/it, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 157/400:  39%|███▉      | 156/400 [03:13<05:00,  1.23s/it, v_num=1, train_loss_step=1.07e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 158/400:  39%|███▉      | 157/400 [03:14<04:59,  1.23s/it, v_num=1, train_loss_step=1.12e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 159/400:  40%|███▉      | 158/400 [03:15<04:57,  1.23s/it, v_num=1, train_loss_step=910, train_loss_epoch=921]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 160/400:  40%|███▉      | 159/400 [03:16<04:54,  1.22s/it, v_num=1, train_loss_step=972, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 161/400:  40%|████      | 160/400 [03:18<04:50,  1.21s/it, v_num=1, train_loss_step=908, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 162/400:  40%|████      | 161/400 [03:19<04:50,  1.22s/it, v_num=1, train_loss_step=839, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 163/400:  40%|████      | 162/400 [03:20<04:52,  1.23s/it, v_num=1, train_loss_step=898, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 164/400:  41%|████      | 163/400 [03:21<04:55,  1.25s/it, v_num=1, train_loss_step=947, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 165/400:  41%|████      | 164/400 [03:23<04:52,  1.24s/it, v_num=1, train_loss_step=965, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 166/400:  41%|████▏     | 165/400 [03:24<04:50,  1.23s/it, v_num=1, train_loss_step=998, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 167/400:  42%|████▏     | 166/400 [03:25<04:45,  1.22s/it, v_num=1, train_loss_step=877, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 168/400:  42%|████▏     | 167/400 [03:26<04:42,  1.21s/it, v_num=1, train_loss_step=821, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 169/400:  42%|████▏     | 168/400 [03:27<04:41,  1.21s/it, v_num=1, train_loss_step=827, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 170/400:  42%|████▏     | 169/400 [03:29<04:40,  1.21s/it, v_num=1, train_loss_step=955, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 171/400:  42%|████▎     | 170/400 [03:30<04:40,  1.22s/it, v_num=1, train_loss_step=825, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 172/400:  43%|████▎     | 171/400 [03:31<04:40,  1.22s/it, v_num=1, train_loss_step=998, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 173/400:  43%|████▎     | 172/400 [03:32<04:39,  1.22s/it, v_num=1, train_loss_step=1.1e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 174/400:  43%|████▎     | 173/400 [03:33<04:38,  1.23s/it, v_num=1, train_loss_step=992, train_loss_epoch=920]   

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 175/400:  44%|████▎     | 174/400 [03:35<04:36,  1.23s/it, v_num=1, train_loss_step=924, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 176/400:  44%|████▍     | 175/400 [03:36<04:32,  1.21s/it, v_num=1, train_loss_step=873, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 177/400:  44%|████▍     | 176/400 [03:37<04:33,  1.22s/it, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 178/400:  44%|████▍     | 177/400 [03:38<04:32,  1.22s/it, v_num=1, train_loss_step=983, train_loss_epoch=920]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 179/400:  44%|████▍     | 178/400 [03:40<04:31,  1.22s/it, v_num=1, train_loss_step=875, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 180/400:  45%|████▍     | 179/400 [03:41<04:29,  1.22s/it, v_num=1, train_loss_step=866, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 181/400:  45%|████▌     | 180/400 [03:42<04:27,  1.22s/it, v_num=1, train_loss_step=778, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 182/400:  45%|████▌     | 181/400 [03:43<04:26,  1.22s/it, v_num=1, train_loss_step=894, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 183/400:  46%|████▌     | 182/400 [03:44<04:26,  1.22s/it, v_num=1, train_loss_step=881, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 184/400:  46%|████▌     | 183/400 [03:46<04:25,  1.22s/it, v_num=1, train_loss_step=865, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 185/400:  46%|████▌     | 184/400 [03:47<04:21,  1.21s/it, v_num=1, train_loss_step=786, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 186/400:  46%|████▋     | 185/400 [03:48<04:23,  1.23s/it, v_num=1, train_loss_step=959, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 187/400:  46%|████▋     | 186/400 [03:49<04:22,  1.23s/it, v_num=1, train_loss_step=969, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 188/400:  47%|████▋     | 187/400 [03:51<04:21,  1.23s/it, v_num=1, train_loss_step=906, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 189/400:  47%|████▋     | 188/400 [03:52<04:19,  1.22s/it, v_num=1, train_loss_step=1.08e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 190/400:  47%|████▋     | 189/400 [03:53<04:17,  1.22s/it, v_num=1, train_loss_step=983, train_loss_epoch=920]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 191/400:  48%|████▊     | 190/400 [03:54<04:16,  1.22s/it, v_num=1, train_loss_step=1.1e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 192/400:  48%|████▊     | 191/400 [03:55<04:16,  1.23s/it, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 193/400:  48%|████▊     | 192/400 [03:57<04:14,  1.22s/it, v_num=1, train_loss_step=981, train_loss_epoch=920]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 194/400:  48%|████▊     | 193/400 [03:58<04:13,  1.22s/it, v_num=1, train_loss_step=980, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 195/400:  48%|████▊     | 194/400 [03:59<04:14,  1.23s/it, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 196/400:  49%|████▉     | 195/400 [04:00<04:11,  1.23s/it, v_num=1, train_loss_step=976, train_loss_epoch=920]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 197/400:  49%|████▉     | 196/400 [04:02<04:08,  1.22s/it, v_num=1, train_loss_step=767, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 198/400:  49%|████▉     | 197/400 [04:03<04:06,  1.21s/it, v_num=1, train_loss_step=950, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 199/400:  50%|████▉     | 198/400 [04:04<04:04,  1.21s/it, v_num=1, train_loss_step=820, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 200/400:  50%|████▉     | 199/400 [04:05<04:02,  1.21s/it, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 201/400:  50%|█████     | 200/400 [04:06<04:01,  1.21s/it, v_num=1, train_loss_step=984, train_loss_epoch=920]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 202/400:  50%|█████     | 201/400 [04:08<03:59,  1.21s/it, v_num=1, train_loss_step=863, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 203/400:  50%|█████     | 202/400 [04:09<03:58,  1.20s/it, v_num=1, train_loss_step=865, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 204/400:  51%|█████     | 203/400 [04:10<03:57,  1.21s/it, v_num=1, train_loss_step=900, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 205/400:  51%|█████     | 204/400 [04:11<03:58,  1.22s/it, v_num=1, train_loss_step=1.02e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 206/400:  51%|█████▏    | 205/400 [04:12<03:57,  1.22s/it, v_num=1, train_loss_step=914, train_loss_epoch=920]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 207/400:  52%|█████▏    | 206/400 [04:14<03:57,  1.22s/it, v_num=1, train_loss_step=984, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 208/400:  52%|█████▏    | 207/400 [04:15<03:55,  1.22s/it, v_num=1, train_loss_step=973, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 209/400:  52%|█████▏    | 208/400 [04:16<03:55,  1.23s/it, v_num=1, train_loss_step=880, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 210/400:  52%|█████▏    | 209/400 [04:17<03:54,  1.23s/it, v_num=1, train_loss_step=977, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 211/400:  52%|█████▎    | 210/400 [04:19<03:51,  1.22s/it, v_num=1, train_loss_step=938, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 212/400:  53%|█████▎    | 211/400 [04:20<03:51,  1.23s/it, v_num=1, train_loss_step=1.07e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 213/400:  53%|█████▎    | 212/400 [04:21<03:50,  1.23s/it, v_num=1, train_loss_step=930, train_loss_epoch=920]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 214/400:  53%|█████▎    | 213/400 [04:22<03:49,  1.23s/it, v_num=1, train_loss_step=1.29e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 215/400:  54%|█████▎    | 214/400 [04:24<03:48,  1.23s/it, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 216/400:  54%|█████▍    | 215/400 [04:25<03:47,  1.23s/it, v_num=1, train_loss_step=935, train_loss_epoch=920]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 217/400:  54%|█████▍    | 216/400 [04:26<03:46,  1.23s/it, v_num=1, train_loss_step=892, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 218/400:  54%|█████▍    | 217/400 [04:27<03:44,  1.23s/it, v_num=1, train_loss_step=1.04e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 219/400:  55%|█████▍    | 218/400 [04:28<03:43,  1.23s/it, v_num=1, train_loss_step=969, train_loss_epoch=920]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 220/400:  55%|█████▍    | 219/400 [04:30<03:39,  1.21s/it, v_num=1, train_loss_step=1.13e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 221/400:  55%|█████▌    | 220/400 [04:31<03:38,  1.21s/it, v_num=1, train_loss_step=1.14e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 222/400:  55%|█████▌    | 221/400 [04:32<03:38,  1.22s/it, v_num=1, train_loss_step=899, train_loss_epoch=921]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 223/400:  56%|█████▌    | 222/400 [04:33<03:37,  1.22s/it, v_num=1, train_loss_step=1.07e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 224/400:  56%|█████▌    | 223/400 [04:35<03:36,  1.22s/it, v_num=1, train_loss_step=1.02e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 225/400:  56%|█████▌    | 224/400 [04:36<03:34,  1.22s/it, v_num=1, train_loss_step=899, train_loss_epoch=920]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 226/400:  56%|█████▋    | 225/400 [04:37<03:34,  1.22s/it, v_num=1, train_loss_step=960, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 227/400:  56%|█████▋    | 226/400 [04:38<03:32,  1.22s/it, v_num=1, train_loss_step=1.07e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 228/400:  57%|█████▋    | 227/400 [04:39<03:30,  1.22s/it, v_num=1, train_loss_step=953, train_loss_epoch=920]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 229/400:  57%|█████▋    | 228/400 [04:41<03:29,  1.22s/it, v_num=1, train_loss_step=956, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 230/400:  57%|█████▋    | 229/400 [04:42<03:28,  1.22s/it, v_num=1, train_loss_step=898, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 231/400:  57%|█████▊    | 230/400 [04:43<03:27,  1.22s/it, v_num=1, train_loss_step=883, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 232/400:  58%|█████▊    | 231/400 [04:44<03:26,  1.22s/it, v_num=1, train_loss_step=964, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 233/400:  58%|█████▊    | 232/400 [04:46<03:24,  1.22s/it, v_num=1, train_loss_step=891, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 234/400:  58%|█████▊    | 233/400 [04:47<03:24,  1.22s/it, v_num=1, train_loss_step=1.14e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 235/400:  58%|█████▊    | 234/400 [04:48<03:23,  1.23s/it, v_num=1, train_loss_step=1.06e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 236/400:  59%|█████▉    | 235/400 [04:49<03:21,  1.22s/it, v_num=1, train_loss_step=862, train_loss_epoch=920]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 237/400:  59%|█████▉    | 236/400 [04:50<03:19,  1.22s/it, v_num=1, train_loss_step=979, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 238/400:  59%|█████▉    | 237/400 [04:52<03:17,  1.21s/it, v_num=1, train_loss_step=791, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 239/400:  60%|█████▉    | 238/400 [04:53<03:15,  1.20s/it, v_num=1, train_loss_step=1.05e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 240/400:  60%|█████▉    | 239/400 [04:54<03:13,  1.20s/it, v_num=1, train_loss_step=901, train_loss_epoch=920]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 241/400:  60%|██████    | 240/400 [04:55<03:06,  1.16s/it, v_num=1, train_loss_step=889, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 242/400:  60%|██████    | 241/400 [04:56<03:04,  1.16s/it, v_num=1, train_loss_step=905, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 243/400:  60%|██████    | 242/400 [04:57<03:07,  1.18s/it, v_num=1, train_loss_step=913, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 244/400:  61%|██████    | 243/400 [04:59<03:09,  1.21s/it, v_num=1, train_loss_step=920, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 245/400:  61%|██████    | 244/400 [05:00<03:08,  1.21s/it, v_num=1, train_loss_step=958, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 246/400:  61%|██████▏   | 245/400 [05:01<03:08,  1.22s/it, v_num=1, train_loss_step=939, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 247/400:  62%|██████▏   | 246/400 [05:02<03:08,  1.22s/it, v_num=1, train_loss_step=907, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 248/400:  62%|██████▏   | 247/400 [05:04<03:06,  1.22s/it, v_num=1, train_loss_step=775, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 249/400:  62%|██████▏   | 248/400 [05:05<03:05,  1.22s/it, v_num=1, train_loss_step=993, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 250/400:  62%|██████▏   | 249/400 [05:06<03:05,  1.23s/it, v_num=1, train_loss_step=1.19e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 251/400:  62%|██████▎   | 250/400 [05:07<03:04,  1.23s/it, v_num=1, train_loss_step=966, train_loss_epoch=921]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 252/400:  63%|██████▎   | 251/400 [05:09<03:03,  1.23s/it, v_num=1, train_loss_step=1.28e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 253/400:  63%|██████▎   | 252/400 [05:10<03:02,  1.23s/it, v_num=1, train_loss_step=907, train_loss_epoch=921]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 254/400:  63%|██████▎   | 253/400 [05:11<03:01,  1.23s/it, v_num=1, train_loss_step=955, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 255/400:  64%|██████▎   | 254/400 [05:12<03:00,  1.24s/it, v_num=1, train_loss_step=967, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 256/400:  64%|██████▍   | 255/400 [05:13<02:58,  1.23s/it, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 257/400:  64%|██████▍   | 256/400 [05:15<02:57,  1.24s/it, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 258/400:  64%|██████▍   | 257/400 [05:16<02:53,  1.22s/it, v_num=1, train_loss_step=937, train_loss_epoch=920]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 259/400:  64%|██████▍   | 258/400 [05:17<02:54,  1.23s/it, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 260/400:  65%|██████▍   | 259/400 [05:18<02:52,  1.22s/it, v_num=1, train_loss_step=1.09e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 261/400:  65%|██████▌   | 260/400 [05:20<02:52,  1.23s/it, v_num=1, train_loss_step=967, train_loss_epoch=920]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 262/400:  65%|██████▌   | 261/400 [05:21<02:52,  1.24s/it, v_num=1, train_loss_step=997, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 263/400:  66%|██████▌   | 262/400 [05:22<02:51,  1.24s/it, v_num=1, train_loss_step=831, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 264/400:  66%|██████▌   | 263/400 [05:23<02:49,  1.24s/it, v_num=1, train_loss_step=947, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 265/400:  66%|██████▌   | 264/400 [05:25<02:49,  1.24s/it, v_num=1, train_loss_step=927, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 266/400:  66%|██████▋   | 265/400 [05:26<02:47,  1.24s/it, v_num=1, train_loss_step=1.19e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 267/400:  66%|██████▋   | 266/400 [05:27<02:45,  1.23s/it, v_num=1, train_loss_step=1.09e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 268/400:  67%|██████▋   | 267/400 [05:28<02:44,  1.23s/it, v_num=1, train_loss_step=1.02e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 269/400:  67%|██████▋   | 268/400 [05:30<02:43,  1.24s/it, v_num=1, train_loss_step=919, train_loss_epoch=920]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 270/400:  67%|██████▋   | 269/400 [05:31<02:42,  1.24s/it, v_num=1, train_loss_step=885, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 271/400:  68%|██████▊   | 270/400 [05:32<02:41,  1.24s/it, v_num=1, train_loss_step=971, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 272/400:  68%|██████▊   | 271/400 [05:33<02:39,  1.24s/it, v_num=1, train_loss_step=1.03e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 273/400:  68%|██████▊   | 272/400 [05:35<02:39,  1.24s/it, v_num=1, train_loss_step=1.03e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 274/400:  68%|██████▊   | 273/400 [05:36<02:37,  1.24s/it, v_num=1, train_loss_step=945, train_loss_epoch=920]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 275/400:  68%|██████▊   | 274/400 [05:37<02:36,  1.24s/it, v_num=1, train_loss_step=1.07e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 276/400:  69%|██████▉   | 275/400 [05:38<02:35,  1.24s/it, v_num=1, train_loss_step=1.04e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 277/400:  69%|██████▉   | 276/400 [05:39<02:34,  1.24s/it, v_num=1, train_loss_step=962, train_loss_epoch=920]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 278/400:  69%|██████▉   | 277/400 [05:41<02:32,  1.24s/it, v_num=1, train_loss_step=985, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 279/400:  70%|██████▉   | 278/400 [05:42<02:30,  1.23s/it, v_num=1, train_loss_step=994, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 280/400:  70%|██████▉   | 279/400 [05:43<02:28,  1.23s/it, v_num=1, train_loss_step=1.08e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 281/400:  70%|███████   | 280/400 [05:44<02:28,  1.23s/it, v_num=1, train_loss_step=825, train_loss_epoch=921]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 282/400:  70%|███████   | 281/400 [05:46<02:25,  1.23s/it, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 283/400:  70%|███████   | 282/400 [05:47<02:25,  1.23s/it, v_num=1, train_loss_step=975, train_loss_epoch=920]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 284/400:  71%|███████   | 283/400 [05:48<02:24,  1.23s/it, v_num=1, train_loss_step=944, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 285/400:  71%|███████   | 284/400 [05:49<02:22,  1.23s/it, v_num=1, train_loss_step=1.09e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 286/400:  71%|███████▏  | 285/400 [05:51<02:21,  1.23s/it, v_num=1, train_loss_step=958, train_loss_epoch=921]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 287/400:  72%|███████▏  | 286/400 [05:52<02:21,  1.24s/it, v_num=1, train_loss_step=965, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 288/400:  72%|███████▏  | 287/400 [05:53<02:19,  1.23s/it, v_num=1, train_loss_step=868, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 289/400:  72%|███████▏  | 288/400 [05:54<02:18,  1.24s/it, v_num=1, train_loss_step=947, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 290/400:  72%|███████▏  | 289/400 [05:55<02:16,  1.23s/it, v_num=1, train_loss_step=998, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 291/400:  72%|███████▎  | 290/400 [05:57<02:15,  1.24s/it, v_num=1, train_loss_step=1.04e+3, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 292/400:  73%|███████▎  | 291/400 [05:58<02:14,  1.24s/it, v_num=1, train_loss_step=1.15e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 293/400:  73%|███████▎  | 292/400 [05:59<02:13,  1.23s/it, v_num=1, train_loss_step=982, train_loss_epoch=921]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 294/400:  73%|███████▎  | 293/400 [06:00<02:09,  1.21s/it, v_num=1, train_loss_step=962, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 295/400:  74%|███████▎  | 294/400 [06:02<02:09,  1.22s/it, v_num=1, train_loss_step=908, train_loss_epoch=920]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 296/400:  74%|███████▍  | 295/400 [06:03<02:09,  1.23s/it, v_num=1, train_loss_step=990, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 297/400:  74%|███████▍  | 296/400 [06:04<02:07,  1.23s/it, v_num=1, train_loss_step=1.03e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 298/400:  74%|███████▍  | 297/400 [06:05<02:06,  1.23s/it, v_num=1, train_loss_step=1.06e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 299/400:  74%|███████▍  | 298/400 [06:07<02:05,  1.23s/it, v_num=1, train_loss_step=1.07e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 300/400:  75%|███████▍  | 299/400 [06:08<02:04,  1.23s/it, v_num=1, train_loss_step=966, train_loss_epoch=921]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 301/400:  75%|███████▌  | 300/400 [06:09<02:02,  1.23s/it, v_num=1, train_loss_step=1.04e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 302/400:  75%|███████▌  | 301/400 [06:10<02:02,  1.24s/it, v_num=1, train_loss_step=815, train_loss_epoch=921]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 303/400:  76%|███████▌  | 302/400 [06:11<02:00,  1.23s/it, v_num=1, train_loss_step=1.03e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 304/400:  76%|███████▌  | 303/400 [06:13<01:59,  1.23s/it, v_num=1, train_loss_step=931, train_loss_epoch=921]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 305/400:  76%|███████▌  | 304/400 [06:14<01:58,  1.23s/it, v_num=1, train_loss_step=903, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 306/400:  76%|███████▋  | 305/400 [06:15<01:56,  1.23s/it, v_num=1, train_loss_step=977, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 307/400:  76%|███████▋  | 306/400 [06:16<01:54,  1.22s/it, v_num=1, train_loss_step=985, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 308/400:  77%|███████▋  | 307/400 [06:18<01:53,  1.23s/it, v_num=1, train_loss_step=930, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 309/400:  77%|███████▋  | 308/400 [06:19<01:53,  1.23s/it, v_num=1, train_loss_step=963, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 310/400:  77%|███████▋  | 309/400 [06:20<01:51,  1.23s/it, v_num=1, train_loss_step=1.03e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 311/400:  78%|███████▊  | 310/400 [06:21<01:50,  1.22s/it, v_num=1, train_loss_step=966, train_loss_epoch=921]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 312/400:  78%|███████▊  | 311/400 [06:23<01:49,  1.23s/it, v_num=1, train_loss_step=997, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 313/400:  78%|███████▊  | 312/400 [06:24<01:47,  1.22s/it, v_num=1, train_loss_step=870, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 314/400:  78%|███████▊  | 313/400 [06:25<01:46,  1.23s/it, v_num=1, train_loss_step=1.04e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 315/400:  78%|███████▊  | 314/400 [06:26<01:45,  1.22s/it, v_num=1, train_loss_step=984, train_loss_epoch=921]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 316/400:  79%|███████▉  | 315/400 [06:27<01:44,  1.23s/it, v_num=1, train_loss_step=1.02e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 317/400:  79%|███████▉  | 316/400 [06:29<01:43,  1.23s/it, v_num=1, train_loss_step=925, train_loss_epoch=921]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 318/400:  79%|███████▉  | 317/400 [06:30<01:41,  1.22s/it, v_num=1, train_loss_step=1.08e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 319/400:  80%|███████▉  | 318/400 [06:31<01:40,  1.23s/it, v_num=1, train_loss_step=1.06e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 320/400:  80%|███████▉  | 319/400 [06:32<01:39,  1.23s/it, v_num=1, train_loss_step=1.09e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 321/400:  80%|████████  | 320/400 [06:34<01:38,  1.23s/it, v_num=1, train_loss_step=844, train_loss_epoch=921]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 322/400:  80%|████████  | 321/400 [06:35<01:37,  1.23s/it, v_num=1, train_loss_step=916, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 323/400:  80%|████████  | 322/400 [06:36<01:35,  1.23s/it, v_num=1, train_loss_step=1.07e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 324/400:  81%|████████  | 323/400 [06:37<01:36,  1.26s/it, v_num=1, train_loss_step=953, train_loss_epoch=921]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 325/400:  81%|████████  | 324/400 [06:39<01:36,  1.27s/it, v_num=1, train_loss_step=950, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 326/400:  81%|████████▏ | 325/400 [06:40<01:35,  1.27s/it, v_num=1, train_loss_step=983, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 327/400:  82%|████████▏ | 326/400 [06:41<01:33,  1.26s/it, v_num=1, train_loss_step=993, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 328/400:  82%|████████▏ | 327/400 [06:42<01:32,  1.26s/it, v_num=1, train_loss_step=905, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 329/400:  82%|████████▏ | 328/400 [06:44<01:31,  1.27s/it, v_num=1, train_loss_step=1.02e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 330/400:  82%|████████▏ | 329/400 [06:45<01:30,  1.27s/it, v_num=1, train_loss_step=925, train_loss_epoch=921]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 331/400:  82%|████████▎ | 330/400 [06:46<01:28,  1.27s/it, v_num=1, train_loss_step=895, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 332/400:  83%|████████▎ | 331/400 [06:48<01:27,  1.27s/it, v_num=1, train_loss_step=986, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 333/400:  83%|████████▎ | 332/400 [06:49<01:26,  1.27s/it, v_num=1, train_loss_step=985, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 334/400:  83%|████████▎ | 333/400 [06:50<01:25,  1.27s/it, v_num=1, train_loss_step=959, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 335/400:  84%|████████▎ | 334/400 [06:51<01:23,  1.27s/it, v_num=1, train_loss_step=1.06e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 336/400:  84%|████████▍ | 335/400 [06:53<01:22,  1.27s/it, v_num=1, train_loss_step=942, train_loss_epoch=922]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 337/400:  84%|████████▍ | 336/400 [06:54<01:21,  1.27s/it, v_num=1, train_loss_step=964, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 338/400:  84%|████████▍ | 337/400 [06:55<01:20,  1.28s/it, v_num=1, train_loss_step=1.1e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 339/400:  84%|████████▍ | 338/400 [06:56<01:18,  1.26s/it, v_num=1, train_loss_step=978, train_loss_epoch=922]   

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 340/400:  85%|████████▍ | 339/400 [06:58<01:17,  1.26s/it, v_num=1, train_loss_step=975, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 341/400:  85%|████████▌ | 340/400 [06:59<01:15,  1.25s/it, v_num=1, train_loss_step=881, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 342/400:  85%|████████▌ | 341/400 [07:00<01:14,  1.25s/it, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 343/400:  86%|████████▌ | 342/400 [07:01<01:13,  1.26s/it, v_num=1, train_loss_step=1.02e+3, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 344/400:  86%|████████▌ | 343/400 [07:03<01:12,  1.26s/it, v_num=1, train_loss_step=932, train_loss_epoch=921]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 345/400:  86%|████████▌ | 344/400 [07:04<01:10,  1.26s/it, v_num=1, train_loss_step=977, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 346/400:  86%|████████▋ | 345/400 [07:05<01:08,  1.25s/it, v_num=1, train_loss_step=918, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 347/400:  86%|████████▋ | 346/400 [07:06<01:07,  1.25s/it, v_num=1, train_loss_step=982, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 348/400:  87%|████████▋ | 347/400 [07:08<01:05,  1.24s/it, v_num=1, train_loss_step=943, train_loss_epoch=921]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 349/400:  87%|████████▋ | 348/400 [07:09<01:05,  1.25s/it, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 350/400:  87%|████████▋ | 349/400 [07:10<01:04,  1.26s/it, v_num=1, train_loss_step=1.02e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 351/400:  88%|████████▊ | 350/400 [07:11<01:03,  1.26s/it, v_num=1, train_loss_step=1.02e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 352/400:  88%|████████▊ | 351/400 [07:13<01:01,  1.26s/it, v_num=1, train_loss_step=931, train_loss_epoch=922]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 353/400:  88%|████████▊ | 352/400 [07:14<01:00,  1.26s/it, v_num=1, train_loss_step=994, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 354/400:  88%|████████▊ | 353/400 [07:15<00:59,  1.26s/it, v_num=1, train_loss_step=1.05e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 355/400:  88%|████████▊ | 354/400 [07:16<00:57,  1.25s/it, v_num=1, train_loss_step=906, train_loss_epoch=922]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 356/400:  89%|████████▉ | 355/400 [07:18<00:56,  1.26s/it, v_num=1, train_loss_step=1.06e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 357/400:  89%|████████▉ | 356/400 [07:19<00:55,  1.27s/it, v_num=1, train_loss_step=1.03e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 358/400:  89%|████████▉ | 357/400 [07:20<00:54,  1.26s/it, v_num=1, train_loss_step=826, train_loss_epoch=922]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 359/400:  90%|████████▉ | 358/400 [07:22<00:52,  1.26s/it, v_num=1, train_loss_step=979, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 360/400:  90%|████████▉ | 359/400 [07:23<00:51,  1.26s/it, v_num=1, train_loss_step=978, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 361/400:  90%|█████████ | 360/400 [07:24<00:50,  1.26s/it, v_num=1, train_loss_step=1.08e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 362/400:  90%|█████████ | 361/400 [07:25<00:49,  1.27s/it, v_num=1, train_loss_step=1.03e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 363/400:  90%|█████████ | 362/400 [07:27<00:47,  1.26s/it, v_num=1, train_loss_step=1.12e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 364/400:  91%|█████████ | 363/400 [07:28<00:46,  1.26s/it, v_num=1, train_loss_step=941, train_loss_epoch=922]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 365/400:  91%|█████████ | 364/400 [07:29<00:45,  1.25s/it, v_num=1, train_loss_step=1.02e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 366/400:  91%|█████████▏| 365/400 [07:30<00:43,  1.25s/it, v_num=1, train_loss_step=936, train_loss_epoch=922]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 367/400:  92%|█████████▏| 366/400 [07:32<00:42,  1.25s/it, v_num=1, train_loss_step=1.09e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 368/400:  92%|█████████▏| 367/400 [07:33<00:41,  1.26s/it, v_num=1, train_loss_step=1.04e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 369/400:  92%|█████████▏| 368/400 [07:34<00:40,  1.27s/it, v_num=1, train_loss_step=900, train_loss_epoch=922]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 370/400:  92%|█████████▏| 369/400 [07:35<00:39,  1.26s/it, v_num=1, train_loss_step=925, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 371/400:  92%|█████████▎| 370/400 [07:37<00:37,  1.27s/it, v_num=1, train_loss_step=983, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 372/400:  93%|█████████▎| 371/400 [07:38<00:36,  1.27s/it, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 373/400:  93%|█████████▎| 372/400 [07:39<00:35,  1.26s/it, v_num=1, train_loss_step=1.02e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 374/400:  93%|█████████▎| 373/400 [07:40<00:33,  1.23s/it, v_num=1, train_loss_step=921, train_loss_epoch=922]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 375/400:  94%|█████████▎| 374/400 [07:42<00:32,  1.23s/it, v_num=1, train_loss_step=1.09e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 376/400:  94%|█████████▍| 375/400 [07:43<00:31,  1.24s/it, v_num=1, train_loss_step=984, train_loss_epoch=922]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 377/400:  94%|█████████▍| 376/400 [07:44<00:30,  1.26s/it, v_num=1, train_loss_step=1.06e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 378/400:  94%|█████████▍| 377/400 [07:45<00:29,  1.26s/it, v_num=1, train_loss_step=970, train_loss_epoch=922]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 379/400:  94%|█████████▍| 378/400 [07:47<00:27,  1.26s/it, v_num=1, train_loss_step=857, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 380/400:  95%|█████████▍| 379/400 [07:48<00:26,  1.26s/it, v_num=1, train_loss_step=885, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 381/400:  95%|█████████▌| 380/400 [07:49<00:25,  1.26s/it, v_num=1, train_loss_step=1.12e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 382/400:  95%|█████████▌| 381/400 [07:50<00:23,  1.25s/it, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=923]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 383/400:  96%|█████████▌| 382/400 [07:52<00:22,  1.25s/it, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 384/400:  96%|█████████▌| 383/400 [07:53<00:21,  1.24s/it, v_num=1, train_loss_step=822, train_loss_epoch=922]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 385/400:  96%|█████████▌| 384/400 [07:54<00:19,  1.24s/it, v_num=1, train_loss_step=892, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 386/400:  96%|█████████▋| 385/400 [07:55<00:18,  1.25s/it, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=923]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 387/400:  96%|█████████▋| 386/400 [07:57<00:17,  1.24s/it, v_num=1, train_loss_step=931, train_loss_epoch=922]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 388/400:  97%|█████████▋| 387/400 [07:58<00:16,  1.27s/it, v_num=1, train_loss_step=931, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 389/400:  97%|█████████▋| 388/400 [07:59<00:15,  1.28s/it, v_num=1, train_loss_step=1.12e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 390/400:  97%|█████████▋| 389/400 [08:01<00:14,  1.29s/it, v_num=1, train_loss_step=1.05e+3, train_loss_epoch=923]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 391/400:  98%|█████████▊| 390/400 [08:02<00:12,  1.28s/it, v_num=1, train_loss_step=942, train_loss_epoch=923]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 392/400:  98%|█████████▊| 391/400 [08:03<00:11,  1.27s/it, v_num=1, train_loss_step=864, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 393/400:  98%|█████████▊| 392/400 [08:04<00:10,  1.27s/it, v_num=1, train_loss_step=972, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 394/400:  98%|█████████▊| 393/400 [08:06<00:08,  1.27s/it, v_num=1, train_loss_step=1.16e+3, train_loss_epoch=923]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 395/400:  98%|█████████▊| 394/400 [08:07<00:07,  1.27s/it, v_num=1, train_loss_step=1.1e+3, train_loss_epoch=923] 

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 396/400:  99%|█████████▉| 395/400 [08:08<00:06,  1.28s/it, v_num=1, train_loss_step=942, train_loss_epoch=923]   

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 397/400:  99%|█████████▉| 396/400 [08:09<00:05,  1.28s/it, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 398/400:  99%|█████████▉| 397/400 [08:11<00:03,  1.27s/it, v_num=1, train_loss_step=947, train_loss_epoch=922]    

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 399/400: 100%|█████████▉| 398/400 [08:12<00:02,  1.26s/it, v_num=1, train_loss_step=1.04e+3, train_loss_epoch=922]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 400/400: 100%|█████████▉| 399/400 [08:13<00:01,  1.25s/it, v_num=1, train_loss_step=1.11e+3, train_loss_epoch=923]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 400/400: 100%|██████████| 400/400 [08:14<00:00,  1.26s/it, v_num=1, train_loss_step=996, train_loss_epoch=922]    

`Trainer.fit` stopped: `max_epochs=400` reached.


Epoch 400/400: 100%|██████████| 400/400 [08:14<00:00,  1.24s/it, v_num=1, train_loss_step=996, train_loss_epoch=922]
[34mINFO    [0m Training for [1;36m20[0m epochs.                                                                                   


  self.validate_field(adata)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/root/miniconda3/envs/cmml/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Epoch 1/20:   0%|          | 0/20 [00:00<?, ?it/s]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 2/20:   5%|▌         | 1/20 [00:02<00:54,  2.89s/it, v_num=1, train_loss_step=1.05e+3, train_loss_epoch=975]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 3/20:  10%|█         | 2/20 [00:05<00:50,  2.82s/it, v_num=1, train_loss_step=1.14e+3, train_loss_epoch=922]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 4/20:  15%|█▌        | 3/20 [00:08<00:46,  2.75s/it, v_num=1, train_loss_step=999, train_loss_epoch=915]    

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 5/20:  20%|██        | 4/20 [00:11<00:44,  2.75s/it, v_num=1, train_loss_step=1.18e+3, train_loss_epoch=912]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 6/20:  25%|██▌       | 5/20 [00:13<00:41,  2.75s/it, v_num=1, train_loss_step=1.36e+3, train_loss_epoch=912]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 7/20:  30%|███       | 6/20 [00:16<00:38,  2.75s/it, v_num=1, train_loss_step=985, train_loss_epoch=911]    

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 8/20:  35%|███▌      | 7/20 [00:19<00:35,  2.74s/it, v_num=1, train_loss_step=876, train_loss_epoch=910]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 9/20:  40%|████      | 8/20 [00:22<00:32,  2.74s/it, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=909]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 10/20:  45%|████▌     | 9/20 [00:24<00:30,  2.74s/it, v_num=1, train_loss_step=907, train_loss_epoch=908]   

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 11/20:  50%|█████     | 10/20 [00:27<00:27,  2.75s/it, v_num=1, train_loss_step=981, train_loss_epoch=908]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 12/20:  55%|█████▌    | 11/20 [00:30<00:24,  2.75s/it, v_num=1, train_loss_step=906, train_loss_epoch=908]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 13/20:  60%|██████    | 12/20 [00:33<00:22,  2.75s/it, v_num=1, train_loss_step=1.1e+3, train_loss_epoch=907]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 14/20:  65%|██████▌   | 13/20 [00:35<00:19,  2.76s/it, v_num=1, train_loss_step=928, train_loss_epoch=907]   

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 15/20:  70%|███████   | 14/20 [00:38<00:16,  2.75s/it, v_num=1, train_loss_step=1e+3, train_loss_epoch=908]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 16/20:  75%|███████▌  | 15/20 [00:41<00:13,  2.75s/it, v_num=1, train_loss_step=1.09e+3, train_loss_epoch=907]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 17/20:  80%|████████  | 16/20 [00:44<00:11,  2.76s/it, v_num=1, train_loss_step=892, train_loss_epoch=907]    

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 18/20:  85%|████████▌ | 17/20 [00:46<00:08,  2.75s/it, v_num=1, train_loss_step=993, train_loss_epoch=906]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 19/20:  90%|█████████ | 18/20 [00:49<00:05,  2.76s/it, v_num=1, train_loss_step=1.01e+3, train_loss_epoch=907]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 20/20:  95%|█████████▌| 19/20 [00:52<00:02,  2.75s/it, v_num=1, train_loss_step=886, train_loss_epoch=906]    

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 20/20: 100%|██████████| 20/20 [00:55<00:00,  2.75s/it, v_num=1, train_loss_step=1.22e+3, train_loss_epoch=907]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 20/20: 100%|██████████| 20/20 [00:55<00:00,  2.76s/it, v_num=1, train_loss_step=1.22e+3, train_loss_epoch=907]
[34mINFO    [0m File result/scanvi_model/model.pt already downloaded                                                      
[34mINFO    [0m Training for [1;36m100[0m epochs.                                                                                  


  model = torch.load(model_path, map_location=map_location)
  self.validate_field(adata)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/root/miniconda3/envs/cmml/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Epoch 1/100:   0%|          | 0/100 [00:00<?, ?it/s]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 2/100:   1%|          | 1/100 [00:00<01:00,  1.63it/s, v_num=1, train_loss_step=1.49e+3, train_loss_epoch=1.47e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 3/100:   2%|▏         | 2/100 [00:01<01:01,  1.61it/s, v_num=1, train_loss_step=1.41e+3, train_loss_epoch=1.45e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 4/100:   3%|▎         | 3/100 [00:01<01:00,  1.60it/s, v_num=1, train_loss_step=1.44e+3, train_loss_epoch=1.44e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 5/100:   4%|▍         | 4/100 [00:02<01:00,  1.59it/s, v_num=1, train_loss_step=1.43e+3, train_loss_epoch=1.43e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 6/100:   5%|▌         | 5/100 [00:03<00:59,  1.60it/s, v_num=1, train_loss_step=1.39e+3, train_loss_epoch=1.42e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 7/100:   6%|▌         | 6/100 [00:03<00:58,  1.60it/s, v_num=1, train_loss_step=1.38e+3, train_loss_epoch=1.41e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 8/100:   7%|▋         | 7/100 [00:04<00:58,  1.59it/s, v_num=1, train_loss_step=1.51e+3, train_loss_epoch=1.41e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 9/100:   8%|▊         | 8/100 [00:05<00:58,  1.57it/s, v_num=1, train_loss_step=1.35e+3, train_loss_epoch=1.4e+3] 

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 10/100:   9%|▉         | 9/100 [00:05<00:58,  1.56it/s, v_num=1, train_loss_step=1.33e+3, train_loss_epoch=1.4e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 11/100:  10%|█         | 10/100 [00:06<00:59,  1.52it/s, v_num=1, train_loss_step=1.46e+3, train_loss_epoch=1.39e+3]

/root/miniconda3/envs/cmml/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 12/100:  11%|█         | 11/100 [00:06<00:57,  1.56it/s, v_num=1, train_loss_step=1.47e+3, train_loss_epoch=1.39e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 13/100:  12%|█▏        | 12/100 [00:07<00:56,  1.56it/s, v_num=1, train_loss_step=1.4e+3, train_loss_epoch=1.38e+3] 

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 14/100:  13%|█▎        | 13/100 [00:08<00:55,  1.57it/s, v_num=1, train_loss_step=1.35e+3, train_loss_epoch=1.38e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 15/100:  14%|█▍        | 14/100 [00:08<00:54,  1.58it/s, v_num=1, train_loss_step=1.39e+3, train_loss_epoch=1.38e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 16/100:  15%|█▌        | 15/100 [00:09<00:54,  1.57it/s, v_num=1, train_loss_step=1.39e+3, train_loss_epoch=1.37e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 17/100:  16%|█▌        | 16/100 [00:10<00:53,  1.56it/s, v_num=1, train_loss_step=1.43e+3, train_loss_epoch=1.37e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 18/100:  17%|█▋        | 17/100 [00:10<00:53,  1.56it/s, v_num=1, train_loss_step=1.39e+3, train_loss_epoch=1.37e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 19/100:  18%|█▊        | 18/100 [00:11<00:52,  1.57it/s, v_num=1, train_loss_step=1.42e+3, train_loss_epoch=1.37e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 20/100:  19%|█▉        | 19/100 [00:12<00:51,  1.57it/s, v_num=1, train_loss_step=1.49e+3, train_loss_epoch=1.37e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 21/100:  20%|██        | 20/100 [00:12<00:52,  1.54it/s, v_num=1, train_loss_step=1.32e+3, train_loss_epoch=1.36e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 22/100:  21%|██        | 21/100 [00:13<00:50,  1.56it/s, v_num=1, train_loss_step=1.45e+3, train_loss_epoch=1.36e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 23/100:  22%|██▏       | 22/100 [00:14<00:51,  1.52it/s, v_num=1, train_loss_step=1.37e+3, train_loss_epoch=1.36e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 24/100:  23%|██▎       | 23/100 [00:14<00:51,  1.50it/s, v_num=1, train_loss_step=1.45e+3, train_loss_epoch=1.36e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 25/100:  24%|██▍       | 24/100 [00:15<00:50,  1.51it/s, v_num=1, train_loss_step=1.45e+3, train_loss_epoch=1.36e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 26/100:  25%|██▌       | 25/100 [00:16<00:50,  1.48it/s, v_num=1, train_loss_step=1.38e+3, train_loss_epoch=1.36e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 27/100:  26%|██▌       | 26/100 [00:16<00:49,  1.49it/s, v_num=1, train_loss_step=1.38e+3, train_loss_epoch=1.36e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 28/100:  27%|██▋       | 27/100 [00:17<00:48,  1.51it/s, v_num=1, train_loss_step=1.49e+3, train_loss_epoch=1.35e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 29/100:  28%|██▊       | 28/100 [00:18<00:48,  1.50it/s, v_num=1, train_loss_step=1.37e+3, train_loss_epoch=1.35e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 30/100:  29%|██▉       | 29/100 [00:18<00:49,  1.45it/s, v_num=1, train_loss_step=1.31e+3, train_loss_epoch=1.35e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 31/100:  30%|███       | 30/100 [00:19<00:49,  1.42it/s, v_num=1, train_loss_step=1.35e+3, train_loss_epoch=1.35e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 32/100:  31%|███       | 31/100 [00:20<00:46,  1.48it/s, v_num=1, train_loss_step=1.42e+3, train_loss_epoch=1.35e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 33/100:  32%|███▏      | 32/100 [00:20<00:44,  1.51it/s, v_num=1, train_loss_step=1.3e+3, train_loss_epoch=1.35e+3] 

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 34/100:  33%|███▎      | 33/100 [00:21<00:44,  1.51it/s, v_num=1, train_loss_step=1.3e+3, train_loss_epoch=1.35e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 35/100:  34%|███▍      | 34/100 [00:22<00:43,  1.53it/s, v_num=1, train_loss_step=1.43e+3, train_loss_epoch=1.35e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 36/100:  35%|███▌      | 35/100 [00:22<00:41,  1.55it/s, v_num=1, train_loss_step=1.37e+3, train_loss_epoch=1.35e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 37/100:  36%|███▌      | 36/100 [00:23<00:41,  1.56it/s, v_num=1, train_loss_step=1.38e+3, train_loss_epoch=1.35e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 38/100:  37%|███▋      | 37/100 [00:24<00:40,  1.57it/s, v_num=1, train_loss_step=1.4e+3, train_loss_epoch=1.34e+3] 

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 39/100:  38%|███▊      | 38/100 [00:24<00:38,  1.62it/s, v_num=1, train_loss_step=1.4e+3, train_loss_epoch=1.34e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 40/100:  39%|███▉      | 39/100 [00:25<00:37,  1.61it/s, v_num=1, train_loss_step=1.33e+3, train_loss_epoch=1.34e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 41/100:  40%|████      | 40/100 [00:25<00:38,  1.56it/s, v_num=1, train_loss_step=1.43e+3, train_loss_epoch=1.34e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 42/100:  41%|████      | 41/100 [00:26<00:37,  1.59it/s, v_num=1, train_loss_step=1.29e+3, train_loss_epoch=1.34e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 43/100:  42%|████▏     | 42/100 [00:27<00:36,  1.59it/s, v_num=1, train_loss_step=1.35e+3, train_loss_epoch=1.34e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 44/100:  43%|████▎     | 43/100 [00:27<00:36,  1.58it/s, v_num=1, train_loss_step=1.38e+3, train_loss_epoch=1.34e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 45/100:  44%|████▍     | 44/100 [00:28<00:35,  1.57it/s, v_num=1, train_loss_step=1.48e+3, train_loss_epoch=1.34e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 46/100:  45%|████▌     | 45/100 [00:29<00:35,  1.55it/s, v_num=1, train_loss_step=1.42e+3, train_loss_epoch=1.34e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 47/100:  46%|████▌     | 46/100 [00:29<00:34,  1.56it/s, v_num=1, train_loss_step=1.36e+3, train_loss_epoch=1.34e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 48/100:  47%|████▋     | 47/100 [00:30<00:34,  1.54it/s, v_num=1, train_loss_step=1.35e+3, train_loss_epoch=1.34e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 49/100:  48%|████▊     | 48/100 [00:31<00:33,  1.55it/s, v_num=1, train_loss_step=1.34e+3, train_loss_epoch=1.34e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 50/100:  49%|████▉     | 49/100 [00:31<00:32,  1.56it/s, v_num=1, train_loss_step=1.4e+3, train_loss_epoch=1.34e+3] 

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 51/100:  50%|█████     | 50/100 [00:32<00:33,  1.51it/s, v_num=1, train_loss_step=1.38e+3, train_loss_epoch=1.34e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 52/100:  51%|█████     | 51/100 [00:32<00:31,  1.55it/s, v_num=1, train_loss_step=1.44e+3, train_loss_epoch=1.34e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 53/100:  52%|█████▏    | 52/100 [00:33<00:30,  1.55it/s, v_num=1, train_loss_step=1.3e+3, train_loss_epoch=1.34e+3] 

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 54/100:  53%|█████▎    | 53/100 [00:34<00:30,  1.57it/s, v_num=1, train_loss_step=1.29e+3, train_loss_epoch=1.34e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 55/100:  54%|█████▍    | 54/100 [00:34<00:29,  1.57it/s, v_num=1, train_loss_step=1.36e+3, train_loss_epoch=1.34e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 56/100:  55%|█████▌    | 55/100 [00:35<00:28,  1.57it/s, v_num=1, train_loss_step=1.31e+3, train_loss_epoch=1.34e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 57/100:  56%|█████▌    | 56/100 [00:36<00:29,  1.52it/s, v_num=1, train_loss_step=1.42e+3, train_loss_epoch=1.34e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 58/100:  57%|█████▋    | 57/100 [00:36<00:29,  1.48it/s, v_num=1, train_loss_step=1.31e+3, train_loss_epoch=1.34e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 59/100:  58%|█████▊    | 58/100 [00:37<00:28,  1.45it/s, v_num=1, train_loss_step=1.43e+3, train_loss_epoch=1.34e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 60/100:  59%|█████▉    | 59/100 [00:38<00:28,  1.44it/s, v_num=1, train_loss_step=1.31e+3, train_loss_epoch=1.34e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 61/100:  60%|██████    | 60/100 [00:39<00:28,  1.41it/s, v_num=1, train_loss_step=1.34e+3, train_loss_epoch=1.34e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 62/100:  61%|██████    | 61/100 [00:39<00:27,  1.42it/s, v_num=1, train_loss_step=1.35e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 63/100:  62%|██████▏   | 62/100 [00:40<00:26,  1.43it/s, v_num=1, train_loss_step=1.37e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 64/100:  63%|██████▎   | 63/100 [00:41<00:26,  1.42it/s, v_num=1, train_loss_step=1.37e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 65/100:  64%|██████▍   | 64/100 [00:41<00:25,  1.42it/s, v_num=1, train_loss_step=1.4e+3, train_loss_epoch=1.33e+3] 

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 66/100:  65%|██████▌   | 65/100 [00:42<00:24,  1.41it/s, v_num=1, train_loss_step=1.4e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 67/100:  66%|██████▌   | 66/100 [00:43<00:24,  1.41it/s, v_num=1, train_loss_step=1.3e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 68/100:  67%|██████▋   | 67/100 [00:44<00:23,  1.41it/s, v_num=1, train_loss_step=1.34e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 69/100:  68%|██████▊   | 68/100 [00:44<00:22,  1.39it/s, v_num=1, train_loss_step=1.31e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 70/100:  69%|██████▉   | 69/100 [00:45<00:22,  1.41it/s, v_num=1, train_loss_step=1.42e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 71/100:  70%|███████   | 70/100 [00:46<00:21,  1.42it/s, v_num=1, train_loss_step=1.28e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 72/100:  71%|███████   | 71/100 [00:46<00:20,  1.42it/s, v_num=1, train_loss_step=1.37e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 73/100:  72%|███████▏  | 72/100 [00:47<00:19,  1.42it/s, v_num=1, train_loss_step=1.31e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 74/100:  73%|███████▎  | 73/100 [00:48<00:18,  1.43it/s, v_num=1, train_loss_step=1.34e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 75/100:  74%|███████▍  | 74/100 [00:48<00:18,  1.43it/s, v_num=1, train_loss_step=1.27e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 76/100:  75%|███████▌  | 75/100 [00:49<00:17,  1.43it/s, v_num=1, train_loss_step=1.39e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 77/100:  76%|███████▌  | 76/100 [00:50<00:16,  1.45it/s, v_num=1, train_loss_step=1.31e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 78/100:  77%|███████▋  | 77/100 [00:50<00:15,  1.47it/s, v_num=1, train_loss_step=1.38e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 79/100:  78%|███████▊  | 78/100 [00:51<00:14,  1.51it/s, v_num=1, train_loss_step=1.3e+3, train_loss_epoch=1.33e+3] 

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 80/100:  79%|███████▉  | 79/100 [00:52<00:13,  1.51it/s, v_num=1, train_loss_step=1.36e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 81/100:  80%|████████  | 80/100 [00:52<00:13,  1.50it/s, v_num=1, train_loss_step=1.35e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 82/100:  81%|████████  | 81/100 [00:53<00:12,  1.54it/s, v_num=1, train_loss_step=1.34e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 83/100:  82%|████████▏ | 82/100 [00:54<00:11,  1.55it/s, v_num=1, train_loss_step=1.35e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 84/100:  83%|████████▎ | 83/100 [00:54<00:10,  1.56it/s, v_num=1, train_loss_step=1.28e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 85/100:  84%|████████▍ | 84/100 [00:55<00:10,  1.55it/s, v_num=1, train_loss_step=1.33e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 86/100:  85%|████████▌ | 85/100 [00:56<00:09,  1.56it/s, v_num=1, train_loss_step=1.39e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 87/100:  86%|████████▌ | 86/100 [00:56<00:08,  1.59it/s, v_num=1, train_loss_step=1.27e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 88/100:  87%|████████▋ | 87/100 [00:57<00:08,  1.59it/s, v_num=1, train_loss_step=1.35e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 89/100:  88%|████████▊ | 88/100 [00:57<00:07,  1.57it/s, v_num=1, train_loss_step=1.33e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 90/100:  89%|████████▉ | 89/100 [00:58<00:07,  1.56it/s, v_num=1, train_loss_step=1.35e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 91/100:  90%|█████████ | 90/100 [00:59<00:06,  1.50it/s, v_num=1, train_loss_step=1.4e+3, train_loss_epoch=1.33e+3] 

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 92/100:  91%|█████████ | 91/100 [01:00<00:05,  1.52it/s, v_num=1, train_loss_step=1.37e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 93/100:  92%|█████████▏| 92/100 [01:00<00:05,  1.54it/s, v_num=1, train_loss_step=1.39e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 94/100:  93%|█████████▎| 93/100 [01:01<00:04,  1.55it/s, v_num=1, train_loss_step=1.3e+3, train_loss_epoch=1.33e+3] 

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 95/100:  94%|█████████▍| 94/100 [01:01<00:03,  1.55it/s, v_num=1, train_loss_step=1.37e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 96/100:  95%|█████████▌| 95/100 [01:02<00:03,  1.56it/s, v_num=1, train_loss_step=1.3e+3, train_loss_epoch=1.33e+3] 

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 97/100:  96%|█████████▌| 96/100 [01:03<00:02,  1.55it/s, v_num=1, train_loss_step=1.33e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 98/100:  97%|█████████▋| 97/100 [01:03<00:01,  1.56it/s, v_num=1, train_loss_step=1.34e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 99/100:  98%|█████████▊| 98/100 [01:04<00:01,  1.56it/s, v_num=1, train_loss_step=1.4e+3, train_loss_epoch=1.33e+3] 

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 100/100:  99%|█████████▉| 99/100 [01:05<00:00,  1.57it/s, v_num=1, train_loss_step=1.3e+3, train_loss_epoch=1.33e+3]

  reconst_loss = -px.log_prob(x).sum(-1)
  reconst_loss = -px.log_prob(x).sum(-1)


Epoch 100/100: 100%|██████████| 100/100 [01:05<00:00,  1.53it/s, v_num=1, train_loss_step=1.4e+3, train_loss_epoch=1.33e+3]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 100/100: 100%|██████████| 100/100 [01:05<00:00,  1.52it/s, v_num=1, train_loss_step=1.4e+3, train_loss_epoch=1.33e+3]
target_adata shape: (3289, 4000)
source_adata shape: (13093, 4000)
adata_full shape: (16382, 4000)
adata_full.obs columns: Index(['study', 'cell_type', '_scvi_batch', '_scvi_labels', 'query'], dtype='object')
[34mINFO    [0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             


  self.validate_field(adata)


Recompute neighbors on rep X_pca instead of None
Cluster for cluster_0.2 with leiden



 To achieve the future defaults please pass: flavor="igraph" and n_iterations=2.  directed must also be False to work with igraph's implementation.
  cluster_function(adata, resolution=res, key_added=resolution_key, **kwargs)


Cluster for cluster_0.4 with leiden
Cluster for cluster_0.6 with leiden
Cluster for cluster_0.8 with leiden
Cluster for cluster_1.0 with leiden
Cluster for cluster_1.2 with leiden
Cluster for cluster_1.4 with leiden
Cluster for cluster_1.6 with leiden
Cluster for cluster_1.8 with leiden
Cluster for cluster_2.0 with leiden
NMI...
ARI...
Silhouette score...
PC regression...
Isolated labels F1...
Cluster for iso_label_0.2 with leiden


  batch_per_lab = tmp.groupby(label_key).agg({batch_key: "count"})


Cluster for iso_label_0.4 with leiden
Cluster for iso_label_0.6 with leiden
Cluster for iso_label_0.8 with leiden
Cluster for iso_label_1.0 with leiden
Cluster for iso_label_1.2 with leiden
Cluster for iso_label_1.4 with leiden
Cluster for iso_label_1.6 with leiden
Cluster for iso_label_1.8 with leiden
Cluster for iso_label_2.0 with leiden
Isolated labels ASW...


  batch_per_lab = tmp.groupby(label_key).agg({batch_key: "count"})


Graph connectivity...
Integration Scores:
    NMI_cluster/label  ARI_cluster/label  ASW_label  ASW_label/batch  \
0           0.937513            0.96392   0.667629         0.850226   

   PCR_batch  isolated_label_F1  isolated_label_silhouette  graph_conn  
0   0.670202           0.148936                   0.720116    0.986879  

Results:
Reference Time (seconds): 551.3490998744965
Query Time (seconds): 65.85595846176147
Integration Scores:
    NMI_cluster/label  ARI_cluster/label  ASW_label  ASW_label/batch  \
0           0.937513            0.96392   0.667629         0.850226   

   PCR_batch  isolated_label_F1  isolated_label_silhouette  graph_conn  
0   0.670202           0.148936                   0.720116    0.986879  


  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
  tab = pd.value_counts(labels)
