In [2]:
import orbax.checkpoint as ocp
import os
from ml_collections import ConfigDict
from pathlib import Path
from utils import prepare_test_dataset
from dataset_utils import get_dataset
from jax import random
from models.utils import sample_gaussian

import models.ClassifierGFZ as ClassifierGFZ
import models.ClassifierDFZ as ClassifierDFZ

checkpoint_path = "dfz-2-epochs-first-try-1"
path = os.path.join(Path.cwd(), Path(f"checkpoints"), Path(checkpoint_path))
checkpoint = ocp.PyTreeCheckpointer().restore(path, item=None)

config = ConfigDict(checkpoint["config"])
dataset_config = ConfigDict(checkpoint["dataset_config"])

if config.model_name == "GFZ":
    classifier = ClassifierGFZ
elif config.model_name == "DFZ":
    classifier = ClassifierDFZ
else:
    raise NotImplementedError(config.model_name)

_, test_ds = get_dataset(config.dataset)
test_images, test_labels = prepare_test_dataset(
    test_ds, dataset_config
    )

trained_params = checkpoint["params"]

log_likelyhood_fn = classifier.log_likelyhood_A

test_key = random.PRNGKey(config.seed)

test_key, model, _ = classifier.create_and_init(
    test_key, config, dataset_config
)



In [3]:
from flax import linen as nn
import jax
from jax import jacrev
import numpy as np
from functools import partial
from jax.scipy.special import logsumexp
import jax.numpy as jnp
import matplotlib.pyplot as plt
from tqdm import tqdm
import optax
from scipy.optimize import minimize

def init_data(test_key, n_samples=10):
    idx = np.random.choice(range(len(test_images)), n_samples, replace=False)

    all_xs = test_images[idx]
    true_ys = test_labels[idx]
    true_labels = np.argmax(true_ys, axis=1)

    K = model.K
    batch_size = n_samples
    test_key, epsilons = sample_gaussian(test_key, (batch_size, model.n_classes * K, model.d_latent))
    epsilons = epsilons[:n_samples*model.n_classes]
    all_ys = nn.one_hot(jnp.repeat(jnp.arange(model.n_classes), K), model.n_classes, dtype=jnp.float32)
    
    return all_xs, true_labels, epsilons, all_ys, K, test_key

def get_model_output(x, epsilon, y, K):
    z, logit_q_z_xy, logit_p_x_z, logit_p_y_xz = jax.vmap(
            partial(model.apply, {'params': trained_params}, train=False),
            in_axes=(None, 0, 0)
        )(x, y, epsilon)

    ll = log_likelyhood_fn(
            z, logit_q_z_xy, logit_p_x_z, logit_p_y_xz
        ).reshape(model.n_classes, K)
    ll = logsumexp(ll, axis=1) - np.log(K)
    return ll

def get_model_jacobian(x, epsilon, y, K):
    return jacrev(get_model_output, argnums=0)(x, epsilon, y, K)

def map_label_to_name(y):
    labels = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
              "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
    return labels[y]

