# Create VAE with Intermediate Feature Activations


In [1]:
import os
import sys
import pickle
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
from IPython.display import display, HTML
%reload_ext autoreload
%autoreload 2
%matplotlib inline

# Nicer way to import the module?
sys.path.append(str(Path.cwd().parent))
from utils.loading import load_net, vae_from_args
from utils.train_val import validate_epoch
from utils.data import make_generators_DF_MNIST

from models.FeatureVAE import FEAT_VAE_MNIST

import torch
from torch import nn
from torch.autograd import Variable
from torchvision import transforms
from torch.nn import functional as F
import torch.optim as optim

import foolbox
import json
from PIL import Image

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(device)
print(torch.cuda.current_device())

BASE_PATH = Path('/media/rene/data/adv_gen/MNIST/mnist_normal/feature_models')

cuda:1
0


## Classifier Performance

In [2]:
files_df_loc = '/media/rene/data/adv_gen/MNIST/mnist_normal/files_df.pkl'
batch_size = 64
num_workers = 2
IM_SIZE = 28

with open(files_df_loc, 'rb') as f:
    files_df = pickle.load(f)

dataloaders = make_generators_DF_MNIST(files_df, batch_size, num_workers, size=IM_SIZE,
                                        path_colname='path', adv_path_colname=None, return_loc=False, normalize=True)

In [5]:
model_name_list = ['SimpleNetMNIST-10_model_best.pth.tar',
                   'SimpleNetMNIST-12_model_best.pth.tar',
                   'SimpleNetMNIST-16_model_best.pth.tar',
                   'SimpleNetMNIST-20_model_best.pth.tar',
                   'TopkNetMNIST-16-10_model_best.pth.tar',
                   'TopkNetMNIST-16-8_model_best.pth.tar',
                   'TopkNetMNIST-12-6_model_best.pth.tar',
                   'TopkNetMNIST-10-5_model_best.pth.tar'
                  ]

for model_name in model_name_list:
    model_loc = BASE_PATH / model_name
    model = load_net(model_loc).to(device).eval()
    acc, loss = validate_epoch(dataloaders['val'], model, device)
    print(f'Model Name: {model_name}: Accuracy: {acc}')

VALID:  * TOP1 98.820 TOP5 100.000 Loss (0.0000)	 Time (0.012)	
Model Name: SimpleNetMNIST-10_model_best.pth.tar: Accuracy: 98.81999969482422
VALID:  * TOP1 98.810 TOP5 100.000 Loss (0.0000)	 Time (0.013)	
Model Name: SimpleNetMNIST-12_model_best.pth.tar: Accuracy: 98.80999755859375
VALID:  * TOP1 98.740 TOP5 100.000 Loss (0.0000)	 Time (0.012)	
Model Name: SimpleNetMNIST-16_model_best.pth.tar: Accuracy: 98.73999786376953
VALID:  * TOP1 98.880 TOP5 100.000 Loss (0.0000)	 Time (0.012)	
Model Name: SimpleNetMNIST-20_model_best.pth.tar: Accuracy: 98.87999725341797
VALID:  * TOP1 98.520 TOP5 100.000 Loss (0.0000)	 Time (0.013)	
Model Name: TopkNetMNIST-16-10_model_best.pth.tar: Accuracy: 98.5199966430664
VALID:  * TOP1 42.730 TOP5 87.050 Loss (0.0000)	 Time (0.013)	
Model Name: TopkNetMNIST-16-8_model_best.pth.tar: Accuracy: 42.72999954223633
VALID:  * TOP1 56.720 TOP5 92.900 Loss (0.0000)	 Time (0.013)	
Model Name: TopkNetMNIST-12-6_model_best.pth.tar: Accuracy: 56.71999740600586
VALID:  

## Feature VAE Classification Performance

In [3]:
results= {}
results[10] = pickle.load(open(BASE_PATH/ 'FEAT_VAE_MNIST-6-10_iter50_nt10_nsamp200_deter_results.pkl', "rb"))
results[25] = pickle.load(open(BASE_PATH/ 'FEAT_VAE_MNIST-6-10_iter50_nt25_nsamp200_deter_results.pkl', "rb"))
results[50] = pickle.load(open(BASE_PATH/ 'FEAT_VAE_MNIST-6-10_iter50_nt50_nsamp200_deter_results.pkl', "rb"))
# results[100] = pickle.load(open(BASE_PATH/'FEAT_VAE_MNIST-6-10_iter50_nt100_nsamp200_deter_results.pkl', "rb"))
# results[1000] = pickle.load(open(BASE_PATH/'FEAT_VAE_MNIST-6-10_iter50_nt1000_nsamp200_deter_results.pkl', "rb"))

for n_t, result in results.items():
    acc = len(result[result['predicted_label']==result['true_label']])/len(result)
    print(f'Number of times: {n_t}, Accuracy {acc}')

Number of times: 10, Accuracy 0.88
Number of times: 25, Accuracy 0.93
Number of times: 50, Accuracy 0.97


In [8]:
model_loc = '/media/rene/data/adv_gen/MNIST/mnist_normal/feature_models'
encoding_model_loc = '/media/rene/data/adv_gen/MNIST/mnist_normal/feature_models/SimpleNetMNIST-10_model_best.pth.tar'
num_features = 10
latent_size = 6

In [9]:
model = FEAT_VAE_MNIST(encoding_model=load_net(encoding_model_loc).to(device),
                             num_features=num_features,
                             latent_size=latent_size)
model = model.to(device)

In [18]:
print(model.encoding_model.parameters())

for p in model.parameters():
    print(p.requires_grad)

<generator object Module.parameters at 0x7fc6bbc1e938>
False
False
False
False
False
False
False
False
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True


In [21]:
model_file = 'FEAT_VAE_MNIST-6-10-MNIST_label_4_model_best.pth.tar'
int(model_file.split('-')[1].split('_')[0])

6

In [2]:
results= {}
results['20_50'] = pickle.load(open(BASE_PATH/'FEAT_VAE_MNIST-6-10_iter100_nt50_nsamp20_deter_results.pkl', "rb"))

for n_t, result in results.items():
    acc = len(result[result['predicted_label']==result['true_label']])/len(result)
    print(f'Number of times: {n_t}, Accuracy {acc}')

Number of times: 20_50, Accuracy 0.95
