In [1]:
import numpy as np
import cv2
import pickle
import matplotlib.pyplot as plt
import time
import pandas as pd
import argparse

import torch
import torch.nn as nn
from torch.autograd import Variable

import torchvision
import torchvision.transforms as transforms

from numpy import linalg as LA
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn import linear_model

from fashion_model import FashionCNN 
from manifold_sampling import *

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# def arg_parse():
#     parser = argparse.ArgumentParser(description="UMAP discriminator")
#     parser.add_argument(
#             "--exp", dest="exp", help="Experiments"
#         )
#     parser.add_argument(
#             "--std", dest="std", type=float, help="Perturbation std"
#         )
#     parser.add_argument(
#         "--multiplier", dest="multiplier", type=int, help="Number of times an image is perturbed"
#     )
#     parser.add_argument(
#         "--perturbations", dest="num_perturbations", type=int, help="Number of perturbations"
#     )
#     parser.add_argument(
#         "--dim", dest="dim", type=int, help="Number of low dim"
#     )
#     parser.add_argument(
#         "--pivots", dest="pivots", type=int, help="Number of pivots"
#     )
#     parser.add_argument(
#         "--shuffle", dest="shuffle", type=bool, help="Shuffle the pivots"
#     )
    
        
#     parser.set_defaults(
#         exp = 'fashion_mnist',
#         std = 0.1,
#         num_perturbations = 1,
# #         runs = 10,
#         multiplier = 100,
#         dim = 2,
#         pivots = 10,
#         shuffle = True
#     )
#     return parser.parse_args()

# prog_args = arg_parse()


EXPERIMENT = 'fashion_mnist'
PERTURBATION_STD = 0.1
NUM_PERTURBATIONS = 1
MULTIPLIER = 100
DIM = 2
PIVOTS = 10
SHUFFLE = True

print("EXPERIMENT: ", EXPERIMENT)
print("MULTIPLIER: ", MULTIPLIER)
print("PERTURBATION_STD: ", PERTURBATION_STD)
print("DIM: ", DIM)
print("PIVOTS: ", PIVOTS)
print("SHUFFLE: ", SHUFFLE)

# EXPERIMENT = 'fashion_mnist'
# EXPERIMENT = 'mnist'
# EXPERIMENT = 'compass'
# EXPERIMENT = 'german'

if EXPERIMENT == 'fashion_mnist':
    print("Loading fashion mnist")
    train_set = torchvision.datasets.FashionMNIST("./data", download=True, transform=
                                                    transforms.Compose([transforms.ToTensor()]))
    test_set = torchvision.datasets.FashionMNIST("./data", download=True, train=False, transform=
                                                   transforms.Compose([transforms.ToTensor()]))
elif EXPERIMENT == 'mnist':
    print("Loading mnist")
    train_set = torchvision.datasets.MNIST("./data", download=True, transform=
                                                    transforms.Compose([transforms.ToTensor()]))
    test_set = torchvision.datasets.MNIST("./data", download=True, train=False, transform=
                                                   transforms.Compose([transforms.ToTensor()]))
else:
    print("Nothing to do.")
    
print("Done loading")
    
train_loader = torch.utils.data.DataLoader(train_set, 
                                           batch_size=100)
test_loader = torch.utils.data.DataLoader(test_set,
                                          batch_size=100)


all_loader = torch.utils.data.DataLoader(train_set, batch_size=train_set.__len__())
all_images, all_labels = next(iter(all_loader))

start_time = time.time()
manifold_sampler = Manifold_Image_Sampler(all_images, dim = DIM, labels = all_labels)
duration = time.time() - start_time
print("Initialize duration: ", duration)


start_time = time.time()
manifold_sampler.train_multiplier = MULTIPLIER
manifold_sampler.std_train = PERTURBATION_STD
manifold_sampler.train_pivot(no_pivots_per_label = PIVOTS, shuffle = SHUFFLE)
duration = time.time() - start_time
print("Train duration: ", duration)

