Skip to content

Commit

Permalink
Merge pull request #1218 from LPirch/master
Browse files Browse the repository at this point in the history
fix CarliniWagnerL2 dtype bug
  • Loading branch information
alkaet committed Sep 23, 2021
2 parents b54cb0a + 6d49cb4 commit 95d5bc9
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions cleverhans/tf2/attacks/carlini_wagner_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,14 @@ def _attack(self, x):
f"The input is greater than the maximum value of {self.clip_max}!"
)

y, _ = get_or_guess_labels(self.model_fn, x, y=self.y, targeted=self.targeted)

# cast to tensor if provided as numpy array
original_x = tf.cast(x, tf.float32)
shape = original_x.shape

y, _ = get_or_guess_labels(
self.model_fn, original_x, y=self.y, targeted=self.targeted
)

if not y.shape.as_list()[0] == original_x.shape.as_list()[0]:
raise CarliniWagnerL2Exception("x and y do not have the same shape!")

Expand Down

0 comments on commit 95d5bc9

Please sign in to comment.