In [4]:
class Constrained_BFGS_B():
    def __init__(self, f, grad, x0, bounds, maxiter=1000, eps=1e-8):
        self.f = f
        self.grad = grad
        self.x0 = x0
        self.bounds = bounds
        self.maxiter = maxiter
        self.eps = eps

    def line_search(self, x, direction, alpha = 0.4, beta=0.8, max_iter=1000):
        step_size = 1
        i = 0
        while i < max_iter:
            if self.f(x + step_size  * direction) <= self.f(x) + step_size * alpha * direction.dot(self.grad(x)):
                break
            step_size  *= beta
            i += 1
            
        return step_size
    
    def determine_active_set(self, x, bounds):
        active_set = np.ones_like(x, dtype=bool)

        for i, (lower_bound, upper_bound) in enumerate(bounds):
            if lower_bound is not None and x[i] <= lower_bound:
                active_set[i] = False  # Variable is at the lower bound
            elif upper_bound is not None and x[i] >= upper_bound:
                active_set[i] = False  # Variable is at the upper bound

        return active_set
    
    def update_inverse_hessian_bfgs_b(self, Bk, sk, yk, active_set):
        sk_active = sk[active_set]
        yk_active = yk[active_set]

        if not any(active_set):
            return Bk

        rho = 1 / np.dot(yk_active, sk_active)
        term1 = np.eye(len(sk_active)) - np.outer(sk_active, yk_active) * rho
        term2 = np.eye(len(sk_active)) - np.outer(yk_active, sk_active) * rho
        Bk1_active = np.dot(term1, np.dot(Bk, term2)) + np.outer(rho * sk_active, sk_active)
        Bk1 = Bk.copy()
        Bk1[np.ix_(active_set, active_set)] = Bk1_active

        return Bk1

    def optimize(self):
        x = self.x0
        B = np.eye(len(x))
        for i in range(self.maxiter):
            g = self.grad(x)
            if np.linalg.norm(g) < self.eps:
                break
            direction = -B.dot(g)
            step_size = self.line_search(x, direction)
            x_new = x + step_size * direction
            s = x_new - x
            y = self.grad(x_new) - g
            active_set = self.determine_active_set(x_new, self.bounds)
            B = self.update_inverse_hessian_bfgs_b(B, s, y, active_set)
            x = x_new
            print(self.f(x))

        return x


In [25]:
import numpy as np
import jax
import jax.numpy as jnp
import optax

class L_BGFS_Attack():
    def __init__(self, model, max_iter=100, learning_rate=1, p=2):
        self.model = model
        self.n_classes = model.n_classes
        self.max_iter = max_iter
        self.learning_rate = learning_rate
        assert p > 1 
        self.p = p
        if self.p == np.inf:
            self.q = 1
        else:
            self.q = self.p / (self.p - 1)

    def qnorm(self, x):
        return jnp.linalg.norm(x.flatten(), self.q)

    def get_label(self, x):
        val = get_model_output(x, self.epsilon, self.y, self.K)
        return jnp.argmax(val)

    def get_likelihoods(self, x):
        val = get_model_output(x, self.epsilon, self.y, self.K)
        return val

    def get_gradients(self, x):
        J = get_model_jacobian(x, self.epsilon, self.y, self.K)
        return J.flatten()

    def loss(self, val, label):
        label_one_hot_encoding = jax.nn.one_hot(jnp.array([label]), self.n_classes)
        return optax.softmax_cross_entropy(val, label_one_hot_encoding)
    
    def project_to_bounds(self, x, bounds):
        bounds_min = jnp.zeros_like(x)
        bounds_max = jnp.ones_like(x)
        return jnp.clip(x, bounds_min, bounds_max)

    def get_perturbation(self, x, epsilon, all_ys, K):
        corrupted_x = x.copy()
        self.y = all_ys
        self.epsilon = epsilon
        self.K = K
        true_label = self.get_label(corrupted_x)
        bounds = np.array([[0, 1] * len(corrupted_x.flatten())])
        r = np.ones_like(corrupted_x)
        max_perturbation_norm = jnp.linalg.norm(jax.device_put(r))
        best_label = true_label
        best_corrupted_x = corrupted_x

        # Line search for c
        for label in range(self.n_classes):
            if label != true_label:
                def get_problem(r):
                    corrupted_x = self.project_to_bounds(x + r.reshape(x.shape), bounds)
                    val = self.get_likelihoods(corrupted_x)
                    return jnp.sum(self.qnorm(r) + self.loss(val, label))

                i = 0
                optimizer = optax.adam(learning_rate=self.learning_rate)
                state = optimizer.init(jax.device_put(r.flatten()))
                while i < self.max_iter:
                    grad = jax.grad(get_problem)(jax.device_put(r.flatten()))
                    updates, state = optimizer.update(grad, state)
                    r = optax.apply_updates(jax.device_put(r.flatten()), updates).reshape(corrupted_x.shape)
                    corrupted_x = self.project_to_bounds(corrupted_x + r, bounds)
                    if self.get_label(corrupted_x) == label:
                        break
                    else:
                        i += 1

                new_label = self.get_label(corrupted_x)
                if new_label != label:
                    print("Warning: did not find a perturbation")
                    perturbation_norm = -1
                else:
                    perturbation_norm = jnp.linalg.norm(jax.device_put(r))

                # Choose minimal perturbation
                if perturbation_norm < max_perturbation_norm:
                    max_perturbation_norm = perturbation_norm
                    best_label = new_label
                    best_corrupted_x = corrupted_x

        return best_corrupted_x, best_label, max_perturbation_norm



