# Playground

In [1]:
from typing import List, Set, Dict, Tuple, Optional, Any
from collections import defaultdict

import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

import math 
import torch
from torch import nn, Tensor
from torch.nn.functional import softplus, relu
from torch.distributions import Distribution, Normal
from torch.utils.data import DataLoader

from gmfpp.utils.data_preparation import *
from gmfpp.utils.data_transformers import *
from gmfpp.utils.plotting import *

from gmfpp.models.ReparameterizedDiagonalGaussian import *
from gmfpp.models.CytoVariationalAutoencoder import *
from gmfpp.models.VariationalAutoencoder import *
from gmfpp.models.ConvVariationalAutoencoder import *
from gmfpp.models.VariationalInference import *

%matplotlib inline

In [2]:
torch.manual_seed(0)
torch.cuda.manual_seed(0)

## Load data

In [3]:
metadata = read_metadata("./data/tiny/metadata.csv")

In [4]:
relative_path = get_relative_image_paths(metadata)
image_paths = ["./data/tiny/" + path for path in relative_path]
images = load_images(image_paths)

In [5]:
len(images)

259

## VAE

In [6]:
train_set = prepare_raw_images(images)
normalize_channels_inplace(train_set)
print(train_set.shape)

torch.Size([259, 3, 68, 68])


In [7]:
channel_first = view_channel_dim_first(train_set)
for i in range(channel_first.shape[0]):
    channel = channel_first[i]
    print("channel {} interval: [{:.2f}; {:.2f}]".format(i, torch.min(channel), torch.max(channel)))

channel 0 interval: [0.02; 1.00]
channel 1 interval: [0.04; 1.00]
channel 2 interval: [0.05; 1.00]


In [8]:
# VAE
image_shape = np.array([3, 68, 68])
latent_features = 256
vae = CytoVariationalAutoencoder(image_shape, latent_features)
#vae = VariationalAutoencoder(image_shape, latent_features)

beta = 1
vi = VariationalInference(beta=beta)

# The Adam optimizer works really well with VAEs.
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-2, weight_decay=10e-4)

# define dictionary to store the training curves
training_data = defaultdict(list)
validation_data = defaultdict(list)

In [9]:
num_epochs = 1000
batch_size = 16

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f">> Using device: {device}")

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)

# move the model to the device
vae = vae.to(device)

# training..

for epoch in range(num_epochs):
    print(f"epoch: {epoch}/{num_epochs}")    

    training_epoch_data = defaultdict(list)
    vae.train()

    for x in train_loader:
        x = x.to(device)
        
        # perform a forward pass through the model and compute the ELBO
        loss, diagnostics, outputs = vi(vae, x)

        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(vae.parameters(), 10_000)
        optimizer.step()

        # gather data for the current batch
        for k, v in diagnostics.items():
            training_epoch_data[k] += [v.mean().item()]

    print("training | elbo: {:2f}, log_px: {:.2f}, kl: {:.2f}:".format(np.mean(training_epoch_data["elbo"]), np.mean(training_epoch_data["log_px"]), np.mean(training_epoch_data["kl"])))

    # gather data for the full epoch
    for k, v in training_epoch_data.items():
        training_data[k] += [np.mean(training_epoch_data[k])]

    # Evaluate on a single batch, do not propagate gradients
    with torch.no_grad():
        vae.eval()

        # Just load a single batch from the test loader
        '''x, y = next(iter(test_loader))'''
        x = x.to(device)

        # perform a forward pass through the model and compute the ELBO
        loss, diagnostics, outputs = vi(vae, x)

        # gather data for the validation step
        for k, v in diagnostics.items():
            validation_data[k] += [v.mean().item()]

    print("validation | elbo: {:2f}, log_px: {:.2f}, kl: {:.2f}:".format(np.mean(validation_data["elbo"]), np.mean(validation_data["log_px"]), np.mean(validation_data["kl"])))    

