In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
sys.path.append(os.getcwd() + '/core')
sys.path.append(os.getcwd() + '/core/train_GLOW') 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import json
from datetime import datetime

import core.config as config
from core.model_loader import load_pretrained_VAE, load_pretrained_GLOW
from core.data_loader import TRAIN_loader, TEST_loader
from core.custom_loss import KL_div, VAE_loss
from core.fisher_utils_VAE import Calculate_fisher_VAE, Calculate_score_VAE
from core.fisher_utils_VAE import AUTO_VAE_CIFAR, AUTO_VAE_FMNIST
from core.fisher_utils_GLOW import Calculate_fisher_GLOW, Calculate_score_GLOW
from core.fisher_utils_GLOW import AUTO_GLOW_CIFAR, AUTO_GLOW_FMNIST
from core.visualize import plot_hist, AUROC

# Define Global Variables & Initialize

In [2]:
# FISHER_invs : Fisher inverse matrices from calculating Fisher score (w.r.t. train-dist)
FISHERs = {'VAE': {'cifar10': 0, 'fmnist': 0},
           'GLOW': {'cifar10': 0, 'fmnist': 0}}
         
# NORMalize_FACtors : Normalizing Factor (w.r.t. train-dist)
NORM_FACs = {'VAE': {'cifar10': 0, 'fmnist': 0},
             'GLOW': {'cifar10': 0, 'fmnist': 0}}
             
# SCOREs : (Scalars) Scores from calculating Fisher score (w.r.t. target-dist)
SCOREs = {'VAE': {'cifar10': {}, 'fmnist': {}},
          'GLOW': {'cifar10': {}, 'fmnist': {}}}

# VAE-CIFAR10

In [3]:
train_dist = 'cifar10'
opt = config.VAE_cifar10
ngf, nz, augment = 64, 100, None
netE, netG = load_pretrained_VAE(option=train_dist, ngf=ngf, nz=nz, augment=augment)

#params = ['mu'] # for this choice, you must change the architecture of VAE (go core/train_VAE/DCGAN_VAE_pixel.py line 42)
params = [netE.conv1.weight, netG.main[0].weight, netG.main[-1].weight]
params_name = ['Econv1', 'Gmain0', 'Gmain-1'] # to assign json filename properly
assert len(params) == len(params_name), 'If you modified params, please modify params_name, too!'
max_iter_list = [[300, 100], [500, 200], [1000, 500], [3000, 1000], [10000, 3000], [30000, 5000]]

start = datetime.today()
for max_iter in max_iter_list:
    print(f'Start! {max_iter}')
    a, b, _scores = AUTO_VAE_CIFAR(netE, netG, params, max_iter=max_iter, loss_type='ELBO')

    for i in range(len(params)):
        
        FISHERs['VAE'][train_dist] = a[params[i]]
        NORM_FACs['VAE'][train_dist] = b[params[i]]
        for ood in opt.ood_list:
            SCOREs['VAE'][train_dist][ood] = _scores[ood][params[i]]

        # AUROC curve
        auroc = {}
        for ood in opt.ood_list:
            args = [
                SCOREs['VAE'][train_dist][train_dist],
                SCOREs['VAE'][train_dist][ood],
            ]
            labels = [train_dist, ood]
            auroc[ood] = AUROC(*args, labels=labels, verbose=False)
        print(pd.Series(auroc))

        filename = f'{params_name[i]}_num_sample_{max_iter[0]}_{max_iter[1]}_ngf_{ngf}_nz_{nz}_augment_{augment}'
        with open(f'./results/VAE_{train_dist}/{filename}.json', 'w') as f:
            json.dump(auroc, f)
            
    now = datetime.today()
    print(f'Elapsed time : {now - start}')
    torch.cuda.empty_cache()
    
    
    
path = f'./results/VAE_{train_dist}/'
df = pd.DataFrame()
for file in os.listdir(path):
    if file[-4:] != 'json':
        continue
    temp = file.split('_num_sample_')
    col = (temp[0], f'{temp[1].split("_")[0]}, {temp[1].split("_")[1]}')
    with open(path+file, 'r') as f:
        data = json.load(f)
        df[col] = pd.Series(data)
temp = [f'{elt[0]}, {elt[1]}' for elt in max_iter_list]
df.columns = pd.MultiIndex.from_product([params_name, temp])
df.to_csv(f'{path}result_table.csv')

display(df)

Start! [300, 100]
Files already downloaded and verified


Calculate Fisher VAE:   1%|▎                                                     | 299/50000 [00:15<44:03, 18.80step/s]


Files already downloaded and verified


Calculate Score VAE:   1%|▌                                                       | 99/10000 [00:03<05:08, 32.05step/s]


Using downloaded and verified file: ../data\test_32x32.mat


Calculate Score VAE:   0%|▏                                                       | 99/26032 [00:03<15:19, 28.21step/s]
Calculate Score VAE:   1%|▎                                                       | 99/19141 [00:01<05:17, 59.98step/s]
Calculate Score VAE:  16%|█████████▌                                                | 99/600 [00:01<00:08, 58.22step/s]


Files already downloaded and verified


Calculate Score VAE:   1%|▌                                                       | 99/10000 [00:03<05:06, 32.27step/s]
Calculate Score VAE:   1%|▌                                                       | 99/10000 [00:02<03:52, 42.59step/s]
Calculate Score VAE:   1%|▌                                                       | 99/10000 [00:02<03:51, 42.71step/s]
Calculate Score VAE:   1%|▌                                                       | 99/10000 [00:02<03:54, 42.13step/s]


Files already downloaded and verified


