Skip to content

Commit

Permalink
🐛 fra31#70
Browse files Browse the repository at this point in the history
  • Loading branch information
userElaina committed May 21, 2024
1 parent 65e19aa commit 7c58c37
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion autoattack/autoattack.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
class AutoAttack():
def __init__(self, model, norm='Linf', eps=.3, seed=None, verbose=False,
attacks_to_run=[], version='standard', is_tf_model=False,
device='cuda', log_path=None, n_iter=40, san=500, sa_re=True):
device='cuda', log_path=None, n_iter=40, san=500, sa_re=True,
outputs_size=10):
self.model = model
self.norm = norm
assert norm in ['Linf', 'L2', 'L1']
Expand All @@ -24,6 +25,7 @@ def __init__(self, model, norm='Linf', eps=.3, seed=None, verbose=False,
self.is_tf_model = False
self.device = device
self.logger = Logger(log_path)
self.outputs_size = outputs_size

if version in ['standard', 'plus', 'rand'] and attacks_to_run != []:
raise ValueError("attacks_to_run will be overridden unless you use version='custom'")
Expand Down Expand Up @@ -323,3 +325,9 @@ def set_version(self, version='standard'):
self.apgd.n_restarts = 1
self.apgd.eot_iter = 20

if self.outputs_size < 4:
if 'apgd-t' in self.attacks_to_run:
self.attacks_to_run.remove('apgd-t')
if self.outputs_size < 3:
if 'apgd-dlr' in self.attacks_to_run:
self.attacks_to_run.remove('apgd-dlr')

0 comments on commit 7c58c37

Please sign in to comment.