In [30]:
import torch
import os
import pandas as pd
import numpy as np

In [41]:
def calc_active(model, x, gate = 0.01):
        # z1 ~ q(z1 | x)
        z1_q_mean, z1_q_logvar = model.q_z1(x)
        z1_q = model.reparameterize(z1_q_mean, z1_q_logvar)

        # z2 ~ q(z2 | x, z1)
        z2_q_mean, _ = model.q_z2(z1_q)

        active_z1 = (z1_q_mean.var(0) > gate).sum()
        active_z2 = (z2_q_mean.var(0) > gate).sum()
        return active_z1, active_z2

In [40]:

# start processing
def lines_to_np_array(lines):
    return np.array([[int(i) for i in line.split()] for line in lines])
with open(os.path.join('datasets', 'MNIST_static', 'binarized_mnist_train.amat')) as f:
    lines = f.readlines()
x_train = lines_to_np_array(lines).astype('float32')
with open(os.path.join('datasets', 'MNIST_static', 'binarized_mnist_valid.amat')) as f:
    lines = f.readlines()
x_val = lines_to_np_array(lines).astype('float32')
with open(os.path.join('datasets', 'MNIST_static', 'binarized_mnist_test.amat')) as f:
    lines = f.readlines()
x_test = lines_to_np_array(lines).astype('float32')
x_test = torch.from_numpy(x_test)
# shuffle train data
np.random.shuffle(x_train)

type(x_test)

torch.Tensor

In [51]:
results_table = pd.DataFrame(index=['test_kl', 'test_ll', 'test_loss', 'test_re', 'train_time', 'train_kl', 'train_loss', 'train_re', 'val_kl', 'val_loss', 'val_re', 'z1_active', 'z2_active'])

for mod in os.listdir('./snapshots'):
    # model
    try:
        model = torch.load('./snapshots/' + mod + '/vae_2level.model', map_location=torch.device('cpu'))
    except:
        model = torch.load('./snapshots/' + mod + '/hvae_2level.model', map_location=torch.device('cpu'))
    model.args.cuda = False
    prior = model.args.prior
    wu = model.args.warmup
    if model.args.model_name == 'hvae_2level':
        name = model.args.model_name
        # test stats
        test_kl = torch.load('./snapshots/' + mod + '/hvae_2level.test_kl', map_location=torch.device('cpu'))
        test_ll = torch.load('./snapshots/' + mod + '/hvae_2level.test_log_likelihood', map_location=torch.device('cpu'))
        test_loss = torch.load('./snapshots/' + mod + '/hvae_2level.test_loss', map_location=torch.device('cpu'))
        test_re = torch.load('./snapshots/' + mod + '/hvae_2level.test_re', map_location=torch.device('cpu'))

        # train stats
        train_time = torch.load('./snapshots/' + mod + '/hvae_2level.train_time', map_location=torch.device('cpu'))
        train_kl = torch.load('./snapshots/' + mod + '/hvae_2level.train_kl', map_location=torch.device('cpu'))
        train_loss = torch.load('./snapshots/' + mod + '/hvae_2level.train_loss', map_location=torch.device('cpu'))
        train_re = torch.load('./snapshots/' + mod + '/hvae_2level.train_re', map_location=torch.device('cpu'))

        # validation stats
        val_kl = torch.load('./snapshots/' + mod + '/hvae_2level.val_kl', map_location=torch.device('cpu'))
        val_loss = torch.load('./snapshots/' + mod + '/hvae_2level.val_loss', map_location=torch.device('cpu'))
        val_re = torch.load('./snapshots/' + mod + '/hvae_2level.val_re', map_location=torch.device('cpu'))
        # active units
        x_mean, x_logvar, z1_q, z1_q_mean, z1_q_logvar, z2_q, z2_q_mean, z2_q_logvar, z1_p_mean, z1_p_logvar = model.forward(x_test)
        a_z1 = (z1_q_mean.var(0) > 0.01).sum().detach().numpy()
        a_z2 = (z2_q_mean.var(0) > 0.01).sum().detach().numpy()
    else:
        name = str(prior) + ' ' + str(wu)
        # test stats
        test_kl = torch.load('./snapshots/' + mod + '/vae_2level.test_kl', map_location=torch.device('cpu'))
        test_ll = torch.load('./snapshots/' + mod + '/vae_2level.test_log_likelihood', map_location=torch.device('cpu'))
        test_loss = torch.load('./snapshots/' + mod + '/vae_2level.test_loss', map_location=torch.device('cpu'))
        test_re = torch.load('./snapshots/' + mod + '/vae_2level.test_re', map_location=torch.device('cpu'))

        # train stats
        train_time = torch.load('./snapshots/' + mod + '/vae_2level.train_time', map_location=torch.device('cpu'))
        train_kl = torch.load('./snapshots/' + mod + '/vae_2level.train_kl', map_location=torch.device('cpu'))
        train_loss = torch.load('./snapshots/' + mod + '/vae_2level.train_loss', map_location=torch.device('cpu'))
        train_re = torch.load('./snapshots/' + mod + '/vae_2level.train_re', map_location=torch.device('cpu'))

        # validation stats
        val_kl = torch.load('./snapshots/' + mod + '/vae_2level.val_kl', map_location=torch.device('cpu'))
        val_loss = torch.load('./snapshots/' + mod + '/vae_2level.val_loss', map_location=torch.device('cpu'))
        val_re = torch.load('./snapshots/' + mod + '/vae_2level.val_re', map_location=torch.device('cpu'))
        # active units
        a_z1, a_z2 = calc_active(model, x_test)

    
    results_table[name] = [test_kl, test_ll, test_loss, test_re, train_time, train_kl, train_loss, train_re, val_kl, val_loss, val_re, a_z1, a_z2]