In [6]:
def get_average_performance(corruption_model, all_xs, epsilons, all_ys, K):
    perturbation_norms = []
    n_samples = len(all_xs)
    for i in tqdm(range(n_samples)):
        x = all_xs[i]
        epsilon = epsilons[i]
        _, _, perturbation_norm = corruption_model.get_perturbation(x, epsilon, all_ys, K)
        perturbation_norms.append(perturbation_norm)
    return np.array(perturbation_norms)

In [7]:
#Testing optimization algorithm
def f(x):
    return x**2

def grad(x):
    return 2*x
f = f
g = grad
bounds = [(0, 1)] * 1
i = 0
corrupted_x = np.array([1])
print(f(corrupted_x))
BFGS = Constrained_BFGS_B(f, g, corrupted_x, bounds)
r = BFGS.optimize()
print(f(r))

[1]
[0.000576]
[3.31776e-07]
[1.17549435e-38]
[1.17549435e-38]


In [8]:
n_samples = 10
all_xs, true_labels, epsilons, all_ys, K, test_key = init_data(test_key, n_samples=n_samples)

corruption_model = L_BGFS_Attack(model)

perturbation_norms_BFGS = get_average_performance(corruption_model, all_xs, epsilons, all_ys, K)
perturbation_norms_successful_BFGS = perturbation_norms_BFGS[perturbation_norms_BFGS != -1]
n_successful_BFGS = len(perturbation_norms_successful_BFGS)
n_successful_BFGS
print(f'Average perturbation norm of L_BGFS Attack model (on {} successful samples): {np.mean(perturbation_norms_successful_BFGS):>.4f}')

100%|██████████| 10/10 [02:39<00:00, 15.90s/it]

Average perturbation norm of L_BGFS Attack model (on 10 successful samples): 17.2475





In [42]:
# Wagner and carlini attack
import numpy as np
import jax
import jax.numpy as jnp
import optax

