In [1]:
import pickle, glob, json, os
import numpy as np
from matplotlib import pyplot as plt

## Figure 2

In [2]:
def load_res(data, t=1e-2, remove='sub'):
    eps = np.logspace(-5, 0, 21)
    eps = eps[eps >= t]
    if data == 'breast_cancer':
        if remove == 'sub':
            n_remove = [1, 4, 43] # 1, 1%, 10%
        else:
            n_remove = [1, 4, 13, 43, 130] # 1, 1%, 3%, 10%, 30%
    elif data == 'diabetes':
        if remove == 'sub':
            n_remove = [1, 4, 49] # 1, 1%, 10%
        else:
            n_remove = [1, 4, 14, 49, 147] # 1, 1%, 3%, 10%, 30%
    else:
        if remove == 'sub':
            n_remove = [1, 10, 100] # 1, 1%, 10%
        else:
            n_remove = [1, 10, 30, 100, 300] # 1, 1%, 3%, 10%, 30%
    #res_dir = '../res/%s/dist/' % (data,)
    res_dir = '/home/satohara/nas/StableTree/res_iclr23/%s/dist/' % (data,)
    dist, acc = {'greedy':[], 'stable':[]}, {'greedy':[], 'stable':[]}
    for n in n_remove:
        # greedy
        method = 'greedy'
        res = np.load('%sremove_%03d_%s.npz' % (res_dir, n, method))
        dist['greedy'].append(res['dist'][np.newaxis])
        acc['greedy'].append(res['acc'][np.newaxis])

        # stable
        dist_s, acc_s = [], []
        for e in eps:
            method = 'eps%08d' % (int(1e+6*e),)
            res = np.load('%sremove_%03d_%s.npz' % (res_dir, n, method))
            dist_s.append(res['dist'])
            acc_s.append(res['acc'])
        dist['stable'].append(np.array(dist_s))
        acc['stable'].append(np.array(acc_s))
    dist['greedy'] = np.array(dist['greedy'])
    dist['stable'] = np.array(dist['stable'])
    acc['greedy'] = np.array(acc['greedy'])
    acc['stable'] = np.array(acc['stable'])
    return eps, n_remove, dist, acc

In [11]:
fig_dir = '../fig/eps'
os.makedirs(fig_dir, exist_ok=True)