Calculate Score VAE:   1%|▍                                                       | 99/13180 [00:02<06:36, 33.01step/s]
Calculate Score VAE:   1%|▎                                                       | 99/18724 [00:01<04:55, 63.11step/s]
Calculate Score VAE:   1%|▍                                                       | 99/12630 [00:01<03:26, 60.68step/s]
Calculate Score VAE:   1%|▌                                                       | 99/10000 [00:01<02:26, 67.57step/s]
Calculate Score VAE:   1%|▌                                                       | 99/10000 [00:01<02:26, 67.57step/s]


{'cifar10': 0.5, 'svhn': 0.8321, 'celeba': 0.7961, 'lsun': 0.6131000000000001, 'cifar100': 0.5513000000000001, 'mnist': 1.0, 'fmnist': 0.9949999999999999, 'kmnist': 0.9994, 'omniglot': 1.0, 'notmnist': 0.9973, 'trafficsign': 0.6910000000000001, 'noise': 0.16649999999999998, 'constant': 0.9631000000000001}
{'cifar10': 0.5, 'svhn': 0.7260000000000001, 'celeba': 0.8520999999999999, 'lsun': 0.7414999999999999, 'cifar100': 0.5687, 'mnist': 0.9954999999999999, 'fmnist': 0.9884999999999999, 'kmnist': 0.9964, 'omniglot': 0.9984, 'notmnist': 0.9984999999999999, 'trafficsign': 0.3794, 'noise': 0.007900000000000008, 'constant': 0.9804999999999999}
{'cifar10': 0.5, 'svhn': 0.8897999999999999, 'celeba': 0.6305000000000001, 'lsun': 0.4662, 'cifar100': 0.522, 'mnist': 0.9488000000000001, 'fmnist': 0.8903, 'kmnist': 0.9259999999999999, 'omniglot': 0.9346, 'notmnist': 0.8664000000000001, 'trafficsign': 0.726, 'noise': 0.0, 'constant': 0.9924999999999999}
Elapsed time : 0:00:49.651846
Start! [500, 200]


Calculate Fisher VAE:   1%|▌                                                     | 499/50000 [00:21<36:00, 22.91step/s]


Files already downloaded and verified


Calculate Score VAE:   2%|█                                                      | 199/10000 [00:04<03:45, 43.43step/s]


Using downloaded and verified file: ../data\test_32x32.mat


Calculate Score VAE:   1%|▍                                                      | 199/26032 [00:04<10:45, 40.00step/s]
Calculate Score VAE:   1%|▌                                                      | 199/19141 [00:03<05:15, 60.08step/s]
Calculate Score VAE:  33%|██████████████████▉                                      | 199/600 [00:03<00:06, 58.68step/s]


Files already downloaded and verified


Calculate Score VAE:   2%|█                                                      | 199/10000 [00:04<03:46, 43.24step/s]
Calculate Score VAE:   2%|█                                                      | 199/10000 [00:03<03:10, 51.56step/s]
Calculate Score VAE:   2%|█                                                      | 199/10000 [00:03<03:09, 51.72step/s]
Calculate Score VAE:   2%|█                                                      | 199/10000 [00:03<03:09, 51.82step/s]


Files already downloaded and verified


Calculate Score VAE:   2%|▊                                                      | 199/13180 [00:04<04:54, 44.06step/s]
Calculate Score VAE:   1%|▌                                                      | 199/18724 [00:03<04:56, 62.41step/s]
Calculate Score VAE:   2%|▊                                                      | 199/12630 [00:03<03:24, 60.70step/s]
Calculate Score VAE:   2%|█                                                      | 199/10000 [00:02<02:25, 67.36step/s]
Calculate Score VAE:   2%|█                                                      | 199/10000 [00:02<02:25, 67.21step/s]


{'cifar10': 0.49999999999999994, 'svhn': 0.76475, 'celeba': 0.72995, 'lsun': 0.507775, 'cifar100': 0.5162, 'mnist': 0.9974, 'fmnist': 0.98665, 'kmnist': 0.994, 'omniglot': 0.998825, 'notmnist': 0.988125, 'trafficsign': 0.598575, 'noise': 0.0677, 'constant': 0.9333}
{'cifar10': 0.49999999999999994, 'svhn': 0.6861499999999999, 'celeba': 0.7648999999999999, 'lsun': 0.65225, 'cifar100': 0.488575, 'mnist': 0.9971749999999999, 'fmnist': 0.987025, 'kmnist': 0.9995499999999999, 'omniglot': 0.998675, 'notmnist': 0.98915, 'trafficsign': 0.324475, 'noise': 0.01050000000000001, 'constant': 0.9701}
{'cifar10': 0.49999999999999994, 'svhn': 0.82195, 'celeba': 0.582975, 'lsun': 0.36572499999999997, 'cifar100': 0.4295, 'mnist': 0.969375, 'fmnist': 0.907125, 'kmnist': 0.9546, 'omniglot': 0.749425, 'notmnist': 0.8599249999999999, 'trafficsign': 0.576075, 'noise': 0.0, 'constant': 0.9921249999999999}
Elapsed time : 0:02:05.274532
Start! [1000, 500]
Files already downloaded and verified


Calculate Fisher VAE:   2%|█                                                     | 999/50000 [00:40<33:08, 24.64step/s]


Files already downloaded and verified


Calculate Score VAE:   5%|██▋                                                    | 499/10000 [00:09<02:51, 55.31step/s]


Using downloaded and verified file: ../data\test_32x32.mat


Calculate Score VAE:   2%|█                                                      | 499/26032 [00:09<08:01, 53.00step/s]
Calculate Score VAE:   3%|█▍                                                     | 499/19141 [00:08<05:08, 60.39step/s]
Calculate Score VAE:  83%|███████████████████████████████████████████████▍         | 499/600 [00:08<00:01, 59.36step/s]


Files already downloaded and verified