def get_discriminator(X,y,n_estimators = 100, test_ratio = 0.5):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_ratio)
    the_rf = RandomForestClassifier(n_estimators=n_estimators).fit(X_train, y_train)
    y_pred = the_rf.predict(X_test)
    the_rf_result = (y_pred == y_test).sum()
    return the_rf, the_rf_result/y_test.shape[0], X_train.shape[0]

def get_discriminator_performance(X,y,rf):
    y_pred = rf.predict(X)
    the_rf_result = (y_pred == y).sum()
    return the_rf_result/y.shape[0], y.shape[0]

X_in = manifold_sampler.pivots.numpy()
X_per = np.expand_dims(np.vstack([perturbs[0] for perturbs in manifold_sampler.perturbs]), axis = 1)
X_plane = np.zeros_like(X_in)
X_ortho = np.zeros_like(X_in)
for i in range(X_plane.shape[0]):
    X_plane[i] = manifold_sampler.pivots[i] + manifold_sampler.plane_noise[i][0]
    X_ortho[i] = manifold_sampler.pivots[i] + manifold_sampler.ortho_noise[i][0]

X_discriminator_per = np.vstack((X_in, X_per))
X_discriminator_plane = np.vstack((X_in, X_plane))
X_discriminator_ortho = np.vstack((X_in, X_ortho))
y_discriminator = np.concatenate((np.zeros(X_in.shape[0]), np.ones(X_per.shape[0])))

the_rf_per, test_acc_per, no_trains = get_discriminator(manifold_sampler.to_1d(X_discriminator_per),y_discriminator, n_estimators = 100, test_ratio = 0.5)
print(test_acc, no_trains)
the_rf_plane, test_acc_plane, no_trains = get_discriminator(manifold_sampler.to_1d(X_discriminator_plane),y_discriminator, n_estimators = 100, test_ratio = 0.5)
print(test_acc, no_trains)
the_rf_ortho, test_acc_ortho, no_trains = get_discriminator(manifold_sampler.to_1d(X_discriminator_ortho),y_discriminator, n_estimators = 100, test_ratio = 0.5)
print(test_acc, no_trains)

print("Create Testing environment")

start_time = time.time()
explanation_sampler = Manifold_Image_Sampler(all_images, dim = DIM, labels = all_labels)
duration = time.time() - start_time
print("Initialize duration: ", duration)

start_time = time.time()
explanation_sampler.train_multiplier = MULTIPLIER
explanation_sampler.std_train = PERTURBATION_STD
explanation_sampler.train_pivot(no_pivots_per_label = PIVOTS, shuffle = True)
duration = time.time() - start_time
print("Train duration: ", duration)

Z_in = explanation_sampler.pivots.numpy()
acc_per = 0
acc_plane = 0
acc_ortho = 0
var_per = 0 
var_plane = 0
var_ortho = 0
for p in range(NUM_PERTURBATIONS):
    Z_per = np.expand_dims(np.vstack([perturbs[p] for perturbs in explanation_sampler.perturbs]), axis = 1)
    Z_plane = np.zeros_like(Z_in)
    Z_ortho = np.zeros_like(Z_in)
    for i in range(Z_plane.shape[0]):
        Z_plane[i] = explanation_sampler.pivots[i] + explanation_sampler.plane_noise[i][p]
        Z_ortho[i] = explanation_sampler.pivots[i] + explanation_sampler.ortho_noise[i][p]

    Z_discriminator_per = np.vstack((Z_in, Z_per))
    Z_discriminator_plane = np.vstack((Z_in, Z_plane))
    Z_discriminator_ortho = np.vstack((Z_in, Z_plane))
    y_discriminator = np.concatenate((np.zeros(Z_in.shape[0]), np.ones(Z_per.shape[0])))
    
    test_acc_per, no_test = get_discriminator_performance(explanation_sampler.to_1d(Z_discriminator_per), y_discriminator, the_rf_per)
    test_acc_plane, no_test = get_discriminator_performance(explanation_sampler.to_1d(Z_discriminator_plane), y_discriminator, the_rf_plane)
    test_acc_ortho, no_test = get_discriminator_performance(explanation_sampler.to_1d(Z_discriminator_ortho), y_discriminator, the_rf_ortho)
    
    acc_per = acc_per + test_acc_per
    acc_plane = acc_plane + test_acc_plane
    acc_ortho = acc_ortho + test_acc_ortho
    var_per = var_per + np.var(Z_per-Z_in)
    var_plane = var_plane + np.var(Z_plane-Z_in)
    var_ortho = var_ortho + np.var(Z_ortho-Z_in)
    
