-
Notifications
You must be signed in to change notification settings - Fork 2
/
main_plot_overall.py
37 lines (30 loc) · 1.46 KB
/
main_plot_overall.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import examples.problems as pblm
from examples.heatmap import *
import pandas as pd
import os
from matplotlib.ticker import FuncFormatter
plt.switch_backend('agg')
plt.rcdefaults()
if __name__ == '__main__':
# robust heatmap for MNIST
num_classes = 10 # in total, 10 digit classes for MNIST
args = pblm.argparser(prefix='mnist', method='overall_robust',
opt='adam', starting_epsilon=0.05, epsilon=0.2)
labels = ['digit ' + x for x in list(map(str, range(num_classes)))]
filepath = ('results/'+args.proctitle+'_robustProbs.csv')
# load the pairwise robust error matrix
df = pd.read_csv(filepath, sep='\t', skiprows=[0], nrows=num_classes,
usecols=np.arange(1,num_classes+1), header=None)
robust_prob_mat = df.applymap(lambda x: float(x.strip('%'))).values
for i in range(num_classes):
robust_prob_mat[i,i] = np.nan
robust_prob_mat = np.ma.masked_invalid(robust_prob_mat)
fig, ax = plt.subplots()
ticks = np.arange(0, np.nanmax(robust_prob_mat), step=2.0, dtype=float)
im = heatmap(robust_prob_mat, labels, labels, ax=ax, cmap="OrRd",
cbarlabel="Robust error rate", cbar_kw={'ticks':ticks, 'format':'%.0f%%'})
texts = annotate_heatmap(im, data=robust_prob_mat, valfmt="{x:.1f}%", fontsize=8)
fig.tight_layout()
save_filepath = 'results/'+os.path.dirname(args.proctitle)+'/robust_heatmap.pdf'
fig.savefig(save_filepath)
plt.savefig('test.png')