In [None]:
import numpy as np
import pandas as pd
import sys
import seaborn
import matplotlib
import matplotlib.ticker
from matplotlib import pyplot as plt
import experiments

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
seaborn.set()
seaborn.set_style("white")
seaborn.set_context("poster")
from matplotlib import pyplot

In [None]:
# plot accuracy

train_accs = []
test_accs = []

for ID in range(2, 7):

    opt = experiments.opt[ID]
    csv_file = experiments.opt[ID].csv_dir + experiments.opt[ID].name + '_redundancy.csv'
    df = pd.read_csv(csv_file)
    train_accs.append(df[df['evaluation_set'] == 'train']['performance'].iloc[0])
    test_accs.append(df[df['evaluation_set'] == 'test']['performance'].iloc[0])

x = [0.25, 0.5, 1, 2, 4]

plt.figure()
plt.plot(x, train_accs, 'bo-')
plt.xscale('log')
plt.xlabel('size_factor')
plt.xticks(x, [str(x_val) for x_val in x])
plt.ylim([0.5, 1.1])
plt.title('train_accs')
plt.show()

plt.figure()
plt.plot(x, test_accs, 'bo-')
plt.xscale('log')
plt.xlabel('size_factor')
plt.xticks(x, [str(x_val) for x_val in x])
plt.ylim([0.5, 1.1])
plt.title('test_accs')
plt.show()

In [None]:
# plot prunability

crosses = 3
range_len = 7
knockout_range = np.linspace(0.0, 1.0, num=range_len)
prunability_means = []

def get_threshold(ab_data, th=0.8):
    # ab_data is a crosses x ablation proportions array, th is threshold
    crosses, ranges = ab_data.shape
    for c in range(crosses):
        for r in range(ranges-1):
            if ab_data[c,r] >= th and ab_data[c,r+1] < th:
                x1 = r / (ranges-1)
                x2 = (r+1) / (ranges-1)
                y1 = ab_data[c,r]
                y2 = ab_data[c,r+1]
                m_inv = (x2 - x1) / (y2 - y1)
                ab_data[c,0] = (m_inv * (th - y1)) + x1  # store in the 0 column
                break
    return ab_data[:, 0]  # return 0 column

for ID in range(2, 7):

    opt = experiments.opt[ID]
    csv_file = experiments.opt[ID].csv_dir + experiments.opt[ID].name + '_robustness.csv'
    df = pd.read_csv(csv_file)

    ablation_results = np.zeros((crosses, range_len))
    for cross in range(crosses):
        for amount in range(range_len):
            ablation_results[cross, amount] = df[df['cross_validation'] == cross] \
                [df['evaluation_set'] == 'test'][df['perturbation_layer'] == 'all'] \
                [df['perturbation_name'] == 'Activation Knockout'][df['perturbation_amount'] == knockout_range[amount]] \
                ['performance'].iloc[0]

    ablation_results = get_threshold(ablation_results) * (1/4) * opt.dnn.neuron_multiplier[0]
    prunability_means.append(np.mean(ablation_results))

x = [0.25, 0.5, 1, 2, 4]

plt.figure()
plt.plot(x, prunability_means, 'bo-')
plt.xscale('log')
plt.yscale('log')
plt.xlabel('size_factor')
plt.xticks(x, [str(x_val) for x_val in x])
plt.title('prunability_means')
plt.show()

slopes = [prunability_means[i+1] / prunability_means[i] for i in range(len(prunability_means)-1)]
g = sum(slopes) / len(slopes)
if g > 2:
    print(f'g={g} is greater than 2.0')
else:
    print(f'g={g} is less than or equal to 2.0')

In [None]:
# plot redundancy

crosses = 3
alexnet_units = [96, 256, 384, 192]
redundancy_means = []

for ID in range(2,7):

    opt = experiments.opt[ID]

    csv_file = experiments.opt[ID].csv_dir + experiments.opt[ID].name + '_redundancy.csv'
    df = pd.read_csv(csv_file)

    comp = np.zeros((crosses, len(alexnet_units)))
    for lyr in range(len(alexnet_units)):
        comp[:,lyr] =  np.array([df[df['cross_validation'] == i][df['evaluation_set'] == 'test'][df['layer'] == str(lyr)]
                                    ['compressability_95'].iloc[0] for i in range(crosses)])
    redundancy_means.append(np.mean(comp))

x = [0.25, 0.5, 1, 2, 4]

plt.figure()
plt.plot(x, redundancy_means, 'bo-')
plt.xscale('log')
plt.yscale('log')
plt.xlabel('size_factor')
plt.xticks(x, [str(x_val) for x_val in x])
plt.title('redundancy_means')
plt.show()

slopes = [redundancy_means[i+1] / redundancy_means[i] for i in range(len(redundancy_means)-1)]
g = sum(slopes) / len(slopes)
if g > 2:
    print(f'g={g} is greater than 2.0')
else:
    print(f'g={g} is less than or equal to 2.0')