>> Using device: cpu
epoch: 0/1000
training | elbo: -13455.945923, log_px: -12861.16, kl: 594.79:
validation | elbo: -12656540.000000, log_px: -12588970.00, kl: 67570.55:
epoch: 1/1000
training | elbo: -12837.380310, log_px: -12564.46, kl: 272.92:
validation | elbo: -6334754.938477, log_px: -6300692.52, kl: 34062.69:
epoch: 2/1000
training | elbo: -12549.311401, log_px: -12351.60, kl: 197.71:
validation | elbo: -4227277.443034, log_px: -4204527.34, kl: 22750.29:
epoch: 3/1000
training | elbo: -12121.643494, log_px: -11971.73, kl: 149.91:
validation | elbo: -3173421.459229, log_px: -3156323.25, kl: 17098.34:
epoch: 4/1000
training | elbo: -11437.510620, log_px: -11313.97, kl: 123.54:
validation | elbo: -2540923.657617, log_px: -2527223.03, kl: 13700.74:
epoch: 5/1000
training | elbo: -10397.099487, log_px: -10290.65, kl: 106.45:
validation | elbo: -2119069.248372, log_px: -2107640.90, kl: 11428.44:
epoch: 6/1000
training | elbo: -8885.671509, log_px: -8792.54, kl: 93.13:
validation | el

training | elbo: 7655.211212, log_px: 7731.86, kl: 76.65:
validation | elbo: -209670.405483, log_px: -208451.71, kl: 1218.71:
epoch: 59/1000
training | elbo: 9893.940125, log_px: 9970.15, kl: 76.21:
validation | elbo: -205981.936893, log_px: -204781.84, kl: 1200.10:
epoch: 60/1000
training | elbo: 10874.103485, log_px: 10952.52, kl: 78.42:
validation | elbo: -202415.371761, log_px: -201233.83, kl: 1181.56:
epoch: 61/1000
training | elbo: 10433.924034, log_px: 10512.89, kl: 78.97:
validation | elbo: -198962.120364, log_px: -197798.44, kl: 1163.69:
epoch: 62/1000
training | elbo: 11148.658020, log_px: 11228.52, kl: 79.86:
validation | elbo: -195618.209724, log_px: -194472.23, kl: 1145.98:
epoch: 63/1000
training | elbo: 11830.310486, log_px: 11911.99, kl: 81.68:
validation | elbo: -192393.672359, log_px: -191264.04, kl: 1129.65:
epoch: 64/1000
training | elbo: 11921.166267, log_px: 12002.94, kl: 81.78:
validation | elbo: -189227.590049, log_px: -188114.55, kl: 1113.04:
epoch: 65/1000
tra

training | elbo: 15770.076355, log_px: 15882.92, kl: 112.84:
validation | elbo: -99701.149812, log_px: -99042.57, kl: 658.58:
epoch: 117/1000
training | elbo: 11559.078033, log_px: 11671.71, kl: 112.63:
validation | elbo: -98762.016971, log_px: -98108.09, kl: 653.93:
epoch: 118/1000
training | elbo: 10418.291260, log_px: 10530.98, kl: 112.69:
validation | elbo: -97825.237113, log_px: -97175.70, kl: 649.54:
epoch: 119/1000
training | elbo: 14024.111206, log_px: 14139.41, kl: 115.30:
validation | elbo: -96862.654620, log_px: -96217.79, kl: 644.87:
epoch: 120/1000
training | elbo: 15381.466309, log_px: 15498.00, kl: 116.53:
validation | elbo: -95981.275114, log_px: -95340.76, kl: 640.52:
epoch: 121/1000
training | elbo: 11564.894379, log_px: 11680.94, kl: 116.05:
validation | elbo: -95124.914580, log_px: -94488.86, kl: 636.06:
epoch: 122/1000
training | elbo: 11624.286179, log_px: 11740.57, kl: 116.29:
validation | elbo: -94218.743606, log_px: -93587.12, kl: 631.62:
epoch: 123/1000
traini