class WG_Attack():
    def __init__(self, model, max_iter=100, learning_rate=1, p=2):
        self.model = model
        self.n_classes = model.n_classes
        self.max_iter = max_iter
        self.learning_rate = learning_rate
        assert p > 1 
        self.p = p
        if self.p == np.inf:
            self.q = 1
        else:
            self.q = self.p / (self.p - 1)

    def qnorm(self, x):
        return jnp.linalg.norm(x.flatten(), self.q)

    def get_label(self, x):
        val = get_model_output(x, self.epsilon, self.y, self.K)
        return jnp.argmax(val)

    def get_likelihoods(self, x):
        val = get_model_output(x, self.epsilon, self.y, self.K)
        return val

    def get_gradients(self, x):
        J = get_model_jacobian(x, self.epsilon, self.y, self.K)
        return J.flatten()

    def loss(self, val, label):
        label_one_hot_encoding = jax.nn.one_hot(jnp.array([label]), self.n_classes)
        return optax.softmax_cross_entropy(val, label_one_hot_encoding)
    
    def f(self, x, target_label, k = 0):
        val = self.get_likelihoods(x)
        max_logit = jnp.max(val[jnp.arange(self.n_classes) != target_label])
        logit_diff = jnp.maximum(max_logit - val[target_label], - k)
        return logit_diff
    
    def get_objective(self, w, x, target_label, c, k = 0, type = "L2"):
        if type == "L2":
            norm = self.qnorm(1/2 * jnp.tanh(w) + 1/2 - x)
        penalty = c * self.f(1/2 * jnp.tanh(w) + 1/2, target_label, k = k)
        return norm + penalty
    
    def project_to_bounds(self, x):
        bounds_min = jnp.zeros_like(x)
        bounds_max = jnp.ones_like(x)
        return jnp.clip(x, bounds_min, bounds_max)

    def get_perturbation(self, x, epsilon, all_ys, K):
        corrupted_x = x.copy()
        self.y = all_ys
        self.epsilon = epsilon
        self.K = K
        true_label = self.get_label(corrupted_x)
        w = np.ones_like(corrupted_x)
        max_perturbation_norm = jnp.linalg.norm(jax.device_put(r))
        best_label = true_label
        best_corrupted_x = corrupted_x
        # define the wagner and carlini attack problem
        for label in range(self.n_classes): # to do : optimize this loop
            if label != true_label:
                def get_problem(w):
                    w = self.project_to_bounds(w)
                    target_label = label
                    return self.get_objective(w, x, target_label, 1, k = 0, type = "L2")
                # use adam optimizer to find minimum of the problem
                optimizer = optax.adam(learning_rate=self.learning_rate)
                state = optimizer.init(jax.device_put(w))
                for i in range(self.max_iter):
                    grad = jax.grad(get_problem)(jax.device_put(w))
                    updates, state = optimizer.update(grad, state)
                    w = optax.apply_updates(jax.device_put(w), updates)
                    if self.get_label(1/2 * jnp.tanh(w) + 1/2) == label:
                        break
                # check if the attack was successful
                new_label = self.get_label(1/2 * jnp.tanh(w) + 1/2)
                if new_label != label:
                    print("Warning: did not find a perturbation")
                    perturbation_norm = -1
                else:
                    perturbation_norm = jnp.linalg.norm(jax.device_put(w))
                # Choose minimal perturbation
                if perturbation_norm < max_perturbation_norm:
                    max_perturbation_norm = perturbation_norm
                    best_label = new_label
                    best_corrupted_x = 1/2 * jnp.tanh(w) + 1/2

        return best_corrupted_x, best_label, max_perturbation_norm

In [43]:
n_samples = 10
all_xs, true_labels, epsilons, all_ys, K, test_key = init_data(test_key, n_samples=n_samples)

corruption_model = WG_Attack(model)

perturbation_norms_WG = get_average_performance(corruption_model, all_xs, epsilons, all_ys, K)
perturbation_norms_successful_WG = perturbation_norms_WG[perturbation_norms_BFGS != -1]
n_successful_WG = len(perturbation_norms_successful_WG)
n_successful_WG
print(f'Average perturbation norm of Wagner & Carlini Attack model (on {n_successful_BFGS} successful samples): {np.mean(perturbation_norms_successful_BFGS):>.4f}')

  0%|          | 0/10 [00:00<?, ?it/s]



 10%|█         | 1/10 [07:22<1:06:23, 442.66s/it]



 20%|██        | 2/10 [15:10<1:00:59, 457.50s/it]



 30%|███       | 3/10 [22:13<51:32, 441.74s/it]  



 40%|████      | 4/10 [30:18<45:53, 458.89s/it]



 50%|█████     | 5/10 [37:31<37:27, 449.40s/it]



 60%|██████    | 6/10 [46:03<31:22, 470.71s/it]



 70%|███████   | 7/10 [54:23<24:00, 480.22s/it]



 80%|████████  | 8/10 [1:04:02<17:03, 511.66s/it]



 90%|█████████ | 9/10 [1:14:19<09:04, 544.76s/it]



100%|██████████| 10/10 [1:23:30<00:00, 501.01s/it]

Average perturbation norm of Wagner & Carlini Attack model (on 10 successful samples): 17.2475



