# [ VAE ] OOD Detection using FSS

In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
sys.path.append(os.getcwd() + '/core')
sys.path.append(os.getcwd() + '/core/train_GLOW') 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

import core.config as config
from core.model_loader import load_pretrained_VAE
from core.data_loader import TRAIN_loader, TEST_loader
from core.custom_loss import KL_div, VAE_loss_pixel
from core.fisher_utils_VAE import Calculate_fisher_VAE, Calculate_score_VAE, AUTO_VAE
from core.visualize import plot_hist, AUROC, plot_scores_all_layers

# Define Global Variables & Initialize

In [2]:
# FISHER_invs : Fisher inverse matrices from calculating Fisher score (w.r.t. train-dist)
FISHERs = {'cifar10': {}, 'fmnist': {}}
         
# NORMalize_FACtors : Normalizing Factor (w.r.t. train-dist)
NORM_FACs = {'cifar10': {}, 'fmnist': {}}
             
# SCOREs : (Scalars) Scores from calculating Fisher score (w.r.t. target-dist)
SCOREs = {'cifar10': {}, 'fmnist': {}}


# VAE-CIFAR10

In [5]:
opt = config.VAE_cifar10
netE, netG = load_pretrained_VAE(option=opt.train_dist, ngf=64, nz=200, beta=1, augment='hflip', epoch=100)
netE.eval()
netG.eval()

params = {
    #'Emain0_w': netE.main[0].weight,
    #'Emain1_w': netE.main[1].weight,
    #'Emain1_b': netE.main[1].bias,
    #'Emain3_w': netE.main[3].weight,
    #'Emain4_w': netE.main[4].weight,
    #'Emain4_b': netE.main[4].bias,
    #'Emain6_w': netE.main[6].weight,
    #'Emain7_w': netE.main[7].weight,
    #'Emain7_b': netE.main[7].bias,
    'Econv1_w': netE.conv1.weight,
    #'Econv1_b': netE.conv1.bias,
    #'Econv2_w': netE.conv2.weight,
    #'Econv2_b': netE.conv2.bias,
    #'Gmain0_w': netG.main[0].weight,
    #'Gmain1_w': netG.main[1].weight,
    #'Gmain1_b': netG.main[1].bias,
    #'Gmain3_w': netG.main[3].weight,
    #'Gmain4_w': netG.main[4].weight,
    #'Gmain4_b': netG.main[4].bias,
    #'Gmain6_w': netG.main[6].weight,
    #'Gmain7_w': netG.main[7].weight,
    #'Gmain7_b': netG.main[7].bias,
    #'Gmain9_w': netG.main[9].weight,
}


# 추가훈련

In [4]:
for param in netE.parameters():
    param.requires_grad_(False)
for param in netG.parameters():
    param.requires_grad_(False)
for param in netE.conv1.parameters():
    param.requires_grad_(True)
#for param in netE.conv2.parameters():
#    param.requires_grad_(True)

import torch.optim as optim
from torch.autograd import Variable
from datetime import datetime
from tqdm import tqdm

device = 'cuda:0'
optimizer = optim.Adam(netE.conv1.parameters(), lr=5e-5, weight_decay=0)
loss_fn = nn.CrossEntropyLoss(reduction='none')
rec_l, kl = [], []
loader = TRAIN_loader('cifar10', augment=True, batch_size=64)
start = datetime.now()

for epoch in range(10):
    mean_loss = 0.
    for i, (x, _) in enumerate(tqdm(loader)):
        x = x.to(device)
        b = x.size(0)
        target = Variable(x.data.view(-1) * 255).long()
        [z, mu, logvar] = netE(x)
        recon = netG(z)
        recon = recon.contiguous()
        recon = recon.view(-1, 256)
        recl = loss_fn(recon, target)
        recl = torch.sum(recl) / b
        kld = KL_div(mu, logvar)
        loss = recl + 1 * kld.mean()
        
        optimizer.zero_grad()
        total_loss = loss
        loss.backward(retain_graph=True)
        optimizer.step()
        rec_l.append(recl.detach().item())
        kl.append(kld.mean().detach().item())
        mean_loss = (mean_loss * i + loss.detach().item()) / (i + 1)
        
    now = datetime.now()
    print(f'Epoch {epoch+1:02d} recon {np.mean(rec_l):.2f} kl {np.mean(kl):.2f} Elapsed time {now - start}')
    
    