training | elbo: 11079.349949, log_px: 11221.65, kl: 142.30:
validation | elbo: -61922.970233, log_px: -61441.22, kl: 481.76:
epoch: 175/1000
training | elbo: 11308.817078, log_px: 11451.93, kl: 143.11:
validation | elbo: -61485.237619, log_px: -61005.45, kl: 479.79:
epoch: 176/1000
training | elbo: 15739.684570, log_px: 15883.47, kl: 143.79:
validation | elbo: -61084.046465, log_px: -60606.14, kl: 477.91:
epoch: 177/1000
training | elbo: 9480.737923, log_px: 9624.10, kl: 143.37:
validation | elbo: -60705.482459, log_px: -60229.59, kl: 475.90:
epoch: 178/1000
training | elbo: 14528.037048, log_px: 14672.29, kl: 144.25:
validation | elbo: -60271.044489, log_px: -59797.07, kl: 473.98:
epoch: 179/1000
training | elbo: 10520.283653, log_px: 10663.37, kl: 143.08:
validation | elbo: -59843.779789, log_px: -59371.66, kl: 472.12:
epoch: 180/1000
training | elbo: 15545.506836, log_px: 15688.78, kl: 143.28:
validation | elbo: -59417.005539, log_px: -58946.69, kl: 470.31:
epoch: 181/1000
training

training | elbo: 13964.750549, log_px: 14103.67, kl: 138.92:
validation | elbo: -43144.471196, log_px: -42747.09, kl: 397.39:
epoch: 233/1000
training | elbo: 13000.812195, log_px: 13139.99, kl: 139.17:
validation | elbo: -42888.505908, log_px: -42492.27, kl: 396.24:
epoch: 234/1000
training | elbo: 14133.722046, log_px: 14274.68, kl: 140.96:
validation | elbo: -42644.905829, log_px: -42249.76, kl: 395.15:
epoch: 235/1000
training | elbo: 14417.607727, log_px: 14555.92, kl: 138.31:
validation | elbo: -42413.629584, log_px: -42019.62, kl: 394.02:
epoch: 236/1000
training | elbo: 15012.026306, log_px: 15151.32, kl: 139.29:
validation | elbo: -42181.823558, log_px: -41788.91, kl: 392.92:
epoch: 237/1000
training | elbo: 14946.853210, log_px: 15085.70, kl: 138.85:
validation | elbo: -41934.950596, log_px: -41543.16, kl: 391.80:
epoch: 238/1000
training | elbo: 14934.600342, log_px: 15073.18, kl: 138.58:
validation | elbo: -41692.180221, log_px: -41301.43, kl: 390.76:
epoch: 239/1000
traini

training | elbo: 10359.397864, log_px: 10505.39, kl: 145.99:
validation | elbo: -31509.422589, log_px: -31162.04, kl: 347.39:
epoch: 291/1000
training | elbo: 14546.325562, log_px: 14696.06, kl: 149.73:
validation | elbo: -31343.675683, log_px: -30996.95, kl: 346.73:
epoch: 292/1000
training | elbo: 14229.192749, log_px: 14377.18, kl: 147.99:
validation | elbo: -31195.403059, log_px: -30849.21, kl: 346.19:
epoch: 293/1000
training | elbo: 13691.469116, log_px: 13840.74, kl: 149.27:
validation | elbo: -31050.227262, log_px: -30704.78, kl: 345.45:
epoch: 294/1000
training | elbo: 14621.269531, log_px: 14770.41, kl: 149.14:
validation | elbo: -30892.170447, log_px: -30547.41, kl: 344.77:
epoch: 295/1000
training | elbo: 15073.347534, log_px: 15222.79, kl: 149.44:
validation | elbo: -30732.928799, log_px: -30388.89, kl: 344.04:
epoch: 296/1000
training | elbo: 16271.272095, log_px: 16421.50, kl: 150.22:
validation | elbo: -30580.937591, log_px: -30237.62, kl: 343.32:
epoch: 297/1000
traini

