-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo4ExpA1_AGPvsRank_R28.py
351 lines (311 loc) · 16.7 KB
/
demo4ExpA1_AGPvsRank_R28.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
import datetime
import os
import threading
import time
os.environ['CUDA_VISIBLE_DEVICES'] = '6'
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from models import facTNNdct3Layers
from layers import t_product_dct
from functions import MNIST37
from functions import f_seconds2sentence
import matplotlib.pyplot as plt
import numpy as np
def main():
# para for task
num_classes = 2
# for training
batch_size = 30
learning_rate = 0.001
num_epochs = 20
# budget of adversarial attack
epsilon_adv = 20/255
# para for model
num_channel = 28
dim_input = 28
dim_latent = 28
v_rank_w = torch.tensor(range(3, 25, 3))
#----------
rank_w = 28
num_runs = 20
#----------
bound_norm_w = 1e+5
lambd_norm_bound = 1e+5
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
"""
Load MNIST dataset and create new datasets containing only the data
for digits 3 and 7 for binary classification
"""
# Load the original MNIST dataset
mnist_train = MNIST(root='../data', train=True, transform=ToTensor(), download=True)
mnist_test = MNIST(root='../data', train=False, transform=ToTensor(), download=True)
v_train_sizes = torch.tensor([500,1000, 1500, 2000, 2500,3000,3500,4000,4500, 5000])
def f_subset_data_loaders(mnist_train, mnist_test, train_size,batch_size):
# Create a subset of data based on the specified training size
train_data_subset = torch.utils.data.Subset(MNIST37(mnist_train), range(0, train_size))
train_loader = DataLoader(train_data_subset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(MNIST37(mnist_test), batch_size=batch_size, shuffle=False)
return train_loader, test_loader
"""
Adversarial attack based on FGSM (Fast Gradient Sign Method).
Adversarial samples are generated by adding perturbations on the input data's pixels
according to the gradient direction, causing the model to produce incorrect predictions.
"""
# Define the adversarial training loss function
def f_fgsm_attack(model, images, labels, epsilon):
images.requires_grad = True
outputs = model(images)
loss = nn.BCELoss()(outputs, labels.float().view(-1, 1))
model.zero_grad()
loss.backward()
# Adversarial sample generation: perturb input pixels along the gradient direction
images_adv = images + epsilon * images.grad.sign()
images_adv = torch.clamp(images_adv, 0, 1) # Keep pixel values within [0, 1]
return images_adv
# use regularization to softly bound the weights norm
def f_custom_loss(criterion,output, target, model, B, lambd):
loss = criterion(output, target)
# Regularization term
for n in range(1, 4):
W1 = getattr(model, f'layer{2*n-1}').weight
W2 = getattr(model, f'layer{2*n}').weight
norm = torch.norm(t_product_dct(W2, W1))
if norm > B:
loss += lambd * (norm - B) ** 2 # Quadratic penalty for exceeding B
w = getattr(model,'layer7').weight
norm = torch.norm(w)
if norm > B:
loss += lambd * (norm - B) ** 2 # Quadratic penalty for exceeding B
return loss
def f_prod_weight_norm(model):
prod = 1
for n in range(1, 4):
W1 = getattr(model, f'layer{2*n-1}').weight
W2 = getattr(model, f'layer{2*n}').weight
norm = torch.norm(t_product_dct(W2, W1))
prod *= norm
w = getattr(model,'layer7').weight
norm = torch.norm(w)
prod *= norm
return prod
"""
Conduct experiments by setting different rank parameters
"""
time_a = time.time()
results_all_runs = {}
num_train_sizes = v_train_sizes.numel()
m_product_weight_norm = torch.zeros(num_train_sizes,num_runs)
m_generalization_gap_std = torch.zeros(num_train_sizes,num_runs)
m_generalization_gap_adv = torch.zeros(num_train_sizes,num_runs)
for iter_run in range(num_runs):
results_in_run = {}
num_train_sizes = v_train_sizes.numel()
v_product_weight_norm = torch.zeros(num_train_sizes)
v_generalization_gap_std = torch.zeros(num_train_sizes)
v_generalization_gap_adv = torch.zeros(num_train_sizes)
for iter_train_size in range(num_train_sizes):
# Define the binary classification neural network model
train_size = v_train_sizes[iter_train_size]
train_loader, test_loader = f_subset_data_loaders(mnist_train, mnist_test, train_size,batch_size)
model = facTNNdct3Layers(num_channel,dim_input,dim_latent,rank_w,num_classes).to(device)
# Loss and optimizer
criterion = nn.BCELoss().to(device) # Use binary cross-entropy loss for binary classification task
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# record the empirical risks and sample size
risk_emp_std = 0.0
risk_emp_adv = 0.0
sample_size = 0
# Train the model
for epoch in range(num_epochs):
num_correct_adv = 0
num_examples = 0
total_loss_std = 0.0
total_loss_adv = 0.0
for iter_batch, (images, labels) in enumerate(train_loader):
images = images.float().to(device)
labels = labels.float().to(device)
optimizer.zero_grad()
# Forward pass
outputs_std = model(images).squeeze()
# compute standard loss
loss = criterion(outputs_std, labels)
# compute standard loss and regularization of this batch
loss_reg = f_custom_loss(criterion,outputs_std,labels,model,bound_norm_w,lambd_norm_bound)
# compute batch-version gradients of weights in the model
loss_reg.backward() # Question: Is this step necessary?
# use fgsm to generate adversarial examples
images_adv = f_fgsm_attack(model, images, labels, epsilon_adv)
# move attacked images to device
images_adv = images_adv.to(device)
# the predicted probability of 0 and 1
outputs_adv = model(images_adv).squeeze()
# compute the adversarial loss of this batch
loss_adv = criterion(outputs_adv,labels)
# compute the adversarial loss and regularization of this batch
loss_reg_adv = f_custom_loss(criterion,outputs_adv,labels,model,bound_norm_w,lambd_norm_bound)
# backward to compute the gradient
loss_reg_adv.backward()
# optimize to update the weights
optimizer.step()
# the predicted classes of the attacked images
predicted_adv = torch.round(outputs_adv)
# number of total classfied training examples
num_examples += labels.size(0)
# number of correctly classfied attacked training examples
num_correct_adv += (predicted_adv == labels).sum().item()
# total std loss on training set: not used, but interesting
total_loss_std += loss.item()
# total adv loss on training set
total_loss_adv += loss_adv.item()
if (iter_batch+1) % 20 == 0:
time_z = time.time() - time_a
time_r = num_runs*num_train_sizes*num_epochs*len(train_loader)
time_r = time_r/(iter_run*num_train_sizes*num_epochs*len(train_loader)+iter_train_size*num_epochs*len(train_loader) + epoch*len(train_loader) + iter_batch +1) -1
time_r = time_r*time_z
print("Used time:"+f_seconds2sentence(time_z)+" remaining:"+f_seconds2sentence(time_r))
print(f'Run[{iter_run+1}/{num_runs}] Train Size [{iter_train_size+1}/{num_train_sizes}]\n Epoch [{epoch+1}/{num_epochs}], Step [{iter_batch+1}/{len(train_loader)}], Adv Loss: {loss_adv.item()}\n')
# accuray under adv attack on the training data
accuracy_adv = num_correct_adv / num_examples
# update empirical standard risk
risk_emp_std = total_loss_std / num_examples
# update empirical adversarial risk
risk_emp_adv = total_loss_adv / num_examples
# record the sample size
sample_size = num_examples
print(f'Run[{iter_run+1}/{num_runs}] Train Size [{iter_train_size+1}/{num_train_sizes}]\n Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, AdvLoss: {loss_adv.item():.4f}, Adv Accuracy: {accuracy_adv*100:.2f}%\n')
# Test the model
# record expected standard risk
risk_exp_std = 0.0
# record expected adv risk
risk_exp_adv = 0.0
#with torch.no_grad():
# number of clean examples correctly classified
num_correct_std = 0
# number of attacked examples correctly classified
num_correct_adv = 0
# number of test examples
num_examples_test = 0
# total standard test loss of the adversarially trained model
total_test_loss_std = 0.0
# total adversarial test loss of the adversarially trained model
total_test_loss_adv = 0.0
for images, labels in test_loader:
images = images.float().to(device)
labels = labels.float().to(device)
# the sigmoid output of the clean images
outputs_std = model(images).squeeze()
# current number of test examples
num_examples_test += labels.size(0)
# predicted classes of the clean images by the adversarially trained model
predicted_std = torch.round(outputs_std)
# update number of correctly classified clean examples
num_correct_std += (predicted_std == labels).sum().item()
# compute standard test loss on this test batch
loss_std = criterion(outputs_std, labels)
# update the total standard test loss
total_test_loss_std += loss_std
# attack the clean test images by fgsm
images_adv = f_fgsm_attack(model, images, labels, epsilon_adv)
# move the attacked images to device
images_adv = images_adv.to(device)
# output of model input by the attacked images
outputs_adv = model(images_adv).squeeze()
# the adversarial test loss of the current batch
loss_adv = criterion(outputs_adv, labels)
# update the total adversarial loss on the test data
total_test_loss_adv += loss_adv
# predicted classes of the attacked images in this test batch
predicted_adv = torch.round(outputs_adv)
# update the number of correctly adversarially classified images
num_correct_adv += (predicted_adv == labels).sum().item()
print(f'Rank {rank_w}, Run[{iter_run+1}/{num_runs}], Train Size [{iter_train_size+1}/{num_train_sizes}]\n Adversarial Test Accuracy of the network on the 2000 test images: {100 * num_correct_adv / num_examples_test} %')
current_time = datetime.datetime.now()
print("Current time:", current_time)
# # the expected stadard risk
# risk_exp_std += (risk_emp_std * sample_size + total_test_loss_std)/(sample_size + num_examples_test)
# # the expected adv risk
# risk_exp_adv += (risk_emp_adv*sample_size+total_test_loss_adv) / (sample_size + num_examples_test)
# the expected stadard risk
risk_exp_std += total_test_loss_std/num_examples_test
# the expected adv risk
risk_exp_adv += total_test_loss_adv / num_examples_test
# compute the stanard generalization gap
generalization_gap_std = risk_exp_std - risk_emp_std
v_generalization_gap_std[iter_train_size] = generalization_gap_std
m_generalization_gap_std[iter_train_size,iter_run] = generalization_gap_std
# compute the adv generalization gap
generalization_gap_adv = risk_exp_adv - risk_emp_adv
v_generalization_gap_adv[iter_train_size] = generalization_gap_adv
m_generalization_gap_adv[iter_train_size,iter_run]=generalization_gap_adv
# compute the product of norms of the weights in the model
prod_weight_norm = f_prod_weight_norm(model)
v_product_weight_norm[iter_train_size] = prod_weight_norm
m_product_weight_norm[iter_train_size,iter_run] = prod_weight_norm
# results
results_in_run[train_size.item()] = {
#'model': model, # do not save model
'rank_w': rank_w,
'iter_run': iter_run,
'num_runs': num_runs,
'num_epochs':num_epochs,
'v_train_sizes': v_train_sizes,
'iter_train_sizes': iter_train_size,
'v_generalization_gap_std': v_generalization_gap_std,
'v_generalization_gap_adv': v_generalization_gap_adv,
'v_prod_weight_norm':v_product_weight_norm,
'm_generalization_gap_std': m_generalization_gap_std,
'm_generalization_gap_adv': m_generalization_gap_adv,
'm_prod_weight_norm':m_product_weight_norm,
}
if iter_train_size > 0:
x = range(iter_train_size+1)
# average bound vs N^{-0.5}
y = v_generalization_gap_adv.detach().numpy()[x]
plt.clf()
plt.plot(v_train_sizes[x]**(-0.5), y, label='Line', color='blue')
plt.scatter(v_train_sizes[x]**(-0.5), y, label='Point', color='red')
plt.xlabel(r'$\frac{1}{\sqrt{N}}$')
plt.ylabel('Adversarial generalization gap')
# plt.legend()
plt.grid(True)
plt.savefig(f'AGP-invSqrtN-r{rank_w}-ep{num_epochs}-run{iter_run}-new.png')
# bound vs N
plt.clf()
plt.plot(v_train_sizes[x], y, label='Line', color='blue')
plt.scatter(v_train_sizes[x], y, label='Point', color='red')
plt.xlabel(r'$N$')
plt.ylabel('Adversarial generalization gap')
# plt.legend()
plt.grid(True)
plt.savefig(f'AGP-vs-N-r{rank_w}-ep{num_epochs}-run{iter_run}-new.png')
# if iter_run>0:
# # average bound vs N^{-0.5}
# y_avg = np.mean(m_generalization_gap_adv.detach().numpy()[x,range(iter_run+1)],axis=1)
# plt.clf()
# plt.plot(v_train_sizes[x]**(-0.5),y_avg, label='Line', color='blue')
# plt.scatter(v_train_sizes[x]**(-0.5), y_avg, label='Point', color='red')
# plt.xlabel(r'$\frac{1}{\sqrt{N}}$')
# plt.ylabel('Adversarial generalization gap')
# # plt.legend()
# plt.grid(True)
# plt.savefig(f'AGP-invSqrtN-r{rank_w}-ep{num_epochs}-run{iter_run}-avg.png')
# # average bound vs N
# plt.clf()
# plt.plot(v_train_sizes[x], y_avg, label='Line', color='blue')
# plt.scatter(v_train_sizes[x], y_avg, label='Point', color='red')
# plt.xlabel(r'$N$')
# plt.ylabel('Adversarial generalization gap')
# # plt.legend()
# plt.grid(True)
# plt.savefig(f'AGP-vs-N-r{rank_w}-ep{num_epochs}-run{iter_run}-avg.png')
torch.save(results_in_run, f'checkpoints/New-AGP-N={train_size}-r={rank_w}-ep{num_epochs}-run{iter_run}-new.pth')
if __name__ == "__main__":
num_threads = threading.active_count()
print(f"number of current threads: {num_threads}")
torch.set_num_threads(2)
main()