Calculate Score VAE:   5%|██▋                                                    | 499/10000 [00:09<02:51, 55.28step/s]
Calculate Score VAE:   5%|██▋                                                    | 499/10000 [00:08<02:37, 60.45step/s]
Calculate Score VAE:   5%|██▋                                                    | 499/10000 [00:08<02:37, 60.20step/s]
Calculate Score VAE:   5%|██▋                                                    | 499/10000 [00:08<02:36, 60.54step/s]


Files already downloaded and verified


Calculate Score VAE:   4%|██                                                     | 499/13180 [00:08<03:48, 55.48step/s]
Calculate Score VAE:   3%|█▍                                                     | 499/18724 [00:07<04:47, 63.29step/s]
Calculate Score VAE:   4%|██▏                                                    | 499/12630 [00:08<03:19, 60.73step/s]
Calculate Score VAE:   5%|██▋                                                    | 499/10000 [00:07<02:19, 68.17step/s]
Calculate Score VAE:   5%|██▋                                                    | 499/10000 [00:07<02:20, 67.49step/s]


{'cifar10': 0.5, 'svhn': 0.849244, 'celeba': 0.816932, 'lsun': 0.616844, 'cifar100': 0.566144, 'mnist': 0.9971, 'fmnist': 0.987972, 'kmnist': 0.99566, 'omniglot': 0.999112, 'notmnist': 0.996104, 'trafficsign': 0.591376, 'noise': 0.1299, 'constant': 0.974024}
{'cifar10': 0.5, 'svhn': 0.750124, 'celeba': 0.8158240000000001, 'lsun': 0.704372, 'cifar100': 0.563544, 'mnist': 0.995976, 'fmnist': 0.98376, 'kmnist': 0.998772, 'omniglot': 0.9974239999999999, 'notmnist': 0.993024, 'trafficsign': 0.353776, 'noise': 0.009840000000000008, 'constant': 0.9764919999999998}
{'cifar10': 0.5, 'svhn': 0.644068, 'celeba': 0.708228, 'lsun': 0.5311199999999999, 'cifar100': 0.535664, 'mnist': 0.94922, 'fmnist': 0.891128, 'kmnist': 0.924028, 'omniglot': 0.7575480000000001, 'notmnist': 0.895568, 'trafficsign': 0.680724, 'noise': 0.150736, 'constant': 0.9533479999999999}
Elapsed time : 0:04:39.110443
Start! [3000, 1000]
Files already downloaded and verified


Calculate Fisher VAE:   6%|███▏                                                 | 2999/50000 [01:56<30:25, 25.75step/s]


Files already downloaded and verified


Calculate Score VAE:  10%|█████▍                                                 | 999/10000 [00:16<02:30, 59.75step/s]


Using downloaded and verified file: ../data\test_32x32.mat


Calculate Score VAE:   4%|██                                                     | 999/26032 [00:17<07:06, 58.70step/s]
Calculate Score VAE:   5%|██▊                                                    | 999/19141 [00:16<05:02, 59.99step/s]
Calculate Score VAE: 100%|█████████████████████████████████████████████████████████| 600/600 [00:09<00:00, 60.14step/s]


Files already downloaded and verified


Calculate Score VAE:  10%|█████▍                                                 | 999/10000 [00:16<02:29, 60.20step/s]
Calculate Score VAE:  10%|█████▍                                                 | 999/10000 [00:15<02:22, 63.16step/s]
Calculate Score VAE:  10%|█████▍                                                 | 999/10000 [00:15<02:22, 63.25step/s]
Calculate Score VAE:  10%|█████▍                                                 | 999/10000 [00:15<02:22, 63.23step/s]


Files already downloaded and verified


Calculate Score VAE:   8%|████▏                                                  | 999/13180 [00:16<03:22, 60.28step/s]
Calculate Score VAE:   5%|██▉                                                    | 999/18724 [00:15<04:41, 62.92step/s]
Calculate Score VAE:   8%|████▎                                                  | 999/12630 [00:16<03:11, 60.87step/s]
Calculate Score VAE:  10%|█████▍                                                 | 999/10000 [00:14<02:13, 67.34step/s]
Calculate Score VAE:  10%|█████▍                                                 | 999/10000 [00:14<02:13, 67.35step/s]


{'cifar10': 0.49999999999999994, 'svhn': 0.8378779999999999, 'celeba': 0.8111149999999999, 'lsun': 0.619965, 'cifar100': 0.564045, 'mnist': 0.998404, 'fmnist': 0.9936860000000001, 'kmnist': 0.998098, 'omniglot': 0.999132, 'notmnist': 0.997911, 'trafficsign': 0.571463, 'noise': 0.130098, 'constant': 0.991186}
{'cifar10': 0.49999999999999994, 'svhn': 0.7467069999999999, 'celeba': 0.8029749999999999, 'lsun': 0.7045266666666666, 'cifar100': 0.557819, 'mnist': 0.993978, 'fmnist': 0.9807989999999999, 'kmnist': 0.9972989999999999, 'omniglot': 0.996725, 'notmnist': 0.9919150000000001, 'trafficsign': 0.38707400000000003, 'noise': 0.008037000000000008, 'constant': 0.977772}
{'cifar10': 0.49999999999999994, 'svhn': 0.6439779999999999, 'celeba': 0.7367900000000001, 'lsun': 0.540425, 'cifar100': 0.535627, 'mnist': 0.879964, 'fmnist': 0.824901, 'kmnist': 0.866356, 'omniglot': 0.835383, 'notmnist': 0.904039, 'trafficsign': 0.59343, 'noise': 0.5778789999999999, 'constant': 0.98673}
Elapsed time : 0:10

Calculate Fisher VAE:  20%|██████████▌                                          | 9999/50000 [06:22<25:28, 26.17step/s]


Files already downloaded and verified


