/
visual_func.py
217 lines (184 loc) · 9.46 KB
/
visual_func.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import torch.nn.functional as nfunc
from numpy import linalg
import matplotlib.pyplot as plt
import matplotlib.patches as plot_patch
from ExpUtils import *
from torch_func.evaluate import evaluate_classifier
from Loss import VAT
def visualize_all(model, d_i, x_data, y_data, ul_x, ul_y, adv_x, idx, basis, val_loader, it, rate, args, save_filename='prob_cont'):
adv_y = ul_y[idx]
fig = plt.figure(12)
fig.set_size_inches(12, 6)
plt.suptitle('xVAT\nIteration %d' % it)
plt.subplot(1, 2, 1)
visualize_contour_semi(model, d_i, x_data, y_data, ul_x, ul_y, basis, val_loader, args, show_contour=False)
plt.subplot(1, 2, 2)
visualize_adv_points(model, d_i, x_data, y_data, ul_x[idx], ul_y[idx], adv_x, adv_y, basis, it, rate, args)
fig.savefig(save_filename)
plt.close()
def visualize_contour_semi(model, d_i, x_data, y_data, ul_x, ul_y, basis, val_loader, args, with_lds=False, save_filename='prob_cont', show_contour=True):
line_width = 10
range_x = np.arange(-2.0, 2.1, 0.05)
a_inv = linalg.inv(np.dot(basis, basis.T))
train_x_org = np.dot(x_data, np.dot(basis.T, a_inv))
test_x_org = np.zeros((range_x.shape[0] ** 2, 2))
train_x_1_ind = np.where(y_data == 1)[0]
train_x_0_ind = np.where(y_data == 0)[0]
ul_x_org = np.dot(ul_x, np.dot(basis.T, a_inv))
ul_x_1_ind = np.where(ul_y == 1)[0]
ul_x_0_ind = np.where(ul_y == 0)[0]
for i in range(range_x.shape[0]):
for j in range(range_x.shape[0]):
test_x_org[range_x.shape[0] * i + j, 0] = range_x[i]
test_x_org[range_x.shape[0] * i + j, 1] = range_x[j]
test_x = np.dot(test_x_org, basis)
model.eval()
f_p_y_given_x = model(torch.FloatTensor(test_x).to(args.device))
pred = nfunc.softmax(f_p_y_given_x, dim=1)[:, 1].cpu().detach().numpy()
z = np.zeros((range_x.shape[0], range_x.shape[0]))
for i in range(range_x.shape[0]):
for j in range(range_x.shape[0]):
z[i, j] = pred[range_x.shape[0] * i + j]
y, x = np.meshgrid(range_x, range_x)
font_size = 20
rc = 'r'
bc = 'b'
if d_i == "1":
rescale = 1.0 # /np.sqrt(500)
arc1 = plot_patch.Arc(xy=(0.5 * rescale, -0.25 * rescale), width=2.0 * rescale, height=2.0 * rescale, angle=0, theta1=270,
theta2=180, linewidth=line_width, alpha=0.15, color=rc)
arc2 = plot_patch.Arc(xy=(-0.5 * rescale, +0.25 * rescale), width=2.0 * rescale, height=2.0 * rescale, angle=0, theta1=90,
theta2=360, linewidth=line_width, alpha=0.15, color=bc)
fig = plt.gcf()
frame = fig.gca()
frame.add_artist(arc1)
frame.add_artist(arc2)
else:
rescale = 1.0 # /np.sqrt(500)
circle1 = plot_patch.Circle((0, 0), 1.0 * rescale, color=rc, alpha=0.2, fill=False, linewidth=line_width)
circle2 = plot_patch.Circle((0, 0), 0.15 * rescale, color=bc, alpha=0.2, fill=False, linewidth=line_width)
fig = plt.gcf()
frame = fig.gca()
frame.add_artist(circle1)
frame.add_artist(circle2)
plt.scatter(ul_x_org[ul_x_1_ind, 0] * rescale, ul_x_org[ul_x_1_ind, 1] * rescale, s=2, marker='o', c=rc, label='$y=1$')
plt.scatter(ul_x_org[ul_x_0_ind, 0] * rescale, ul_x_org[ul_x_0_ind, 1] * rescale, s=2, marker='o', c=bc, label='$y=0$')
plt.scatter(train_x_org[train_x_1_ind, 0] * rescale, train_x_org[train_x_1_ind, 1] * rescale, s=50, marker='o', c=rc, label='$y=1$',
edgecolor='black', linewidth=1)
plt.scatter(train_x_org[train_x_0_ind, 0] * rescale, train_x_org[train_x_0_ind, 1] * rescale, s=50, marker='o', c=bc, label='$y=0$',
edgecolor='black', linewidth=1)
err_num, loss = evaluate_classifier(model, val_loader, args.device)
err_rate = 1.0 * err_num / len(val_loader.dataset)
lds_part = ""
if with_lds:
eps = args.eps
args.eps = 0.5
args.k = 5
reg_component = VAT(args)
x_data = x_data.to(args.device)
ave_lds = 0
for t in range(20):
ave_lds += reg_component(model, x_data, kl_way=1)
ave_lds /= 20
lds_part = ' $\widetilde{\\rm LDS}=%.3f$' % ave_lds
args.k = 1
args.eps = eps
if save_filename is None:
plt.show(block=False)
else:
cs = None
levels = [0.05, 0.2, 0.35, 0.5, 0.65, 0.8, 0.95]
cs = plt.contour(x * rescale, y * rescale, z, 7, cmap='bwr', vmin=0., vmax=1.0, linewidths=8., levels=levels)
plt.setp(cs.collections, linewidth=1.0)
plt.contour(x * rescale, y * rescale, z, 1, cmap='binary', vmin=0, vmax=0.5, linewidths=2.0)
# plt.tight_layout()
# plt.savefig(save_filename)
if show_contour:
plt.title('%s\nError %g%s' % (args.exp_marker, err_rate, lds_part))
plt.xticks(fontsize=font_size)
plt.yticks(fontsize=font_size)
plt.xlim([-2. * rescale, 2. * rescale])
plt.ylim([-2. * rescale, 2. * rescale])
plt.xticks([-2.0, -1.0, 0, 1, 2.0], fontsize=font_size)
plt.yticks([-2.0, -1.0, 0, 1, 2.0], fontsize=font_size)
if show_contour:
color_bar = plt.colorbar(cs)
color_bar.ax.tick_params(labelsize=font_size)
def visualize_adv_points(model, d_i, x_data, y_data, ul_x, ul_y, adv_x, adv_y, basis, it, rate, args, save_filename='prob_cont', show_contour=True):
line_width = 10
range_x = np.arange(-2.0, 2.1, 0.05)
a_inv = linalg.inv(np.dot(basis, basis.T))
train_x_org = np.dot(x_data, np.dot(basis.T, a_inv))
test_x_org = np.zeros((range_x.shape[0] ** 2, 2))
train_x_1_ind = np.where(y_data == 1)[0]
train_x_0_ind = np.where(y_data == 0)[0]
ul_x_org = np.dot(ul_x, np.dot(basis.T, a_inv))
ul_x_1_ind = np.where(ul_y == 1)[0]
ul_x_0_ind = np.where(ul_y == 0)[0]
adv_x_org = np.dot(adv_x, np.dot(basis.T, a_inv))
adv_x_1_ind = np.where(adv_y == 1)[0]
adv_x_0_ind = np.where(adv_y == 0)[0]
for i in range(range_x.shape[0]):
for j in range(range_x.shape[0]):
test_x_org[range_x.shape[0] * i + j, 0] = range_x[i]
test_x_org[range_x.shape[0] * i + j, 1] = range_x[j]
test_x = np.dot(test_x_org, basis)
model.eval()
f_p_y_given_x = model(torch.FloatTensor(test_x).to(args.device))
pred = nfunc.softmax(f_p_y_given_x, dim=1)[:, 1].cpu().detach().numpy()
z = np.zeros((range_x.shape[0], range_x.shape[0]))
for i in range(range_x.shape[0]):
for j in range(range_x.shape[0]):
z[i, j] = pred[range_x.shape[0] * i + j]
y, x = np.meshgrid(range_x, range_x)
font_size = 20
rc = 'r'
bc = 'b'
if d_i == "1":
rescale = 1.0 # /np.sqrt(500)
arc1 = plot_patch.Arc(xy=(0.5 * rescale, -0.25 * rescale), width=2.0 * rescale, height=2.0 * rescale, angle=0, theta1=270,
theta2=180, linewidth=line_width, alpha=0.15, color=rc)
arc2 = plot_patch.Arc(xy=(-0.5 * rescale, +0.25 * rescale), width=2.0 * rescale, height=2.0 * rescale, angle=0, theta1=90,
theta2=360, linewidth=line_width, alpha=0.15, color=bc)
fig = plt.gcf()
frame = fig.gca()
frame.add_artist(arc1)
frame.add_artist(arc2)
else:
rescale = 1.0 # /np.sqrt(500)
circle1 = plot_patch.Circle((0, 0), 1.0 * rescale, color=rc, alpha=0.2, fill=False, linewidth=line_width)
circle2 = plot_patch.Circle((0, 0), 0.15 * rescale, color=bc, alpha=0.2, fill=False, linewidth=line_width)
fig = plt.gcf()
frame = fig.gca()
frame.add_artist(circle1)
frame.add_artist(circle2)
frame.axes.get_yaxis().set_visible(False)
colors = ['b', 'r', 'orange', 'purple', 'cyan', 'yellow', 'brown', 'y', 'c', 'lime', ]
plt.scatter(adv_x_org[adv_x_1_ind, 0] * rescale, adv_x_org[adv_x_1_ind, 1] * rescale, s=25, marker='o', c=colors, label='$y=1$')
plt.scatter(adv_x_org[adv_x_0_ind, 0] * rescale, adv_x_org[adv_x_0_ind, 1] * rescale, s=25, marker='o', c=colors, label='$y=0$')
plt.scatter(ul_x_org[ul_x_1_ind, 0] * rescale, ul_x_org[ul_x_1_ind, 1] * rescale, s=10, marker='o', c=colors, label='$y=1$')
plt.scatter(ul_x_org[ul_x_0_ind, 0] * rescale, ul_x_org[ul_x_0_ind, 1] * rescale, s=10, marker='o', c=colors, label='$y=0$')
gc = 'g'
grc = 'gray'
plt.scatter(train_x_org[train_x_1_ind, 0] * rescale, train_x_org[train_x_1_ind, 1] * rescale, s=50, marker='o', c=gc, label='$y=1$',
edgecolor='black', linewidth=1)
plt.scatter(train_x_org[train_x_0_ind, 0] * rescale, train_x_org[train_x_0_ind, 1] * rescale, s=50, marker='o', c=grc, label='$y=0$',
edgecolor='black', linewidth=1)
if save_filename is None:
plt.show(block=False)
else:
levels = [0.05, 0.2, 0.35, 0.5, 0.65, 0.8, 0.95]
cs = plt.contour(x * rescale, y * rescale, z, 7, cmap='bwr', vmin=0., vmax=1.0, linewidths=8., levels=levels)
plt.setp(cs.collections, linewidth=1.0)
plt.contour(x * rescale, y * rescale, z, 1, cmap='binary', vmin=0, vmax=0.5, linewidths=2.0)
plt.xticks(fontsize=font_size)
plt.yticks(fontsize=font_size)
plt.xlim([-2. * rescale, 2. * rescale])
plt.ylim([-2. * rescale, 2. * rescale])
plt.xticks([-2.0, -1.0, 0, 1, 2.0], fontsize=font_size)
plt.yticks([-2.0, -1.0, 0, 1, 2.0], fontsize=font_size)
fig.subplots_adjust(right=0.88)
cbar_ax = fig.add_axes([0.90, 0.05, 0.02, 0.7])
color_bar = plt.colorbar(cs, cbar_ax)
color_bar.ax.tick_params(labelsize=font_size)
# plt.savefig(save_filename)