for data in ['covtype', 'webspam', 'diabetes', 'breast_cancer', 'ijcnn', 'cod-rna', 'diabetes', 'sensorless', 'higgs']:
    eps, n_remove, dist, acc = load_res(data, t=1e-2, remove='sub')

    dist_avg_g = np.mean(dist['greedy'], axis=(2,3))
    dist_avg_s = np.mean(dist['stable'], axis=(2,3))
    acc_avg_g = np.mean(acc['greedy'], axis=(2,))
    acc_avg_s = np.mean(acc['stable'], axis=(2,))
    depth_opt = np.argmax(acc_avg_g[0, 0, :, 1])
    print(data, depth_opt+1)
    n_nodes = 2**(depth_opt + 2)

    # average sensitivity
    plt.figure()
    s = ['o', '^', 'v']
    for i, k in enumerate(['1', '1%', '10%']):
        r = n_remove[i]
        plt.loglog([eps[0], eps[-1]], [dist_avg_g[i, 0, depth_opt, 0] / (2 * n_nodes)]*2, 'r%s--' % (s[i],))
        plt.loglog(eps, dist_avg_s[i, :, depth_opt, 0] / (2 * n_nodes), 'b%s-' % (s[i],), ms=12, label='remove=%s' % (k,))
    yl = plt.gca().get_ylim()
    plt.ylim([10**np.floor(np.log10(yl[0])), 1e+0])
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    plt.xlabel('$\epsilon$', fontsize=24)
    plt.ylabel('\n Average Sensitivity', fontsize=24)
    plt.tight_layout()
    plt.savefig('%s/%s_sensitivity.pdf' % (fig_dir, data))

    plt.clf()
    plt.close(plt.gcf())

    # average sensitivity - identical threshold = False
    plt.figure()
    s = ['o', '^', 'v']
    for i, k in enumerate(['1', '1%', '10%']):
        r = n_remove[i]
        plt.loglog([eps[0], eps[-1]], [dist_avg_g[i, 0, depth_opt, 1] / (2 * n_nodes)]*2, 'r%s--' % (s[i],))
        plt.loglog(eps, dist_avg_s[i, :, depth_opt, 1] / (2 * n_nodes), 'b%s-' % (s[i],), ms=12, label='remove=%s' % (k,))
    yl = plt.gca().get_ylim()
    plt.ylim([10**np.floor(np.log10(yl[0])), 1e+0])
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    plt.xlabel('$\epsilon$', fontsize=24)
    plt.ylabel('\n Average Sensitivity', fontsize=24)
    plt.tight_layout()
    plt.savefig('%s/%s_sensitivity_false.pdf' % (fig_dir, data))

    plt.clf()
    plt.close(plt.gcf())

    # training accuracy
    plt.figure()
    s = ['o', '^', 'v']
    for i, k in enumerate(['1', '1%', '10%']):
        r = n_remove[i]
        plt.semilogx([eps[0], eps[-1]], [acc_avg_g[i, 0, depth_opt, 0]]*2, 'r%s--' % (s[i],))
        plt.semilogx(eps, acc_avg_s[i, :, depth_opt, 0], 'b%s-' % (s[i],), ms=12, label='remove=%s' % (k,))
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    plt.xlabel('$\epsilon$', fontsize=24)
    plt.ylabel('Average\n Training Accuracy', fontsize=24)
    plt.tight_layout()
    plt.savefig('%s/%s_train.pdf' % (fig_dir, data))

    plt.clf()
    plt.close(plt.gcf())

    # test accuracy
    plt.figure()
    s = ['o', '^', 'v']
    for i, k in enumerate(['1', '1%', '10%']):
        r = n_remove[i]
        plt.semilogx([eps[0], eps[-1]], [acc_avg_g[i, 0, depth_opt, 1]]*2, 'r%s--' % (s[i],))
        plt.semilogx(eps, acc_avg_s[i, :, depth_opt, 1], 'b%s-' % (s[i],), ms=12, label='remove=%s' % (k,))
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    plt.xlabel('$\epsilon$', fontsize=24)
    plt.ylabel('Average\n Test Accuracy', fontsize=24)
    plt.tight_layout()
    plt.savefig('%s/%s_test.pdf' % (fig_dir, data))

    plt.clf()
    plt.close(plt.gcf())

covtype 7
webspam 7
diabetes 1
breast_cancer 5
ijcnn 1
cod-rna 10
diabetes 1
sensorless 9
higgs 3


## Figure 2 with errorbar

In [4]:
fig_dir = '../fig/eps_error'
os.makedirs(fig_dir, exist_ok=True)