Calculate Score VAE:  30%|████████████████▏                                     | 2999/10000 [00:46<01:47, 65.01step/s]


Using downloaded and verified file: ../data\test_32x32.mat


Calculate Score VAE:  12%|██████▏                                               | 2999/26032 [00:46<05:57, 64.48step/s]
Calculate Score VAE:  16%|████████▍                                             | 2999/19141 [00:49<04:26, 60.63step/s]
Calculate Score VAE: 100%|█████████████████████████████████████████████████████████| 600/600 [00:09<00:00, 60.74step/s]


Files already downloaded and verified


Calculate Score VAE:  30%|████████████████▏                                     | 2999/10000 [00:46<01:47, 64.94step/s]
Calculate Score VAE:  30%|████████████████▏                                     | 2999/10000 [00:45<01:45, 66.27step/s]
Calculate Score VAE:  30%|████████████████▏                                     | 2999/10000 [00:45<01:45, 66.07step/s]
Calculate Score VAE:  30%|████████████████▏                                     | 2999/10000 [00:45<01:45, 66.10step/s]


Files already downloaded and verified


Calculate Score VAE:  23%|████████████▎                                         | 2999/13180 [00:46<02:36, 64.88step/s]
Calculate Score VAE:  16%|████████▋                                             | 2999/18724 [00:47<04:07, 63.57step/s]
Calculate Score VAE:  24%|████████████▊                                         | 2999/12630 [00:48<02:35, 61.85step/s]
Calculate Score VAE:  30%|████████████████▏                                     | 2999/10000 [00:44<01:43, 67.96step/s]
Calculate Score VAE:  30%|████████████████▏                                     | 2999/10000 [00:44<01:42, 68.05step/s]


