In [None]:
import sys
sys.path.append('..')
from tqdm import tqdm
import pandas as pd
import numpy as np
import pickle
import warnings
from src.lesion import perform_lesion_experiment, do_lesion_hypo_tests
from src.pointers import DATA_PATHS
from src.lesion.experimentation import do_lesion_hypo_tests_imagenet
from src.experiment_tagging import get_model_path
from src.utils import bates_quantile
from scipy.stats import sem

In [None]:
models = [('mnist', 'MNIST'), ('mnist', 'CNN-MNIST'), ('cifar10_full', 'CNN-VGG-CIFAR10+DROPOUT+L2REG')]

n_clust = 16
n_shuffles = 19
n_workers = 5
n_reps = 5
is_unpruned = True
results_dir = '/project/nn_clustering/results/'

all_results = []

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    for dataset_name, tag in tqdm(models):
        if dataset_name == 'cifar10_full':
            dp, ns = (3, 32)
        else:
            dp, ns = (1, 28)
        for use_activations in [False, True]:
            for do_local in [False, True]:

                paths = get_model_path(tag, filter_='all')[-n_reps:]
                fisher_p_means, chi2_p_means, effect_means = [], [], []
                fisher_p_ranges, chi2_p_ranges, effect_ranges = [], [], []
                fisher_stat_means, fisher_stat_ranges = [], []

                with open(results_dir + '/lesion_data_' + tag +
                          f'_activations={use_activations}_local={do_local}.pkl', 'rb') as f:
                    net_pkl_results = pickle.load(f)

                for p_i in range(len(paths)):

                    true_results = net_pkl_results[p_i]['true_results']
                    all_random_results = net_pkl_results[p_i]['all_random_results']
                    metadata = net_pkl_results[p_i]['metadata']
                    evaluation = net_pkl_results[p_i]['evaluation']

                    hypo_results = do_lesion_hypo_tests(evaluation, true_results, all_random_results)
                    fisher_p_means.append(hypo_results['fisher_p_means'])
                    fisher_stat_means.append(hypo_results['fisher_stat_means'])
                    chi2_p_means.append(hypo_results['chi2_p_means'])
                    effect_means.append(hypo_results['effect_factors_means'])
                    fisher_p_ranges.append(hypo_results['fisher_p_ranges'])
                    fisher_stat_ranges.append(hypo_results['fisher_stat_ranges'])
                    chi2_p_ranges.append(hypo_results['chi2_p_ranges'])
                    effect_ranges.append(hypo_results['effect_factors_range'])

                model_results = {'is_unpruned': is_unpruned,
                                 'model_tag': tag,
                                 'activations': use_activations,
                                 'local': do_local,
                                 'fisher_p_means': bates_quantile(np.mean(np.array(fisher_p_means)), n_reps),
                                 'fisher_stat_means': np.mean(np.array(fisher_stat_means)),
                                 # 'chi2_p_means': bates_quantile(np.mean(np.array(chi2_p_means)), n_reps),
                                 'effect_means': np.mean(np.concatenate(effect_means))*2,
                                 'effect_means_sem': sem(np.concatenate(effect_means)*2),
                                 'fisher_p_ranges': bates_quantile(np.mean(np.array(fisher_p_ranges)), n_reps),
                                 'fisher_stat_ranges': np.mean(np.array(fisher_stat_ranges)),
                                 # 'chi2_p_ranges': bates_quantile(np.mean(np.array(chi2_p_ranges)), n_reps),
                                 'effect_ranges': np.mean(np.concatenate(effect_ranges))*2,
                                 'effect_ranges_sem': sem(np.concatenate(effect_ranges)*2),}
                all_results.append(pd.Series(model_results))

In [None]:
nets = ['resnet18', 'vgg16']
num_ranks = n_shuffles + 1
min_percentile = 1 / num_ranks
n_clust = 32

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    for net in nets:
        for use_activations in [False, True]:
            for do_local in [False, True]:

                with open(results_dir + '/lesion_data_' + net +
                          f'_activations={use_activations}_local={do_local}.pkl', 'rb') as f:
                    results_dict = pickle.load(f)

                hypo_test_results = do_lesion_hypo_tests_imagenet(results_dict['results'], n_shuffles=n_shuffles)

                fisher_p_means = hypo_test_results['fisher_p_means']
                fisher_stat_means = hypo_test_results['fisher_stat_means']
                chi2_p_means = hypo_test_results['chi2_p_means']
                effect_means = hypo_test_results['effect_factors_means']
                fisher_p_ranges = hypo_test_results['fisher_p_ranges']
                fisher_stat_ranges = hypo_test_results['fisher_stat_ranges']
                chi2_p_ranges= hypo_test_results['chi2_p_ranges']
                effect_ranges = hypo_test_results['effect_factors_range']

                model_results = {'is_unpruned': True,
                                 'model_tag': net,
                                 'activations': use_activations,
                                 'local': do_local,
                                 'fisher_p_means': fisher_p_means,
                                 'fisher_stat_means': fisher_stat_means,
                                 # 'chi2_p_means': chi2_p_means,
                                 'effect_means': np.mean(effect_means*2),
                                 'effect_means_sem': sem(effect_means*2, axis=None),
                                 'fisher_p_ranges': fisher_p_ranges,
                                 'fisher_stat_ranges': fisher_stat_ranges,
                                 # 'chi2_p_ranges': chi2_p_ranges,
                                 'effect_ranges': np.mean(effect_ranges*2),
                                 'effect_ranges_sem': sem(effect_ranges*2, axis=None),}
                all_results.append(pd.Series(model_results))

In [None]:
result_df = pd.DataFrame(all_results)
savepath = '../results/lesion_results_all.csv'
result_df.to_csv(savepath)
result_df