training | elbo: 17585.391296, log_px: 17761.60, kl: 176.20:
validation | elbo: -23916.779232, log_px: -23599.12, kl: 317.66:
epoch: 349/1000
training | elbo: 16989.079529, log_px: 17165.13, kl: 176.06:
validation | elbo: -23795.815376, log_px: -23478.56, kl: 317.26:
epoch: 350/1000
training | elbo: 16126.056213, log_px: 16302.07, kl: 176.02:
validation | elbo: -23685.639058, log_px: -23368.83, kl: 316.82:
epoch: 351/1000
training | elbo: 14892.025543, log_px: 15069.63, kl: 177.60:
validation | elbo: -23578.426815, log_px: -23261.97, kl: 316.46:
epoch: 352/1000
training | elbo: 15527.745728, log_px: 15707.15, kl: 179.41:
validation | elbo: -23473.272051, log_px: -23157.07, kl: 316.21:
epoch: 353/1000
training | elbo: 15738.363464, log_px: 15916.38, kl: 178.01:
validation | elbo: -23357.530331, log_px: -23041.78, kl: 315.75:
epoch: 354/1000
training | elbo: 18004.201233, log_px: 18183.16, kl: 178.96:
validation | elbo: -23239.625211, log_px: -22924.25, kl: 315.38:
epoch: 355/1000
traini

training | elbo: 15136.102631, log_px: 15319.78, kl: 183.68:
validation | elbo: -18218.084189, log_px: -17920.04, kl: 298.05:
epoch: 407/1000
training | elbo: 18295.731445, log_px: 18478.65, kl: 182.92:
validation | elbo: -18124.966512, log_px: -17827.18, kl: 297.79:
epoch: 408/1000
training | elbo: 19413.751526, log_px: 19596.52, kl: 182.77:
validation | elbo: -18027.273813, log_px: -17729.81, kl: 297.47:
epoch: 409/1000
training | elbo: 17951.628845, log_px: 18134.00, kl: 182.37:
validation | elbo: -17933.292419, log_px: -17636.16, kl: 297.13:
epoch: 410/1000
training | elbo: 18391.897461, log_px: 18573.72, kl: 181.82:
validation | elbo: -17856.461242, log_px: -17559.64, kl: 296.83:
epoch: 411/1000
training | elbo: 17372.850525, log_px: 17556.00, kl: 183.15:
validation | elbo: -17781.934440, log_px: -17485.30, kl: 296.64:
epoch: 412/1000
training | elbo: 13425.617249, log_px: 13609.22, kl: 183.61:
validation | elbo: -17700.616762, log_px: -17404.31, kl: 296.31:
epoch: 413/1000
traini

training | elbo: 16279.527466, log_px: 16471.17, kl: 191.65:
validation | elbo: -13947.264261, log_px: -13663.05, kl: 284.22:
epoch: 465/1000
training | elbo: 18899.961243, log_px: 19090.14, kl: 190.18:
validation | elbo: -13886.030364, log_px: -13602.07, kl: 283.96:
epoch: 466/1000
training | elbo: 15767.877075, log_px: 15957.81, kl: 189.93:
validation | elbo: -13822.358629, log_px: -13538.65, kl: 283.71:
epoch: 467/1000
training | elbo: 17599.136902, log_px: 17789.19, kl: 190.05:
validation | elbo: -13748.884694, log_px: -13465.41, kl: 283.48:
epoch: 468/1000
training | elbo: 18039.508606, log_px: 18230.74, kl: 191.23:
validation | elbo: -13677.666831, log_px: -13394.40, kl: 283.26:
epoch: 469/1000
training | elbo: 19579.533691, log_px: 19768.59, kl: 189.06:
validation | elbo: -13612.414988, log_px: -13329.32, kl: 283.09:
epoch: 470/1000
training | elbo: 17307.842987, log_px: 17496.57, kl: 188.72:
validation | elbo: -13546.626852, log_px: -13263.79, kl: 282.84:
epoch: 471/1000
traini