Files already downloaded and verified


100%|████████████████████████████████████████████████████████████████████████████████| 782/782 [00:16<00:00, 48.87it/s]
  0%|                                                                                          | 0/782 [00:00<?, ?it/s]

Epoch 01 recon 11375.67 kl 430.14 Elapsed time 0:00:16.005411


100%|████████████████████████████████████████████████████████████████████████████████| 782/782 [00:15<00:00, 48.98it/s]
  0%|                                                                                          | 0/782 [00:00<?, ?it/s]

Epoch 02 recon 11371.84 kl 430.75 Elapsed time 0:00:31.971946


100%|████████████████████████████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 53.15it/s]
  0%|                                                                                          | 0/782 [00:00<?, ?it/s]

Epoch 03 recon 11369.52 kl 431.22 Elapsed time 0:00:46.685615


100%|████████████████████████████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 53.08it/s]
  0%|                                                                                          | 0/782 [00:00<?, ?it/s]

Epoch 04 recon 11367.82 kl 431.60 Elapsed time 0:01:01.419246


100%|████████████████████████████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 52.97it/s]
  0%|                                                                                          | 0/782 [00:00<?, ?it/s]

Epoch 05 recon 11366.53 kl 431.91 Elapsed time 0:01:16.184542


100%|████████████████████████████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 52.91it/s]
  0%|                                                                                          | 0/782 [00:00<?, ?it/s]

Epoch 06 recon 11365.53 kl 432.18 Elapsed time 0:01:30.965257


100%|████████████████████████████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 52.74it/s]
  0%|                                                                                          | 0/782 [00:00<?, ?it/s]

Epoch 07 recon 11364.68 kl 432.41 Elapsed time 0:01:45.793521


100%|████████████████████████████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 52.76it/s]
  0%|                                                                                          | 0/782 [00:00<?, ?it/s]

Epoch 08 recon 11363.96 kl 432.61 Elapsed time 0:02:00.617246


100%|████████████████████████████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 52.80it/s]
  0%|                                                                                          | 0/782 [00:00<?, ?it/s]

Epoch 09 recon 11363.32 kl 432.80 Elapsed time 0:02:15.429313


100%|████████████████████████████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 52.68it/s]

Epoch 10 recon 11362.75 kl 432.96 Elapsed time 0:02:30.273580





In [7]:
a, b, _scores = AUTO_VAE(opt, netE, netG, params, max_iter=[10, 1000], loss_type='ELBO_pixel', method='Vanilla')

Files already downloaded and verified


Calculate Fisher VAE:   0%|                                                      | 9/50000 [00:02<4:17:01,  3.24step/s]


Files already downloaded and verified


Calculate Score VAE:  10%|█████▍                                                 | 999/10000 [00:23<03:29, 42.98step/s]


Using downloaded and verified file: ../data\test_32x32.mat


Calculate Score VAE:   4%|██                                                     | 999/26032 [00:23<09:51, 42.31step/s]
Calculate Score VAE:   5%|██▊                                                    | 999/19141 [00:22<06:54, 43.75step/s]
Calculate Score VAE:  10%|█████▍                                                 | 999/10000 [00:23<03:33, 42.15step/s]


Files already downloaded and verified


Calculate Score VAE:  10%|█████▍                                                 | 999/10000 [00:23<03:32, 42.30step/s]
Calculate Score VAE:  10%|█████▍                                                 | 999/10000 [00:22<03:25, 43.87step/s]
Calculate Score VAE:  10%|█████▍                                                 | 999/10000 [00:22<03:25, 43.76step/s]
Calculate Score VAE:  10%|█████▍                                                 | 999/10000 [00:22<03:26, 43.67step/s]


Files already downloaded and verified


Calculate Score VAE:   8%|████▏                                                  | 999/13180 [00:23<04:46, 42.46step/s]
Calculate Score VAE:   5%|██▉                                                    | 999/18724 [00:22<06:38, 44.50step/s]
Calculate Score VAE:   8%|████▎                                                  | 999/12630 [00:22<04:25, 43.75step/s]
Calculate Score VAE:  10%|█████▍                                                 | 999/10000 [00:21<03:16, 45.88step/s]
Calculate Score VAE:  10%|█████▍                                                 | 999/10000 [00:21<03:15, 45.93step/s]