{'cifar10': 0.5, 'svhn': 0.8437492222222222, 'celeba': 0.8050104444444444, 'lsun': 0.6165705555555555, 'cifar100': 0.5503763333333334, 'mnist': 0.9996476666666666, 'fmnist': 0.9947028888888889, 'kmnist': 0.9993703333333334, 'omniglot': 0.998864, 'notmnist': 0.9955044444444444, 'trafficsign': 0.5376074444444444, 'noise': 0.11822199999999998, 'constant': 0.991867888888889}
{'cifar10': 0.5, 'svhn': 0.7506692222222222, 'celeba': 0.7960245555555555, 'lsun': 0.6993177777777777, 'cifar100': 0.5411162222222222, 'mnist': 0.9941862222222221, 'fmnist': 0.9816225555555556, 'kmnist': 0.9971534444444444, 'omniglot': 0.9949636666666666, 'notmnist': 0.9828705555555557, 'trafficsign': 0.3688166666666667, 'noise': 0.005133000000000002, 'constant': 0.9724265555555556}
{'cifar10': 0.5, 'svhn': 0.774337888888889, 'celeba': 0.6617907777777777, 'lsun': 0.46310166666666663, 'cifar100': 0.5178450555555556, 'mnist': 0.8513563333333334, 'fmnist': 0.7743997777777778, 'kmnist': 0.839169, 'omniglot': 0.772733111111

Calculate Fisher VAE:  60%|███████████████████████████████▏                    | 29999/50000 [19:04<12:43, 26.21step/s]


Files already downloaded and verified


Calculate Score VAE:  50%|██████████████████████████▉                           | 4999/10000 [01:17<01:17, 64.17step/s]


Using downloaded and verified file: ../data\test_32x32.mat


Calculate Score VAE:  19%|██████████▎                                           | 4999/26032 [01:17<05:27, 64.24step/s]
Calculate Score VAE:  26%|██████████████                                        | 4999/19141 [01:23<03:56, 59.77step/s]
Calculate Score VAE: 100%|█████████████████████████████████████████████████████████| 600/600 [00:10<00:00, 57.82step/s]


Files already downloaded and verified


Calculate Score VAE:  50%|██████████████████████████▉                           | 4999/10000 [01:17<01:17, 64.59step/s]
Calculate Score VAE:  50%|██████████████████████████▉                           | 4999/10000 [01:16<01:16, 65.20step/s]
Calculate Score VAE:  50%|██████████████████████████▉                           | 4999/10000 [01:17<01:17, 64.85step/s]
Calculate Score VAE:  50%|██████████████████████████▉                           | 4999/10000 [01:16<01:16, 64.97step/s]


Files already downloaded and verified


Calculate Score VAE:  38%|████████████████████▍                                 | 4999/13180 [01:17<02:07, 64.24step/s]
Calculate Score VAE:  27%|██████████████▍                                       | 4999/18724 [01:20<03:41, 62.02step/s]
Calculate Score VAE:  40%|█████████████████████▎                                | 4999/12630 [01:22<02:05, 60.93step/s]
Calculate Score VAE:  50%|██████████████████████████▉                           | 4999/10000 [01:15<01:15, 65.96step/s]
Calculate Score VAE:  50%|██████████████████████████▉                           | 4999/10000 [01:15<01:15, 66.41step/s]


{'cifar10': 0.5, 'svhn': 0.8421926800000001, 'celeba': 0.8030542599999999, 'lsun': 0.613054, 'cifar100': 0.54201796, 'mnist': 0.9993709199999999, 'fmnist': 0.9933366399999999, 'kmnist': 0.99908524, 'omniglot': 0.99818548, 'notmnist': 0.9968381199999999, 'trafficsign': 0.53232746, 'noise': 0.12112828, 'constant': 0.98965776}
{'cifar10': 0.5, 'svhn': 0.7468202199999999, 'celeba': 0.79835148, 'lsun': 0.6943673333333332, 'cifar100': 0.53726764, 'mnist': 0.99233168, 'fmnist': 0.9791108800000001, 'kmnist': 0.9960083999999999, 'omniglot': 0.9933364, 'notmnist': 0.98513628, 'trafficsign': 0.36535038, 'noise': 0.006196360000000007, 'constant': 0.9693365599999999}
{'cifar10': 0.5, 'svhn': 0.81900026, 'celeba': 0.62912912, 'lsun': 0.42180966666666664, 'cifar100': 0.49758038, 'mnist': 0.8383122, 'fmnist': 0.78797048, 'kmnist': 0.8244875199999999, 'omniglot': 0.6947161800000001, 'notmnist': 0.92652578, 'trafficsign': 0.64576118, 'noise': 0.12147918, 'constant': 0.99954736}
Elapsed time : 1:00:53.12

# VAE-FMNIST

In [4]:
train_dist = 'fmnist'
opt = config.VAE_fmnist
ngf, nz, augment = 32, 100, None
netE, netG = load_pretrained_VAE(option=train_dist, ngf=ngf, nz=nz, augment=augment)

#params = ['mu'] # for this choice, you must change the architecture of VAE (go core/train_VAE/DCGAN_VAE_pixel.py line 42)
params = [netE.conv1.weight, netG.main[0].weight, netG.main[-1].weight]
params_name = ['Econv1', 'Gmain0', 'Gmain-1'] # to assign json filename properly
assert len(params) == len(params_name), 'If you modified params, please modify params_name, too!'
max_iter_list = [[300, 100], [500, 200], [1000, 500], [3000, 1000], [10000, 3000], [30000, 5000]]

start = datetime.today()
for max_iter in max_iter_list:
    print(f'Start! {max_iter}')
    a, b, _scores = AUTO_VAE_FMNIST(netE, netG, params, max_iter=max_iter, loss_type='ELBO')
    
    for i in range(len(params)):

        FISHERs['VAE'][train_dist] = a[params[i]]
        NORM_FACs['VAE'][train_dist] = b[params[i]]
        for ood in opt.ood_list:
            SCOREs['VAE'][train_dist][ood] = _scores[ood][params[i]]

        # AUROC curve
        auroc = {}
        for ood in opt.ood_list:
            args = [
                SCOREs['VAE'][train_dist][train_dist],
                SCOREs['VAE'][train_dist][ood],
            ]
            labels = [train_dist, ood]
            auroc[ood] = AUROC(*args, labels=labels, verbose=False)
        print(pd.Series(auroc))

        filename = f'{params_name[i]}_num_sample_{max_iter[0]}_{max_iter[1]}_ngf_{ngf}_nz_{nz}_augment_{augment}'
        with open(f'./results/VAE_{train_dist}/{filename}.json', 'w') as f:
            json.dump(auroc, f)
    
    now = datetime.today()
    print(f'Elapsed time : {now - start}')
    torch.cuda.empty_cache()
    
    
    
path = f'./results/VAE_{train_dist}/'
df = pd.DataFrame()
for file in os.listdir(path):
    if file[-4:] != 'json':
        continue
    temp = file.split('_num_sample_')
    col = (temp[0], f'{temp[1].split("_")[0]}, {temp[1].split("_")[1]}')
    with open(path+file, 'r') as f:
        data = json.load(f)
        df[col] = pd.Series(data)
temp = [f'{elt[0]}, {elt[1]}' for elt in max_iter_list]
df.columns = pd.MultiIndex.from_product([params_name, temp])
df.to_csv(f'{path}result_table.csv')

display(df)

Calculate Fisher VAE:   0%|                                                                | 0/60000 [00:00<?, ?step/s]

Start! [300, 100]


Calculate Fisher VAE:   0%|▎                                                     | 299/60000 [00:05<17:13, 57.76step/s]
Calculate Score VAE:   1%|▌                                                       | 99/10000 [00:01<02:53, 56.92step/s]


Using downloaded and verified file: ../data\test_32x32.mat


Calculate Score VAE:   0%|▏                                                       | 99/26032 [00:03<13:21, 32.38step/s]
Calculate Score VAE:   1%|▎                                                       | 99/19141 [00:01<03:28, 91.32step/s]
Calculate Score VAE:  16%|█████████▌                                                | 99/600 [00:01<00:05, 87.46step/s]


Files already downloaded and verified


Calculate Score VAE:   1%|▌                                                       | 99/10000 [00:02<04:16, 38.53step/s]


Files already downloaded and verified


Calculate Score VAE:   1%|▌                                                       | 99/10000 [00:02<04:18, 38.27step/s]
Calculate Score VAE:   1%|▌                                                       | 99/10000 [00:01<02:51, 57.61step/s]
Calculate Score VAE:   1%|▌                                                       | 99/10000 [00:01<02:51, 57.64step/s]


Files already downloaded and verified


Calculate Score VAE:   1%|▍                                                       | 99/13180 [00:02<05:16, 41.30step/s]
Calculate Score VAE:   1%|▎                                                       | 99/18724 [00:01<03:08, 98.87step/s]
Calculate Score VAE:   1%|▌                                                      | 99/10000 [00:00<01:28, 112.04step/s]
Calculate Score VAE:   1%|▌                                                      | 99/10000 [00:00<01:28, 111.78step/s]
Calculate Fisher VAE:   0%|                                                                | 0/60000 [00:00<?, ?step/s]

fmnist      0.5000
svhn        1.0000
celeba      0.9986
lsun        0.9991
cifar10     0.9987
cifar100    0.9983
mnist       0.9837
kmnist      0.9925
omniglot    1.0000
notmnist    0.9998
noise       1.0000
constant    0.9999
dtype: float64
fmnist      0.5000
svhn        1.0000
celeba      1.0000
lsun        1.0000
cifar10     1.0000
cifar100    0.9996
mnist       0.9631
kmnist      0.9954
omniglot    1.0000
notmnist    1.0000
noise       0.9705
constant    1.0000
dtype: float64
fmnist      0.5000
svhn        0.6223
celeba      0.4917
lsun        0.4318
cifar10     0.4889
cifar100    0.4961
mnist       0.4943
kmnist      0.5071
omniglot    1.0000
notmnist    0.8810
noise       0.2530
constant    0.9471
dtype: float64
Elapsed time : 0:00:28.480157
Start! [500, 200]


Calculate Fisher VAE:   1%|▍                                                     | 499/60000 [00:06<13:15, 74.79step/s]
Calculate Score VAE:   2%|█                                                      | 199/10000 [00:02<02:09, 75.67step/s]


Using downloaded and verified file: ../data\test_32x32.mat


Calculate Score VAE:   1%|▍                                                      | 199/26032 [00:03<08:33, 50.31step/s]
Calculate Score VAE:   1%|▌                                                      | 199/19141 [00:02<03:24, 92.85step/s]
Calculate Score VAE:  33%|██████████████████▉                                      | 199/600 [00:02<00:04, 88.44step/s]


Files already downloaded and verified


Calculate Score VAE:   2%|█                                                      | 199/10000 [00:03<02:53, 56.48step/s]


Files already downloaded and verified


Calculate Score VAE:   2%|█                                                      | 199/10000 [00:03<02:52, 56.78step/s]
Calculate Score VAE:   2%|█                                                      | 199/10000 [00:02<02:09, 75.75step/s]
Calculate Score VAE:   2%|█                                                      | 199/10000 [00:02<02:11, 74.51step/s]


Files already downloaded and verified


Calculate Score VAE:   2%|▊                                                      | 199/13180 [00:03<03:36, 60.03step/s]
Calculate Score VAE:   1%|▌                                                      | 199/18724 [00:02<03:07, 98.58step/s]
Calculate Score VAE:   2%|█                                                     | 199/10000 [00:01<01:29, 110.06step/s]
Calculate Score VAE:   2%|█                                                     | 199/10000 [00:01<01:28, 110.98step/s]
Calculate Fisher VAE:   0%|                                                                | 0/60000 [00:00<?, ?step/s]

fmnist      0.500000
svhn        0.999250
celeba      0.998025
lsun        0.997975
cifar10     0.998325
cifar100    0.998275
mnist       0.988200
kmnist      0.990400
omniglot    1.000000
notmnist    0.997950
noise       0.995000
constant    0.999275
dtype: float64
fmnist      0.500000
svhn        0.999925
celeba      0.999600
lsun        0.999300
cifar10     0.999325
cifar100    0.999025
mnist       0.958025
kmnist      0.991075
omniglot    1.000000
notmnist    0.999800
noise       0.977800
constant    0.999500
dtype: float64
fmnist      0.500000
svhn        0.785950
celeba      0.626575
lsun        0.641800
cifar10     0.655725
cifar100    0.651175
mnist       0.779400
kmnist      0.789625
omniglot    1.000000
notmnist    0.972425
noise       0.671950
constant    0.975850
dtype: float64
Elapsed time : 0:01:09.900467
Start! [1000, 500]


Calculate Fisher VAE:   2%|▉                                                     | 999/60000 [00:12<12:17, 80.04step/s]
Calculate Score VAE:   5%|██▋                                                    | 499/10000 [00:05<01:42, 92.31step/s]


Using downloaded and verified file: ../data\test_32x32.mat


Calculate Score VAE:   2%|█                                                      | 499/26032 [00:06<05:42, 74.50step/s]
Calculate Score VAE:   3%|█▍                                                     | 499/19141 [00:05<03:22, 92.11step/s]
Calculate Score VAE:  83%|███████████████████████████████████████████████▍         | 499/600 [00:05<00:01, 89.65step/s]


Files already downloaded and verified


Calculate Score VAE:   5%|██▋                                                    | 499/10000 [00:06<01:58, 79.89step/s]


Files already downloaded and verified


Calculate Score VAE:   5%|██▋                                                    | 499/10000 [00:06<01:59, 79.84step/s]
Calculate Score VAE:   5%|██▋                                                    | 499/10000 [00:05<01:42, 92.70step/s]
Calculate Score VAE:   5%|██▋                                                    | 499/10000 [00:05<01:42, 92.35step/s]


Files already downloaded and verified


Calculate Score VAE:   4%|██                                                     | 499/13180 [00:06<02:35, 81.73step/s]
Calculate Score VAE:   3%|█▍                                                     | 499/18724 [00:05<03:06, 97.58step/s]
Calculate Score VAE:   5%|██▋                                                   | 499/10000 [00:04<01:24, 112.38step/s]
Calculate Score VAE:   5%|██▋                                                   | 499/10000 [00:04<01:25, 110.72step/s]
Calculate Fisher VAE:   0%|                                                                | 0/60000 [00:00<?, ?step/s]

fmnist      0.500000
svhn        1.000000
celeba      0.999696
lsun        0.999872
cifar10     0.999616
cifar100    0.999384
mnist       0.980548
kmnist      0.988400
omniglot    1.000000
notmnist    0.999944
noise       0.996068
constant    0.994004
dtype: float64
fmnist      0.500000
svhn        1.000000
celeba      0.999888
lsun        0.999952
cifar10     0.999748
cifar100    0.999624
mnist       0.959728
kmnist      0.990324
omniglot    1.000000
notmnist    0.999948
noise       0.971348
constant    0.993828
dtype: float64
fmnist      0.500000
svhn        0.927172
celeba      0.832008
lsun        0.794196
cifar10     0.774052
cifar100    0.778064
mnist       0.884112
kmnist      0.909604
omniglot    1.000000
notmnist    0.989720
noise       0.973504
constant    0.988356
dtype: float64
Elapsed time : 0:02:31.436575
Start! [3000, 1000]


Calculate Fisher VAE:   5%|██▋                                                  | 2999/60000 [00:35<11:23, 83.34step/s]
Calculate Score VAE:  10%|█████▍                                                 | 999/10000 [00:10<01:30, 99.57step/s]


Using downloaded and verified file: ../data\test_32x32.mat


Calculate Score VAE:   4%|██                                                     | 999/26032 [00:11<04:41, 88.90step/s]
Calculate Score VAE:   5%|██▊                                                    | 999/19141 [00:10<03:16, 92.29step/s]
Calculate Score VAE: 100%|█████████████████████████████████████████████████████████| 600/600 [00:06<00:00, 91.49step/s]


Files already downloaded and verified


Calculate Score VAE:  10%|█████▍                                                 | 999/10000 [00:10<01:36, 92.90step/s]


Files already downloaded and verified


Calculate Score VAE:  10%|█████▍                                                 | 999/10000 [00:10<01:37, 92.38step/s]
Calculate Score VAE:  10%|█████▍                                                | 999/10000 [00:09<01:29, 100.48step/s]
Calculate Score VAE:  10%|█████▍                                                | 999/10000 [00:09<01:29, 100.68step/s]


Files already downloaded and verified


Calculate Score VAE:   8%|████▏                                                  | 999/13180 [00:10<02:10, 93.61step/s]
Calculate Score VAE:   5%|██▉                                                    | 999/18724 [00:10<02:58, 99.23step/s]
Calculate Score VAE:  10%|█████▍                                                | 999/10000 [00:08<01:19, 112.92step/s]
Calculate Score VAE:  10%|█████▍                                                | 999/10000 [00:08<01:20, 111.53step/s]
Calculate Fisher VAE:   0%|                                                                | 0/60000 [00:00<?, ?step/s]

fmnist      0.500000
svhn        0.999922
celeba      0.999403
lsun        0.999557
cifar10     0.999386
cifar100    0.999286
mnist       0.982518
kmnist      0.992844
omniglot    1.000000
notmnist    0.999700
noise       0.995570
constant    0.995895
dtype: float64
fmnist      0.500000
svhn        0.999844
celeba      0.999638
lsun        0.999447
cifar10     0.999384
cifar100    0.999356
mnist       0.966647
kmnist      0.994334
omniglot    1.000000
notmnist    0.999737
noise       0.980620
constant    0.995500
dtype: float64
fmnist      0.500000
svhn        0.946797
celeba      0.825745
lsun        0.783483
cifar10     0.757908
cifar100    0.757121
mnist       0.791038
kmnist      0.893809
omniglot    1.000000
notmnist    0.984748
noise       0.998992
constant    0.994776
dtype: float64
Elapsed time : 0:05:08.608319
Start! [10000, 3000]


Calculate Fisher VAE:  17%|████████▊                                            | 9999/60000 [01:56<09:45, 85.47step/s]
Calculate Score VAE:  30%|███████████████▉                                     | 2999/10000 [00:28<01:05, 106.09step/s]


Using downloaded and verified file: ../data\test_32x32.mat


Calculate Score VAE:  12%|██████                                               | 2999/26032 [00:29<03:47, 101.28step/s]
Calculate Score VAE:  16%|████████▍                                             | 2999/19141 [00:32<02:52, 93.57step/s]
Calculate Score VAE: 100%|█████████████████████████████████████████████████████████| 600/600 [00:06<00:00, 92.44step/s]


Files already downloaded and verified


Calculate Score VAE:  30%|███████████████▉                                     | 2999/10000 [00:29<01:07, 103.14step/s]


Files already downloaded and verified


Calculate Score VAE:  30%|███████████████▉                                     | 2999/10000 [00:28<01:06, 105.03step/s]
Calculate Score VAE:  30%|███████████████▉                                     | 2999/10000 [00:27<01:04, 107.89step/s]
Calculate Score VAE:  30%|███████████████▉                                     | 2999/10000 [00:27<01:04, 108.06step/s]


Files already downloaded and verified


Calculate Score VAE:  23%|████████████                                         | 2999/13180 [00:28<01:37, 104.50step/s]
Calculate Score VAE:  16%|████████▋                                             | 2999/18724 [00:30<02:37, 99.95step/s]
Calculate Score VAE:  30%|███████████████▉                                     | 2999/10000 [00:26<01:02, 112.59step/s]
Calculate Score VAE:  30%|███████████████▉                                     | 2999/10000 [00:26<01:02, 112.33step/s]
Calculate Fisher VAE:   0%|                                                                | 0/60000 [00:00<?, ?step/s]

fmnist      0.500000
svhn        0.999949
celeba      0.999553
lsun        0.999566
cifar10     0.999336
cifar100    0.999108
mnist       0.981063
kmnist      0.992745
omniglot    1.000000
notmnist    0.999726
noise       0.992457
constant    0.997128
dtype: float64
fmnist      0.500000
svhn        0.999928
celeba      0.999743
lsun        0.999576
cifar10     0.999369
cifar100    0.999192
mnist       0.963354
kmnist      0.993406
omniglot    1.000000
notmnist    0.999800
noise       0.975404
constant    0.996973
dtype: float64
fmnist      0.500000
svhn        0.953179
celeba      0.821770
lsun        0.741524
cifar10     0.745331
cifar100    0.736726
mnist       0.672422
kmnist      0.856436
omniglot    1.000000
notmnist    0.981541
noise       0.999634
constant    0.997023
dtype: float64
Elapsed time : 0:12:29.800508
Start! [30000, 5000]


Calculate Fisher VAE:  50%|█████████████████████████▉                          | 29999/60000 [05:46<05:46, 86.64step/s]
Calculate Score VAE:  50%|██████████████████████████▍                          | 4999/10000 [00:45<00:45, 109.19step/s]


Using downloaded and verified file: ../data\test_32x32.mat


Calculate Score VAE:  19%|██████████▏                                          | 4999/26032 [00:46<03:17, 106.48step/s]
Calculate Score VAE:  26%|██████████████                                        | 4999/19141 [00:52<02:29, 94.46step/s]
Calculate Score VAE: 100%|█████████████████████████████████████████████████████████| 600/600 [00:06<00:00, 91.96step/s]


Files already downloaded and verified


Calculate Score VAE:  50%|██████████████████████████▍                          | 4999/10000 [00:46<00:46, 106.85step/s]


Files already downloaded and verified


Calculate Score VAE:  50%|██████████████████████████▍                          | 4999/10000 [00:46<00:46, 106.66step/s]
Calculate Score VAE:  50%|██████████████████████████▍                          | 4999/10000 [00:45<00:45, 109.29step/s]
Calculate Score VAE:  50%|██████████████████████████▍                          | 4999/10000 [00:45<00:45, 109.40step/s]


Files already downloaded and verified


Calculate Score VAE:  38%|████████████████████                                 | 4999/13180 [00:46<01:16, 106.99step/s]
Calculate Score VAE:  27%|██████████████▏                                      | 4999/18724 [00:49<02:16, 100.30step/s]
Calculate Score VAE:  50%|██████████████████████████▍                          | 4999/10000 [00:44<00:44, 112.46step/s]
Calculate Score VAE:  50%|██████████████████████████▍                          | 4999/10000 [00:44<00:44, 112.82step/s]

fmnist      0.500000
svhn        0.999993
celeba      0.999726
lsun        0.999844
cifar10     0.999556
cifar100    0.999465
mnist       0.983381
kmnist      0.993261
omniglot    1.000000
notmnist    0.999776
noise       0.995179
constant    0.997080
dtype: float64
fmnist      0.500000
svhn        0.999980
celeba      0.999862
lsun        0.999832
cifar10     0.999632
cifar100    0.999529
mnist       0.966025
kmnist      0.994394
omniglot    1.000000
notmnist    0.999831
noise       0.977015
constant    0.997009
dtype: float64
fmnist      0.500000
svhn        0.955290
celeba      0.817419
lsun        0.721367
cifar10     0.718826
cifar100    0.724923
mnist       0.542106
kmnist      0.764547
omniglot    1.000000
notmnist    0.976339
noise       0.999499
constant    0.997161
dtype: float64
Elapsed time : 0:27:01.269822





Unnamed: 0_level_0,Econv1,Econv1,Econv1,Econv1,Econv1,Econv1,Gmain0,Gmain0,Gmain0,Gmain0,Gmain0,Gmain0,Gmain-1,Gmain-1,Gmain-1,Gmain-1,Gmain-1,Gmain-1
Unnamed: 0_level_1,"300, 100","500, 200","1000, 500","3000, 1000","10000, 3000","30000, 5000","300, 100","500, 200","1000, 500","3000, 1000","10000, 3000","30000, 5000","300, 100","500, 200","1000, 500","3000, 1000","10000, 3000","30000, 5000"
fmnist,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5
svhn,0.999949,1.0,0.999993,0.999922,1.0,0.99925,0.953179,0.927172,0.95529,0.946797,0.6223,0.78595,0.999928,1.0,0.99998,0.999844,1.0,0.999925
celeba,0.999553,0.999696,0.999726,0.999403,0.9986,0.998025,0.82177,0.832008,0.817419,0.825745,0.4917,0.626575,0.999743,0.999888,0.999862,0.999638,1.0,0.9996
lsun,0.999566,0.999872,0.999844,0.999557,0.9991,0.997975,0.741524,0.794196,0.721367,0.783483,0.4318,0.6418,0.999576,0.999952,0.999832,0.999447,1.0,0.9993
cifar10,0.999336,0.999616,0.999556,0.999386,0.9987,0.998325,0.745331,0.774052,0.718826,0.757908,0.4889,0.655725,0.999369,0.999748,0.999632,0.999384,1.0,0.999325
cifar100,0.999108,0.999384,0.999465,0.999286,0.9983,0.998275,0.736726,0.778064,0.724923,0.757121,0.4961,0.651175,0.999192,0.999624,0.999529,0.999356,0.9996,0.999025
mnist,0.981063,0.980548,0.983381,0.982518,0.9837,0.9882,0.672422,0.884112,0.542106,0.791038,0.4943,0.7794,0.963354,0.959728,0.966025,0.966647,0.9631,0.958025
kmnist,0.992745,0.9884,0.993261,0.992844,0.9925,0.9904,0.856436,0.909604,0.764547,0.893809,0.5071,0.789625,0.993406,0.990324,0.994394,0.994334,0.9954,0.991075
omniglot,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
notmnist,0.999726,0.999944,0.999776,0.9997,0.9998,0.99795,0.981541,0.98972,0.976339,0.984748,0.881,0.972425,0.9998,0.999948,0.999831,0.999737,1.0,0.9998


# GLOW-CIFAR10

In [None]:
train_dist = 'cifar10'
opt = config.GLOW_cifar10
model = load_pretrained_GLOW(option=train_dist)

dicts = [model.flow.layers[-1]]

In [None]:
a, b, c, d, e, f, g = AUTO_GLOW_CIFAR(model, dicts)
GRADs['GLOW']['cifar10'] = a
NORM_FACs['GLOW']['cifar10'] = b
Gradients['GLOW']['cifar10']['cifar10'] = c
Gradients['GLOW']['cifar10']['svhn'] = d
Gradients['GLOW']['cifar10']['celeba'] = e
Gradients['GLOW']['cifar10']['lsun'] = f
Gradients['GLOW']['cifar10']['noise'] = g

# GLOW-FMNIST