training | elbo: 20046.806763, log_px: 20238.45, kl: 191.64:
validation | elbo: -10364.894745, log_px: -10091.17, kl: 273.73:
epoch: 523/1000
training | elbo: 18220.057953, log_px: 18410.68, kl: 190.62:
validation | elbo: -10308.436640, log_px: -10034.88, kl: 273.56:
epoch: 524/1000
training | elbo: 15956.311325, log_px: 16147.92, kl: 191.61:
validation | elbo: -10252.486478, log_px: -9979.11, kl: 273.38:
epoch: 525/1000
training | elbo: 20234.178955, log_px: 20426.96, kl: 192.78:
validation | elbo: -10196.348478, log_px: -9923.16, kl: 273.19:
epoch: 526/1000
training | elbo: 20418.713379, log_px: 20610.33, kl: 191.62:
validation | elbo: -10141.384893, log_px: -9868.35, kl: 273.04:
epoch: 527/1000
training | elbo: 18401.565125, log_px: 18593.70, kl: 192.14:
validation | elbo: -10081.229493, log_px: -9808.33, kl: 272.90:
epoch: 528/1000
training | elbo: 18697.040192, log_px: 18888.87, kl: 191.83:
validation | elbo: -10042.104236, log_px: -9769.36, kl: 272.74:
epoch: 529/1000
training | 

training | elbo: 19490.783081, log_px: 19705.31, kl: 214.52:
validation | elbo: -7429.552659, log_px: -7163.74, kl: 265.81:
epoch: 582/1000
training | elbo: 20343.000610, log_px: 20558.25, kl: 215.25:
validation | elbo: -7379.095346, log_px: -7113.39, kl: 265.70:
epoch: 583/1000
training | elbo: 22221.108154, log_px: 22434.79, kl: 213.68:
validation | elbo: -7327.712090, log_px: -7062.14, kl: 265.57:
epoch: 584/1000
training | elbo: 21585.474854, log_px: 21799.22, kl: 213.74:
validation | elbo: -7285.813854, log_px: -7020.35, kl: 265.46:
epoch: 585/1000
training | elbo: 18197.841553, log_px: 18411.81, kl: 213.96:
validation | elbo: -7248.696817, log_px: -6983.34, kl: 265.36:
epoch: 586/1000
training | elbo: 18112.634445, log_px: 18326.84, kl: 214.21:
validation | elbo: -7211.309953, log_px: -6946.06, kl: 265.25:
epoch: 587/1000
training | elbo: 16108.307190, log_px: 16322.21, kl: 213.90:
validation | elbo: -7162.727165, log_px: -6897.57, kl: 265.16:
epoch: 588/1000
training | elbo: 220

training | elbo: 19028.442871, log_px: 19236.56, kl: 208.11:
validation | elbo: -4958.762138, log_px: -4698.46, kl: 260.30:
epoch: 641/1000
training | elbo: 15759.974339, log_px: 15967.05, kl: 207.08:
validation | elbo: -4920.555453, log_px: -4660.36, kl: 260.20:
epoch: 642/1000
training | elbo: 18602.719299, log_px: 18809.40, kl: 206.68:
validation | elbo: -4883.774484, log_px: -4623.70, kl: 260.07:
epoch: 643/1000
training | elbo: 20058.567749, log_px: 20262.74, kl: 204.18:
validation | elbo: -4839.977850, log_px: -4579.99, kl: 259.99:
epoch: 644/1000
training | elbo: 19566.736847, log_px: 19769.02, kl: 202.29:
validation | elbo: -4801.596580, log_px: -4541.70, kl: 259.90:
epoch: 645/1000
training | elbo: 20860.491821, log_px: 21063.05, kl: 202.56:
validation | elbo: -4761.017676, log_px: -4501.23, kl: 259.79:
epoch: 646/1000
training | elbo: 21448.437012, log_px: 21650.72, kl: 202.28:
validation | elbo: -4719.310938, log_px: -4459.61, kl: 259.70:
epoch: 647/1000
training | elbo: 201