In [11]:
for pname in params.keys():
    #FISHERs[opt.train_dist][pname] = a[pname] 
    #NORM_FACs[opt.train_dist][pname] = b[pname]
    pass
for ood in opt.ood_list:
    SCOREs[opt.train_dist][ood] = _scores[ood]

In [12]:
# AUROC curve
auroc = {}
for pname in params.keys():
    _auroc = {}
    for ood in opt.ood_list:
        args = [
            SCOREs[opt.train_dist][opt.train_dist][pname],
            SCOREs[opt.train_dist][ood][pname],
        ]
        labels = [opt.train_dist, ood]
        _auroc[ood] = AUROC(*args, labels=labels, verbose=False)
    auroc[pname] = _auroc
    
auroc

{'Econv1_w': {'cifar10': 0.49999999999999994,
  'svhn': 0.7574559999999999,
  'celeba': 0.581585,
  'lsun': 0.5073019999999999,
  'cifar100': 0.494632,
  'mnist': 0.907841,
  'fmnist': 0.869656,
  'kmnist': 0.9121130000000001,
  'omniglot': 0.993495,
  'notmnist': 0.928892,
  'trafficsign': 0.533332,
  'noise': 0.591475,
  'constant': 0.912907}}

In [6]:
# Just show scores
# plot_scores_all_layers(train_dist, params, SCOREs, opt, save=True)


# VAE-FMNIST

In [3]:
opt = config.VAE_fmnist
netE, netG = load_pretrained_VAE(option=opt.train_dist, ngf=32, nz=100, beta=1, augment='hflip', epoch=100)
netE.eval()
netG.eval()

params = {
    #'Emain0_w': netE.main[0].weight,
    #'Emain1_w': netE.main[1].weight,
    #'Emain1_b': netE.main[1].bias,
    #'Emain3_w': netE.main[3].weight,
    #'Emain4_w': netE.main[4].weight,
    #'Emain4_b': netE.main[4].bias,
    #'Emain6_w': netE.main[6].weight,
    #'Emain7_w': netE.main[7].weight,
    #'Emain7_b': netE.main[7].bias,
    'Econv1_w': netE.conv1.weight,
    #'Econv1_b': netE.conv1.bias,
    #'Econv2_w': netE.conv2.weight,
    #'Econv2_b': netE.conv2.bias,
    #'Gmain0_w': netG.main[0].weight,
    #'Gmain1_w': netG.main[1].weight,
    #'Gmain1_b': netG.main[1].bias,
    #'Gmain3_w': netG.main[3].weight,
    #'Gmain4_w': netG.main[4].weight,
    #'Gmain4_b': netG.main[4].bias,
    #'Gmain6_w': netG.main[6].weight,
    #'Gmain7_w': netG.main[7].weight,
    #'Gmain7_b': netG.main[7].bias,
    #'Gmain9_w': netG.main[9].weight,
}

In [5]:
for param in netE.parameters():
    param.requires_grad_(False)
for param in netG.parameters():
    param.requires_grad_(False)
for param in netE.conv1.parameters():
    param.requires_grad_(True)
#for param in netE.conv2.parameters():
#    param.requires_grad_(True)

import torch.optim as optim
from torch.autograd import Variable
from datetime import datetime

device = 'cuda:0'
optimizer = optim.Adam(netE.conv1.parameters(), lr=5e-5, weight_decay=0)
loss_fn = nn.CrossEntropyLoss(reduction='none')
rec_l, kl = [], []
loader = TRAIN_loader('fmnist', 'fmnist')
start = datetime.now()

for epoch in range(10):
    mean_loss = 0.
    for i, (x, _) in enumerate(loader):
        x = x.to(device)
        b = x.size(0)
        target = Variable(x.data.view(-1) * 255).long()
        [z, mu, logvar] = netE(x)
        recon = netG(z)
        recon = recon.contiguous()
        recon = recon.view(-1, 256)
        recl = loss_fn(recon, target)
        recl = torch.sum(recl) / b
        kld = KL_div(mu, logvar)
        loss = recl + 1 * kld.mean()
        
        optimizer.zero_grad()
        total_loss = loss
        loss.backward(retain_graph=True)
        optimizer.step()
        rec_l.append(recl.detach().item())
        kl.append(kld.mean().detach().item())
        mean_loss = (mean_loss * i + loss.detach().item()) / (i + 1)
        
    now = datetime.now()
    print(f'Epoch {epoch+1:02d} recon {np.mean(rec_l):.2f} kl {np.mean(kl):.2f} Elapsed time {now - start}')
    
    

