In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
sys.path.append(os.getcwd() + '/core')
sys.path.append(os.getcwd() + '/core/train_GLOW') 
import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from datetime import datetime

import core.config as config
from core.model_loader import load_pretrained_VAE
from core.data_loader import TRAIN_loader, TEST_loader
from core.custom_loss import KL_div, VAE_loss_pixel
from core.fisher_utils_VAE import Calculate_fisher_VAE, Calculate_score_VAE, AUTO_VAE
from core.fisher_utils_VAE import Calculate_fisher_VAE_ekfac, Calculate_score_VAE_ekfac
from core.visualize import plot_hist, AUROC, plot_scores_all_layers

# fix a random seed
seed = 2021
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x210b284b030>

In [2]:
opt = config.VAE_fmnist
netE, netG = load_pretrained_VAE(option=opt.train_dist, num=1, ngf=32, nz=100, beta=1, augment='hflip', epoch=100)
netE.eval()
netG.eval()
modules = [] # Write the name of modules you want to see

In [None]:
#results = {}
start = datetime.now()
for sampling in [10000]:
    result = []
    for i in range(30):
        torch.cuda.empty_cache()
        print(i)
        auroc = {}
        SCOREs = {}

        U_A, U_B, S, mean, std = Calculate_fisher_VAE_ekfac(netE, netG, opt, select_modules=modules, max_iter=sampling, seed=2021+i+sampling)

        for ood in [opt.train_dist, 'overall']:
            score = Calculate_score_VAE_ekfac(netE, netG, opt, U_A, U_B, S, ood, max_iter=5000, seed=2021+i+sampling)
            temp = []
            for name in score.keys():
                a = np.array(score[name])
                a = (a - mean[name]) / std[name]  
                temp.append(a) 
            score = np.max(np.concatenate(temp, 1), 1)
            SCOREs[ood] = score
            args = [SCOREs[opt.train_dist], SCOREs[ood]]
            labels = [opt.train_dist, ood]
            auroc[ood] = AUROC(*args, labels=labels, verbose=False)
            print(f'{opt.train_dist}/{ood} {auroc[ood]}')
        result.append(auroc['overall'])
        print(f'Now {datetime.now()} Elapsed Time {datetime.now() - start}')
        print(result)
        np.save(f'./temp/{opt.train_dist}_sampling_{sampling}.npy', np.array(result))
    #results[sampling] = result
#df = pd.DataFrame(results)
#df.to_csv(f'./temp/{opt.train_dist}.csv')


Calculate A, B:   0%|                                                                              | 0/60000 [00:00<?, ?step/s]

0
(0): Conv2d(128, 100, kernel_size=(4, 4), stride=(1, 1))
(1): Conv2d(128, 100, kernel_size=(4, 4), stride=(1, 1))
(2): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)


Calculate A, B:  17%|██████████▉                                                       | 9999/60000 [00:52<04:20, 192.02step/s]
Calculate Fisher Inverse:  17%|█████████▎                                              | 9999/60000 [00:57<04:47, 173.78step/s]
Calculate Score of fmnist(train):  17%|███████▉                                        | 9999/60000 [01:04<05:20, 155.87step/s]
Calculate Score of fmnist:  50%|███████████████████████████▍                           | 4999/10000 [00:32<00:32, 154.62step/s]
Calculate Score of overall:   0%|                                                        | 15/10000 [00:00<01:08, 145.97step/s]

Average Inference Time : 0.0064710105999999995 seconds
Average #Images Processed : 154.53536731959613 Images
fmnist/fmnist 0.5


Calculate Score of overall:  50%|██████████████████████████▉                           | 4999/10000 [00:34<00:34, 146.25step/s]
Calculate A, B:   0%|                                                                    | 15/60000 [00:00<06:58, 143.24step/s]

fmnist/overall 0.99867764
Now 2021-05-27 13:55:48.936185 Elapsed Time 0:04:00.975401
[0.99867764]
1
(0): Conv2d(128, 100, kernel_size=(4, 4), stride=(1, 1))
(1): Conv2d(128, 100, kernel_size=(4, 4), stride=(1, 1))
(2): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)


Calculate A, B:  17%|██████████▉                                                       | 9999/60000 [01:01<05:09, 161.30step/s]
Calculate Fisher Inverse:  17%|█████████▎                                              | 9999/60000 [01:09<05:45, 144.58step/s]
Calculate Score of fmnist(train):  17%|███████▉                                        | 9999/60000 [01:15<06:16, 132.84step/s]
Calculate Score of fmnist:  50%|███████████████████████████▍                           | 4999/10000 [00:37<00:37, 132.56step/s]
Calculate Score of overall:   0%|                                                        | 12/10000 [00:00<01:23, 119.16step/s]

Average Inference Time : 0.0075471717999999995 seconds
Average #Images Processed : 132.4999650862592 Images
fmnist/fmnist 0.5


Calculate Score of overall:  50%|██████████████████████████▉                           | 4999/10000 [00:40<00:40, 123.36step/s]
Calculate A, B:   0%|                                                                    | 13/60000 [00:00<08:03, 124.17step/s]

fmnist/overall 0.9984162799999999
Now 2021-05-27 14:00:34.030957 Elapsed Time 0:08:46.070173
[0.99867764, 0.9984162799999999]
2
(0): Conv2d(128, 100, kernel_size=(4, 4), stride=(1, 1))
(1): Conv2d(128, 100, kernel_size=(4, 4), stride=(1, 1))
(2): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)


Calculate A, B:  17%|██████████▉                                                       | 9999/60000 [01:13<06:05, 136.73step/s]
Calculate Fisher Inverse:  17%|█████████▎                                              | 9999/60000 [01:20<06:40, 124.82step/s]
Calculate Score of fmnist(train):  17%|███████▉                                        | 9999/60000 [01:25<07:09, 116.35step/s]
Calculate Score of fmnist:  50%|███████████████████████████▍                           | 4999/10000 [00:43<00:43, 116.13step/s]
Calculate Score of overall:   0%|                                                        | 12/10000 [00:00<01:30, 110.42step/s]

Average Inference Time : 0.0086140692 seconds
Average #Images Processed : 116.08915331211874 Images
fmnist/fmnist 0.5


Calculate Score of overall:  17%|█████████▏                                            | 1692/10000 [00:15<01:14, 111.59step/s]