training | elbo: 18464.208069, log_px: 18699.76, kl: 235.55:
validation | elbo: -2670.282001, log_px: -2413.98, kl: 256.30:
epoch: 700/1000
training | elbo: 20288.505798, log_px: 20523.83, kl: 235.32:
validation | elbo: -2633.767673, log_px: -2377.50, kl: 256.26:
epoch: 701/1000
training | elbo: 20022.921387, log_px: 20255.68, kl: 232.76:
validation | elbo: -2597.599593, log_px: -2341.36, kl: 256.24:
epoch: 702/1000
training | elbo: 21812.725830, log_px: 22043.43, kl: 230.71:
validation | elbo: -2560.186764, log_px: -2304.00, kl: 256.19:
epoch: 703/1000
training | elbo: 22604.292114, log_px: 22834.26, kl: 229.96:
validation | elbo: -2527.265847, log_px: -2271.13, kl: 256.14:
epoch: 704/1000
training | elbo: 22882.332031, log_px: 23110.72, kl: 228.39:
validation | elbo: -2486.340437, log_px: -2230.24, kl: 256.10:
epoch: 705/1000
training | elbo: 23149.151611, log_px: 23376.93, kl: 227.78:
validation | elbo: -2449.777820, log_px: -2193.71, kl: 256.07:
epoch: 706/1000
training | elbo: 204

training | elbo: 21141.775032, log_px: 21374.70, kl: 232.93:
validation | elbo: -737.523996, log_px: -483.66, kl: 253.87:
epoch: 759/1000
training | elbo: 21552.824707, log_px: 21784.60, kl: 231.78:
validation | elbo: -709.478418, log_px: -455.65, kl: 253.83:
epoch: 760/1000
training | elbo: 14504.248199, log_px: 14734.72, kl: 230.47:
validation | elbo: -701.753418, log_px: -447.95, kl: 253.80:
epoch: 761/1000
training | elbo: 17387.881500, log_px: 17616.39, kl: 228.50:
validation | elbo: -670.414162, log_px: -416.65, kl: 253.76:
epoch: 762/1000
training | elbo: 22879.746704, log_px: 23104.47, kl: 224.72:
validation | elbo: -639.551479, log_px: -385.77, kl: 253.78:
epoch: 763/1000
training | elbo: 23487.026367, log_px: 23711.65, kl: 224.62:
validation | elbo: -605.339751, log_px: -351.58, kl: 253.76:
epoch: 764/1000
training | elbo: 23858.787964, log_px: 24085.06, kl: 226.27:
validation | elbo: -574.231673, log_px: -320.52, kl: 253.72:
epoch: 765/1000
training | elbo: 22115.464355, log

training | elbo: 24458.932251, log_px: 24683.45, kl: 224.52:
validation | elbo: 926.347166, log_px: 1178.82, kl: 252.47:
epoch: 820/1000
training | elbo: 20735.385254, log_px: 20962.71, kl: 227.32:
validation | elbo: 955.599707, log_px: 1208.04, kl: 252.44:
epoch: 821/1000
training | elbo: 20870.034363, log_px: 21098.16, kl: 228.13:
validation | elbo: 983.514185, log_px: 1235.92, kl: 252.41:
epoch: 822/1000
training | elbo: 22916.785645, log_px: 23145.04, kl: 228.25:
validation | elbo: 1012.021126, log_px: 1264.39, kl: 252.37:
epoch: 823/1000
training | elbo: 22555.182678, log_px: 22781.97, kl: 226.79:
validation | elbo: 1034.157383, log_px: 1286.50, kl: 252.34:
epoch: 824/1000
training | elbo: 23820.937012, log_px: 24048.09, kl: 227.15:
validation | elbo: 1066.095055, log_px: 1318.40, kl: 252.30:
epoch: 825/1000
training | elbo: 25048.706177, log_px: 25277.21, kl: 228.50:
validation | elbo: 1097.563971, log_px: 1349.85, kl: 252.29:
epoch: 826/1000
training | elbo: 24867.663696, log_px