Epoch 01 recon 2145.25 kl 146.29 Elapsed time 0:00:32.395625
Epoch 02 recon 2144.09 kl 146.47 Elapsed time 0:01:04.396823
Epoch 03 recon 2143.23 kl 146.61 Elapsed time 0:01:37.149608
Epoch 04 recon 2142.58 kl 146.72 Elapsed time 0:02:09.070494
Epoch 05 recon 2142.03 kl 146.82 Elapsed time 0:02:41.094448
Epoch 06 recon 2141.61 kl 146.90 Elapsed time 0:03:12.810576
Epoch 07 recon 2141.14 kl 146.98 Elapsed time 0:03:44.436634
Epoch 08 recon 2140.74 kl 147.05 Elapsed time 0:04:15.897414
Epoch 09 recon 2140.37 kl 147.12 Elapsed time 0:04:47.259621
Epoch 10 recon 2140.01 kl 147.18 Elapsed time 0:05:18.747113


In [4]:
a, b, _scores = AUTO_VAE(opt, netE, netG, params, max_iter=[10000, 5000], loss_type='ELBO_pixel', method='SMW')

Calculate Fisher VAE:  17%|████████▊                                            | 9999/60000 [03:12<16:05, 51.81step/s]
Calculate Score VAE:  50%|██████████████████████████▍                          | 4999/10000 [00:36<00:36, 137.65step/s]


Using downloaded and verified file: ../data\test_32x32.mat


Calculate Score VAE:  19%|██████████▏                                          | 4999/26032 [00:37<02:37, 133.34step/s]
Calculate Score VAE:  26%|█████████████▊                                       | 4999/19141 [00:39<01:53, 125.15step/s]
Calculate Score VAE:  50%|██████████████████████████▍                          | 4999/10000 [00:42<00:42, 117.00step/s]


Files already downloaded and verified


Calculate Score VAE:  50%|██████████████████████████▍                          | 4999/10000 [00:37<00:37, 134.57step/s]


Files already downloaded and verified


Calculate Score VAE:  50%|██████████████████████████▍                          | 4999/10000 [00:37<00:37, 134.36step/s]
Calculate Score VAE:  50%|██████████████████████████▍                          | 4999/10000 [00:36<00:36, 138.28step/s]
Calculate Score VAE:  50%|██████████████████████████▍                          | 4999/10000 [00:36<00:36, 138.25step/s]


Files already downloaded and verified


Calculate Score VAE:  38%|████████████████████                                 | 4999/13180 [00:37<01:00, 134.43step/s]
Calculate Score VAE:  27%|██████████████▏                                      | 4999/18724 [00:37<01:43, 132.16step/s]
Calculate Score VAE:  50%|██████████████████████████▍                          | 4999/10000 [00:34<00:34, 144.35step/s]
Calculate Score VAE:  50%|██████████████████████████▍                          | 4999/10000 [00:34<00:34, 144.27step/s]


In [7]:
for pname in params.keys():
    FISHERs[opt.train_dist][pname] = a[pname] 
    NORM_FACs[opt.train_dist][pname] = b[pname]
for ood in opt.ood_list:
    SCOREs[opt.train_dist][ood] = _scores[ood]

In [8]:
# AUROC curve
auroc = {}
for pname in params.keys():
    _auroc = {}
    for ood in opt.ood_list:
        args = [
            SCOREs[opt.train_dist][opt.train_dist][pname],
            SCOREs[opt.train_dist][ood][pname],
        ]
        labels = [opt.train_dist, ood]
        _auroc[ood] = AUROC(*args, labels=labels, verbose=False)
    auroc[pname] = _auroc
    
auroc

{'Econv1_w': {'fmnist': 0.5,
  'svhn': 0.9969616800000001,
  'celeba': 0.99908208,
  'lsun': 0.99571948,
  'cifar10': 0.9981044400000001,
  'cifar100': 0.99747032,
  'mnist': 0.99382972,
  'kmnist': 0.9967750400000001,
  'omniglot': 1.0,
  'notmnist': 0.9998774,
  'noise': 0.9905497600000001,
  'constant': 0.9966937200000001}}

In [9]:
# Just show scores
# plot_scores_all_layers(train_dist, params, SCOREs, opt, save=True)