for data in ['covtype', 'webspam', 'diabetes', 'breast_cancer', 'ijcnn', 'cod-rna', 'diabetes', 'sensorless', 'higgs']:
    eps, n_remove, dist, acc = load_res(data, t=1e-2, remove='sub')

    dist_shape_g = dist['greedy'].shape
    dist_shape_g = (dist_shape_g[0], dist_shape_g[1], 10, 10, dist_shape_g[3], dist_shape_g[4], dist_shape_g[5])
    dist_avg_g = np.mean(dist['greedy'].reshape(*dist_shape_g), axis=(3,4))
    dist_avg_g_avg = np.mean(dist_avg_g, axis=2)
    dist_avg_g_low = np.percentile(dist_avg_g, 25, axis=2)
    dist_avg_g_high = np.percentile(dist_avg_g, 75, axis=2)
    dist_shape_s = dist['stable'].shape
    dist_shape_s = (dist_shape_s[0], dist_shape_s[1], 10, 10, dist_shape_s[3], dist_shape_s[4], dist_shape_s[5])
    dist_avg_s = np.mean(dist['stable'].reshape(*dist_shape_s), axis=(3,4))
    dist_avg_s_avg = np.mean(dist_avg_s, axis=2)
    dist_avg_s_low = np.percentile(dist_avg_s, 25, axis=2)
    dist_avg_s_high = np.percentile(dist_avg_s, 75, axis=2)
    #acc_avg_g = np.mean(acc['greedy'], axis=(2,))
    #acc_avg_s = np.mean(acc['stable'], axis=(2,))
    acc_shape_g = acc['greedy'].shape
    acc_shape_g = (acc_shape_g[0], acc_shape_g[1], 10, 10, acc_shape_g[3], acc_shape_g[4])
    acc_avg_g = np.mean(acc['greedy'].reshape(*acc_shape_g), axis=3)
    acc_avg_g_avg = np.mean(acc_avg_g, axis=2)
    acc_avg_g_low = np.percentile(acc_avg_g, 25, axis=2)
    acc_avg_g_high = np.percentile(acc_avg_g, 75, axis=2)
    acc_shape_s = acc['stable'].shape
    acc_shape_s = (acc_shape_s[0], acc_shape_s[1], 10, 10, acc_shape_s[3], acc_shape_s[4])
    acc_avg_s = np.mean(acc['stable'].reshape(*acc_shape_s), axis=3)
    acc_avg_s_avg = np.mean(acc_avg_s, axis=2)
    acc_avg_s_low = np.percentile(acc_avg_s, 25, axis=2)
    acc_avg_s_high = np.percentile(acc_avg_s, 75, axis=2)
    depth_opt = np.argmax(acc_avg_g_avg[0, 0, :, 1])
    print(data, depth_opt+1)
    n_nodes = 2**(depth_opt + 2)

    # average sensitivity
    s = ['o', '^', 'v']
    for i, k in enumerate(['1', '1%', '10%']):
        plt.figure()
        r = n_remove[i]
        plt.loglog([eps[0], eps[-1]], [dist_avg_g_avg[i, 0, depth_opt, 0] / (2 * n_nodes)]*2, 'r%s--' % (s[i],))
        plt.fill_between([eps[0], eps[-1]], [dist_avg_g_low[i, 0, depth_opt, 0] / (2 * n_nodes)]*2, [dist_avg_g_high[i, 0, depth_opt, 0] / (2 * n_nodes)]*2, alpha=0.2, facecolor='r')
        plt.loglog(eps, dist_avg_s_avg[i, :, depth_opt, 0] / (2 * n_nodes), 'b%s-' % (s[i],), ms=12)
        plt.fill_between(eps, dist_avg_s_low[i, :, depth_opt, 0] / (2 * n_nodes), dist_avg_s_high[i, :, depth_opt, 0] / (2 * n_nodes), alpha=0.2, facecolor='b')
        yl = plt.gca().get_ylim()
        plt.ylim([10**np.floor(np.log10(yl[0])), 1e+0])
        plt.xticks(fontsize=18)
        plt.yticks(fontsize=18)
        plt.xlabel('$\epsilon$', fontsize=24)
        plt.ylabel('\n Average Sensitivity', fontsize=24)
        plt.tight_layout()
        plt.savefig('%s/%s_sensitivity_%s.pdf' % (fig_dir, data, k.replace('%', 'p')))

        plt.clf()
        plt.close(plt.gcf())

    # average sensitivity - identical_threshold = False
    s = ['o', '^', 'v']
    for i, k in enumerate(['1', '1%', '10%']):
        plt.figure()
        r = n_remove[i]
        plt.loglog([eps[0], eps[-1]], [dist_avg_g_avg[i, 0, depth_opt, 1] / (2 * n_nodes)]*2, 'r%s--' % (s[i],))
        plt.fill_between([eps[0], eps[-1]], [dist_avg_g_low[i, 0, depth_opt, 1] / (2 * n_nodes)]*2, [dist_avg_g_high[i, 0, depth_opt, 1] / (2 * n_nodes)]*2, alpha=0.2, facecolor='r')
        plt.loglog(eps, dist_avg_s_avg[i, :, depth_opt, 1] / (2 * n_nodes), 'b%s-' % (s[i],), ms=12)
        plt.fill_between(eps, dist_avg_s_low[i, :, depth_opt, 1] / (2 * n_nodes), dist_avg_s_high[i, :, depth_opt, 1] / (2 * n_nodes), alpha=0.2, facecolor='b')
        yl = plt.gca().get_ylim()
        plt.ylim([10**np.floor(np.log10(yl[0])), 1e+0])
        plt.xticks(fontsize=18)
        plt.yticks(fontsize=18)
        plt.xlabel('$\epsilon$', fontsize=24)
        plt.ylabel('\n Average Sensitivity', fontsize=24)
        plt.tight_layout()
        plt.savefig('%s/%s_sensitivity_%s_false.pdf' % (fig_dir, data, k.replace('%', 'p')))

        plt.clf()
        plt.close(plt.gcf())

    # training accuracy
    s = ['o', '^', 'v']
    for i, k in enumerate(['1', '1%', '10%']):
        plt.figure()
        r = n_remove[i]
        plt.semilogx([eps[0], eps[-1]], [acc_avg_g_avg[i, 0, depth_opt, 0] / (2 * n_nodes)]*2, 'r%s--' % (s[i],))
        plt.fill_between([eps[0], eps[-1]], [acc_avg_g_low[i, 0, depth_opt, 0] / (2 * n_nodes)]*2, [acc_avg_g_high[i, 0, depth_opt, 0] / (2 * n_nodes)]*2, alpha=0.2, facecolor='r')
        plt.semilogx(eps, acc_avg_s_avg[i, :, depth_opt, 0] / (2 * n_nodes), 'b%s-' % (s[i],), ms=12)
        plt.fill_between(eps, acc_avg_s_low[i, :, depth_opt, 0] / (2 * n_nodes), acc_avg_s_high[i, :, depth_opt, 0] / (2 * n_nodes), alpha=0.2, facecolor='b')
        plt.xticks(fontsize=18)
        plt.yticks(fontsize=18)
        plt.xlabel('$\epsilon$', fontsize=24)
        plt.ylabel('Average\n Training Accuracy', fontsize=24)
        plt.tight_layout()
        plt.savefig('%s/%s_train_%s.pdf' % (fig_dir, data, k.replace('%', 'p')))

        plt.clf()
        plt.close(plt.gcf())

    # test accuracy
    s = ['o', '^', 'v']
    for i, k in enumerate(['1', '1%', '10%']):
        plt.figure()
        r = n_remove[i]
        plt.semilogx([eps[0], eps[-1]], [acc_avg_g_avg[i, 0, depth_opt, 1] / (2 * n_nodes)]*2, 'r%s--' % (s[i],))
        plt.fill_between([eps[0], eps[-1]], [acc_avg_g_low[i, 0, depth_opt, 1] / (2 * n_nodes)]*2, [acc_avg_g_high[i, 0, depth_opt, 1] / (2 * n_nodes)]*2, alpha=0.2, facecolor='r')
        plt.semilogx(eps, acc_avg_s_avg[i, :, depth_opt, 1] / (2 * n_nodes), 'b%s-' % (s[i],), ms=12)
        plt.fill_between(eps, acc_avg_s_low[i, :, depth_opt, 1] / (2 * n_nodes), acc_avg_s_high[i, :, depth_opt, 1] / (2 * n_nodes), alpha=0.2, facecolor='b')
        plt.xticks(fontsize=18)
        plt.yticks(fontsize=18)
        plt.xlabel('$\epsilon$', fontsize=24)
        plt.ylabel('Average\n Test Accuracy', fontsize=24)
        plt.tight_layout()
        plt.savefig('%s/%s_test_%s.pdf' % (fig_dir, data, k.replace('%', 'p')))

        plt.clf()
        plt.close(plt.gcf())