acc_per = acc_per/NUM_PERTURBATIONS
acc_plane = acc_plane/NUM_PERTURBATIONS
acc_ortho = acc_ortho/NUM_PERTURBATIONS
var_per = var_per/NUM_PERTURBATIONS
var_plane = var_plane/NUM_PERTURBATIONS
var_ortho = var_ortho/NUM_PERTURBATIONS

df = pd.DataFrame({'per': [acc_per, test_acc_per, var_per],
                   'plane': [acc_plane, test_acc_plane, var_plane],
                   'ortho': [acc_ortho, test_acc_ortho, var_ortho]})

discriminator_file = 'results/discriminator/' + EXPERIMENT + '_dim_' + str(DIM) + '_noise_' + str(PERTURBATION_STD) +'_.pickle'
print("Save file to ", discriminator_file)
with open(discriminator_file, 'wb') as output:
    pickle.dump(df, output)

# result = []
# # for test_ratio in list(np.arange(0.5,0.99,0.05)):
# #     accs_umap = []
# #     accs_base = []
# #     for _ in range(NUM_RUNS):
# #         acc_umap, _ = get_discriminator_performance(get_1d(all_x),all_y, test_ratio = test_ratio)
# #         acc_base, n = get_discriminator_performance(get_1d(all_x_base),all_y, test_ratio = test_ratio)
# #         accs_umap.append(acc_umap)
# #         accs_base.append(acc_base)
# #     mean_umap = np.mean(np.asarray(accs_umap))
# #     std_umap = np.std(np.asarray(accs_umap))
# #     mean_base = np.mean(np.asarray(accs_base))
# #     std_base = np.std(np.asarray(accs_base))
#     result.append((acc_per, std_per, ))
    
# df = pd.DataFrame.from_records(result, columns =['NoTrain', 'Base', 'std_base', 'Manifold', 'std_manifold'])

# discriminator_file = 'results/discriminator/accuracy_on_' + EXPERIMENT + '_dim_' + str(DIM) + '_noise_' + str(PERTURBATION_STD) +'_.pickle'
# print("Save file to ", discriminator_file)
# with open(discriminator_file, 'wb') as output:
#     pickle.dump(df, output)

EXPERIMENT:  fashion_mnist
MULTIPLIER:  100
PERTURBATION_STD:  0.1
DIM:  2
PIVOTS:  10
SHUFFLE:  True
Loading fashion mnist
Done loading
Initialize duration:  73.38422083854675
Train duration:  161.80371475219727
1.0 100
1.0 100
1.0 100
Create Testing environment
Initialize duration:  57.322988510131836
Train duration:  156.87672019004822
0.995
0.995
0.955


In [18]:
df = pd.DataFrame({'per': [acc_per, test_acc_per, var_per],
                   'plane': [acc_plane, test_acc_plane, var_plane],
                   'ortho': [acc_ortho, test_acc_ortho, var_ortho]})

In [19]:
discriminator_file = 'results/discriminator/' + EXPERIMENT + '_dim_' + str(DIM) + '_noise_' + str(PERTURBATION_STD) +'_.pickle'
print("Save file to ", discriminator_file)
with open(discriminator_file, 'wb') as output:
    pickle.dump(df, output)

Unnamed: 0,per,plane,ortho
0,0.995,0.995,0.955
1,1.0,1.0,1.0
2,0.000582,3.3e-05,0.006541


In [68]:
load_file = 'results/discriminator/' + 'mnist' + '_dim_' + str(2) + '_std_' + str(0.1) +'_.pickle'

In [69]:
with open(load_file, 'rb') as file:
    load_data = pickle.load(file)

df = load_data

In [70]:
df

Unnamed: 0,per,plane,ortho
0,0.99995,0.9916,0.74595
1,1.0,0.995,0.71
2,0.000639,3.9e-05,0.004385
