-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot_merged_sim06.py
117 lines (85 loc) · 2.77 KB
/
plot_merged_sim06.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# %% Import packages
import numpy as np
from matplotlib import pyplot as plt
from pathlib import Path
from sim_setup import output_path, sim06_setup
# %% Create output path if it does not exist
output_path.mkdir(parents=True, exist_ok=True)
# %% Load numerical output
tron_error_loaded = np.loadtxt(output_path.joinpath(sim06_setup['name']+'_tron.csv'), delimiter=',')
sgd_error_loaded = np.loadtxt(output_path.joinpath(sim06_setup['name']+'_sgd.csv'), delimiter=',')
# %% Set font size
fontsize = 13
# %% Set transparency
transparent = False
# %% Set y axis limits and ticks
ylims = [
[-16.2, 2.2],
[-4.2, 1.2],
[-4.2, 1.2],
[-4.2, 1.2],
[-4.2, 1.2],
[-4.2, 1.2],
[-4.2, 1.2]
]
yticks = [
[-16, -14, -12, -10, -8, -6, -4, -2, 0, 2],
[-4, -3, -2, -1, 0, 1],
[-4, -3, -2, -1, 0, 1],
[-4, -3, -2, -1, 0, 1],
[-4, -3, -2, -1, 0, 1],
[-4, -3, -2, -1, 0, 1],
[-4, -3, -2, -1, 0, 1]
]
ylabels = [
['1e-16', '1e-14', '1e-12', '1e-10', '1e-8', '1e-6', '1e-4', '1e-2', '1e-0', '1e+2'],
['1e-4', '1e-3', '1e-2', '1e-1', '1e-0', '1e+1'],
['1e-4', '1e-3', '1e-2', '1e-1', '1e-0', '1e+1'],
['1e-4', '1e-3', '1e-2', '1e-1', '1e-0', '1e+1'],
['1e-4', '1e-3', '1e-2', '1e-1', '1e-0', '1e+1'],
['1e-4', '1e-3', '1e-2', '1e-1', '1e-0', '1e+1'],
['1e-4', '1e-3', '1e-2', '1e-1', '1e-0', '1e+1']
]
# %% Selection of theta values to plot
beta_vals = [1, 3, 4, 5, 6]
# %%
save = True
xrange = range(1, tron_error_loaded.shape[0]+1)
labels = ['Neuro-Tron', 'SGD']
fig, axes = plt.subplots(nrows=len(beta_vals), ncols=1, sharex=True, figsize=(8, 18))
plt.subplots_adjust(hspace = 0.15)
for i in range(len(beta_vals)):
axes[i].plot(
xrange,
np.log10(tron_error_loaded[:, beta_vals[i]]),
linewidth=2.,
label=labels[0]
)
axes[i].plot(
xrange,
np.log10(sgd_error_loaded[:, beta_vals[i]]),
linewidth=2.,
label=labels[1]
)
axes[i].set_ylim(ylims[beta_vals[i]])
axes[i].set_title(
r'$\beta$ = {}'.format(sim06_setup['betalist'][beta_vals[i]]), y=1.0, pad=-23, fontsize=fontsize
)
axes[i].set_yticks(yticks[beta_vals[i]])
axes[i].set_yticklabels(ylabels[beta_vals[i]], fontsize=fontsize)
axes[i].legend(labels=labels, loc='upper right', ncol=1, fontsize=fontsize, frameon=False)
xticks = np.linspace(0, 40000, num=9)
xticklabels = [str(round(i)) for i in xticks]
axes[4].set_xticks(xticks)
axes[4].set_xticklabels(xticklabels, rotation=0, fontsize=fontsize)
if save:
plt.savefig(
output_path.joinpath(
sim06_setup['name']+'_tron_vs_sgd_merged_beta_vals.png'
),
dpi=300,
pil_kwargs={'quality': 100},
transparent=transparent,
bbox_inches='tight',
pad_inches=0.1
)