In [52]:
results_table

Unnamed: 0,hvae_2level,clust_kmeans 0,clust_kmeans 100,mbap_prior 0,mbap_prior 100,standard 0,standard 100,vampprior 0,vampprior 100,vampprior_data 0,vampprior_data 100
test_kl,30.257203,5.752224,34.546187,5.819802,32.662582,6.44629,28.062474,5.740097,27.170851,5.622086,34.349338
test_ll,84.967995,152.454656,94.440438,153.19038,93.37371,152.787488,88.465729,152.456069,86.335088,152.453715,94.684212
test_loss,91.181697,159.829739,103.314187,160.235465,103.207789,159.769806,97.745801,159.399064,94.358362,159.1005,103.352434
test_re,60.924493,154.077517,68.767999,154.415663,70.545207,153.323515,69.683327,153.658967,67.187511,153.478414,69.003096
train_time,2609.013701,1832.59012,3172.190336,1647.113178,3629.188145,1258.879041,3027.803665,1496.481309,2885.57992,1450.843928,3391.062953
train_kl,"[181.54531829833985, 187.7400152282715, 178.56...","[16.275208751678466, 4.178750826358796, 4.0597...","[480.34860348510745, 333.8680689086914, 223.81...","[16.263694442749024, 4.139534892559052, 4.0178...","[488.93616830444336, 347.2642878417969, 231.23...","[21.558244366645813, 5.826872514724731, 4.7427...","[507.2054025878906, 354.03310748291017, 241.80...","[16.140918332576753, 4.157890078544617, 4.0165...","[478.51795349121096, 316.3365305175781, 197.81...","[16.53932852125168, 4.19503119468689, 4.043787...","[486.5944922485352, 341.7694044189453, 225.124..."
train_loss,"[108.16027680969238, 55.27256597137451, 48.164...","[235.6432289428711, 187.93206024169922, 184.44...","[159.42949459838866, 94.51091096496582, 83.776...","[235.19637060546876, 187.55228057861328, 184.0...","[158.13955870056154, 95.1709345703125, 84.6222...","[247.26938970947265, 190.48711450195313, 185.5...","[157.75515055847168, 94.40207637023926, 84.100...","[236.41710397338866, 187.79476635742188, 184.4...","[157.51761471557617, 94.11974325561523, 84.003...","[235.7508332519531, 187.675071685791, 184.2381...","[159.6823553161621, 94.4829789428711, 84.00370..."
train_re,"[106.34482405853271, 51.51776564788818, 42.807...","[219.36801998901367, 183.75330908203125, 180.3...","[154.6260079345703, 87.83354974365234, 77.0619...","[218.93267639160158, 183.4127462158203, 180.04...","[153.25019789123536, 88.22564868164062, 77.685...","[225.7111459350586, 184.66024301147462, 180.79...","[152.6830968170166, 87.32141418457032, 76.8461...","[220.2761849975586, 183.63687603759766, 180.40...","[152.73243563842775, 87.79301266479492, 78.068...","[219.21150497436523, 183.4800398864746, 180.19...","[154.81641052246093, 87.64759107971192, 77.249..."
val_kl,"[188.20921966552734, 189.12368270874023, 171.4...","[4.642967605590821, 4.032795147895813, 4.23125...","[469.2695767211914, 293.2316912841797, 203.085...","[4.320464270114899, 3.981866137981415, 3.90628...","[500.4926885986328, 285.8754916381836, 203.974...","[7.090683732032776, 5.550388746261596, 4.68908...","[501.89875762939454, 301.15468658447264, 218.6...","[4.378635857105255, 4.085849161148071, 4.08691...","[469.8648129272461, 259.19372665405274, 175.76...","[4.368508944511413, 4.0654066371917725, 4.2670...","[509.11985870361326, 290.25888580322265, 205.1..."
val_loss,"[248.18038619995116, 235.23638534545898, 213.1...","[190.91449325561524, 186.40136978149414, 183.1...","[565.9614434814453, 374.2937957763672, 277.336...","[190.49543090820313, 185.88784072875976, 182.4...","[597.2965222167969, 367.4736706542969, 278.351...","[194.9666860961914, 188.40767593383788, 184.11...","[598.0629504394532, 381.67625549316404, 293.17...","[190.60305755615235, 186.4501174926758, 182.98...","[566.2886431884766, 340.5574176025391, 250.753...","[190.44394149780274, 186.5197654724121, 183.26...","[606.4493835449218, 371.15963470458985, 279.50..."