covtype 7
webspam 7
diabetes 1
breast_cancer 5
ijcnn 1
cod-rna 10
diabetes 1
sensorless 9
higgs 3


## Figure 4

In [12]:
fig_dir = '../fig/tradeoff_train'
os.makedirs(fig_dir, exist_ok=True)

for data in ['covtype', 'webspam', 'diabetes', 'breast_cancer', 'ijcnn', 'cod-rna', 'diabetes', 'sensorless', 'higgs']:
    eps, n_remove, dist, acc = load_res(data, t=1e-5, remove='full')

    dist_avg_g = np.mean(dist['greedy'], axis=(2,3))
    dist_avg_s = np.mean(dist['stable'], axis=(2,3))
    acc_avg_g = np.mean(acc['greedy'], axis=(2,))
    acc_avg_s = np.mean(acc['stable'], axis=(2,))
    depth_opt = np.argmax(acc_avg_g[0, 0, :, 1])
    n_nodes = 2**(depth_opt + 2)

    # plot
    plt.figure()
    cl = ['bo', 'g^', 'rv', 'ms', 'cd']
    ratio = ['1', '1%', '3%', '10%', '30%']
    for i, (r, c) in enumerate(zip(ratio, cl)):
        plt.plot(dist_avg_s[i, :, depth_opt, 0]/(2 * n_nodes), acc_avg_s[i, :, depth_opt, 0], c+'-', label='remove = %s' % (r,), ms=10)
    for i, c in enumerate(cl):
        plt.plot(dist_avg_g[i, 0, depth_opt, 0]/(2 * n_nodes), acc_avg_g[i, 0, depth_opt, 0], 'w'+c[1], markeredgecolor='k', markeredgewidth=2, ms=15)
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    plt.xlabel('Average Sensitivity', fontsize=24)
    plt.ylabel('Average\n Training Accuracy', fontsize=24)
    plt.tight_layout()
    plt.savefig('%s/%s.pdf' % (fig_dir, data))

    plt.clf()
    plt.close(plt.gcf())

    # plot
    plt.figure()
    cl = ['bo', 'g^', 'rv', 'ms', 'cd']
    ratio = ['1', '1%', '3%', '10%', '30%']
    for i, (r, c) in enumerate(zip(ratio, cl)):
        plt.plot(dist_avg_s[i, :, depth_opt, 1]/(2 * n_nodes), acc_avg_s[i, :, depth_opt, 0], c+'-', label='remove = %s' % (r,), ms=10)
    for i, c in enumerate(cl):
        plt.plot(dist_avg_g[i, 0, depth_opt, 1]/(2 * n_nodes), acc_avg_g[i, 0, depth_opt, 0], 'w'+c[1], markeredgecolor='k', markeredgewidth=2, ms=15)
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    plt.xlabel('Average Sensitivity', fontsize=24)
    plt.ylabel('Average\n Training Accuracy', fontsize=24)
    plt.tight_layout()
    plt.savefig('%s/%s_false.pdf' % (fig_dir, data))

    plt.clf()
    plt.close(plt.gcf())