validation | elbo: 2471.457917, log_px: 2723.25, kl: 251.80:
epoch: 879/1000
training | elbo: 24597.222046, log_px: 24842.80, kl: 245.58:
validation | elbo: 2499.977940, log_px: 2751.76, kl: 251.78:
epoch: 880/1000
training | elbo: 23206.447632, log_px: 23452.66, kl: 246.22:
validation | elbo: 2525.155338, log_px: 2776.94, kl: 251.79:
epoch: 881/1000
training | elbo: 25230.036499, log_px: 25476.00, kl: 245.96:
validation | elbo: 2543.447719, log_px: 2795.20, kl: 251.76:
epoch: 882/1000
training | elbo: 22907.006836, log_px: 23152.75, kl: 245.75:
validation | elbo: 2565.493738, log_px: 2817.25, kl: 251.76:
epoch: 883/1000
training | elbo: 23417.155884, log_px: 23663.34, kl: 246.18:
validation | elbo: 2586.227179, log_px: 2837.97, kl: 251.74:
epoch: 884/1000
training | elbo: 24976.334106, log_px: 25222.21, kl: 245.88:
validation | elbo: 2608.065543, log_px: 2859.80, kl: 251.73:
epoch: 885/1000
training | elbo: 23944.497925, log_px: 24189.52, kl: 245.03:
validation | elbo: 2635.877552, lo

training | elbo: 24190.776123, log_px: 24448.85, kl: 258.07:
validation | elbo: 3832.563739, log_px: 4084.25, kl: 251.69:
epoch: 939/1000
training | elbo: 23541.829956, log_px: 23799.10, kl: 257.27:
validation | elbo: 3855.582993, log_px: 4107.30, kl: 251.72:
epoch: 940/1000
training | elbo: 23270.382935, log_px: 23527.58, kl: 257.20:
validation | elbo: 3879.266872, log_px: 4130.99, kl: 251.72:
epoch: 941/1000
training | elbo: 24616.811157, log_px: 24873.77, kl: 256.96:
validation | elbo: 3902.529402, log_px: 4154.26, kl: 251.73:
epoch: 942/1000
training | elbo: 23723.571472, log_px: 23979.07, kl: 255.50:
validation | elbo: 3924.769235, log_px: 4176.51, kl: 251.75:
epoch: 943/1000
training | elbo: 21629.895752, log_px: 21886.40, kl: 256.51:
validation | elbo: 3946.870582, log_px: 4198.64, kl: 251.77:
epoch: 944/1000
training | elbo: 26352.904419, log_px: 26608.77, kl: 255.86:
validation | elbo: 3970.816259, log_px: 4222.61, kl: 251.79:
epoch: 945/1000
training | elbo: 26827.471191, log

training | elbo: 24190.223022, log_px: 24448.75, kl: 258.52:
validation | elbo: 5142.235342, log_px: 5394.23, kl: 252.00:
epoch: 999/1000
training | elbo: 25782.489746, log_px: 26041.05, kl: 258.56:
validation | elbo: 5165.770759, log_px: 5417.78, kl: 252.01:


In [None]:
plt.plot(training_data["elbo"])

## Compare reconstruction and original image

In [None]:
x = train_set[0]

In [None]:
plot_image(x)

In [None]:
# vae.eval() # because of batch normalization
outputs = vae(x[None,:,:,:])
px = outputs["px"]

x_reconstruction = px.sample()
x_reconstruction = x_reconstruction[0]
plot_image_channels(x_reconstruction)

In [None]:
plot_image_channels(x)

In [None]:
x_reconstruction = px.sample()
x_reconstruction = x_reconstruction[0]
plot_image(clip_image_to_zero_one(x_reconstruction))