## Figure 5

In [13]:
fig_dir = '../fig/tradeoff_test'
os.makedirs(fig_dir, exist_ok=True)

for data in ['covtype', 'webspam', 'diabetes', 'breast_cancer', 'ijcnn', 'cod-rna', 'diabetes', 'sensorless', 'higgs']:
    eps, n_remove, dist, acc = load_res(data, t=1e-5, remove='full')

    dist_avg_g = np.mean(dist['greedy'], axis=(2,3))
    dist_avg_s = np.mean(dist['stable'], axis=(2,3))
    acc_avg_g = np.mean(acc['greedy'], axis=(2,))
    acc_avg_s = np.mean(acc['stable'], axis=(2,))
    depth_opt = np.argmax(acc_avg_g[0, 0, :, 1])
    n_nodes = 2**(depth_opt + 2)

    # plot
    plt.figure()
    cl = ['bo', 'g^', 'rv', 'ms', 'cd']
    ratio = ['1', '1%', '3%', '10%', '30%']
    for i, (r, c) in enumerate(zip(ratio, cl)):
        plt.plot(dist_avg_s[i, :, depth_opt, 0]/(2 * n_nodes), acc_avg_s[i, :, depth_opt, 1], c+'-', label='remove = %s' % (r,), ms=10)
    for i, c in enumerate(cl):
        plt.plot(dist_avg_g[i, 0, depth_opt, 0]/(2 * n_nodes), acc_avg_g[i, 0, depth_opt, 1], 'w'+c[1], markeredgecolor='k', markeredgewidth=2, ms=15)
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    plt.xlabel('Average Sensitivity', fontsize=24)
    plt.ylabel('Average\n Test Accuracy', fontsize=24)
    plt.tight_layout()
    plt.savefig('%s/%s.pdf' % (fig_dir, data))

    plt.clf()
    plt.close(plt.gcf())

    # plot
    plt.figure()
    cl = ['bo', 'g^', 'rv', 'ms', 'cd']
    ratio = ['1', '1%', '3%', '10%', '30%']
    for i, (r, c) in enumerate(zip(ratio, cl)):
        plt.plot(dist_avg_s[i, :, depth_opt, 1]/(2 * n_nodes), acc_avg_s[i, :, depth_opt, 1], c+'-', label='remove = %s' % (r,), ms=10)
    for i, c in enumerate(cl):
        plt.plot(dist_avg_g[i, 0, depth_opt, 1]/(2 * n_nodes), acc_avg_g[i, 0, depth_opt, 1], 'w'+c[1], markeredgecolor='k', markeredgewidth=2, ms=15)
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    plt.xlabel('Average Sensitivity', fontsize=24)
    plt.ylabel('Average\n Test Accuracy', fontsize=24)
    plt.tight_layout()
    plt.savefig('%s/%s_false.pdf' % (fig_dir, data))

    plt.clf()
    plt.close(plt.gcf())