In [1]:
import orbax.checkpoint as ocp
import os
from ml_collections import ConfigDict
from pathlib import Path
from utils import prepare_test_dataset
from dataset_utils import get_dataset
from jax import random
from models.utils import sample_gaussian

import models.ClassifierGFZ as ClassifierGFZ
import models.ClassifierDFZ as ClassifierDFZ

checkpoint_path = "dfz-2-epochs-first-try-1"
path = os.path.join(Path.cwd(), Path(f"checkpoints"), Path(checkpoint_path))
checkpoint = ocp.PyTreeCheckpointer().restore(path, item=None)

config = ConfigDict(checkpoint["config"])
dataset_config = ConfigDict(checkpoint["dataset_config"])

if config.model_name == "GFZ":
    classifier = ClassifierGFZ
elif config.model_name == "DFZ":
    classifier = ClassifierDFZ
else:
    raise NotImplementedError(config.model_name)

_, test_ds = get_dataset(config.dataset)
test_images, test_labels = prepare_test_dataset(
    test_ds, dataset_config
    )

trained_params = checkpoint["params"]

log_likelyhood_fn = classifier.log_likelyhood_A

test_key = random.PRNGKey(config.seed)

test_key, model, _ = classifier.create_and_init(
    test_key, config, dataset_config
)



In [2]:
from flax import linen as nn
import jax
from jax import jacrev
import numpy as np
from functools import partial
from jax.scipy.special import logsumexp
import jax.numpy as jnp
import matplotlib.pyplot as plt
from tqdm import tqdm
import optax
from scipy.optimize import minimize

def init_data(test_key, n_samples=10):
    idx = np.random.choice(range(len(test_images)), n_samples, replace=False)

    all_xs = test_images[idx]
    true_ys = test_labels[idx]
    true_labels = np.argmax(true_ys, axis=1)

    K = model.K
    batch_size = n_samples
    test_key, epsilons = sample_gaussian(test_key, (batch_size, model.n_classes * K, model.d_latent))
    epsilons = epsilons[:n_samples*model.n_classes]
    all_ys = nn.one_hot(jnp.repeat(jnp.arange(model.n_classes), K), model.n_classes, dtype=jnp.float32)
    
    return all_xs, true_labels, epsilons, all_ys, K, test_key

def get_model_output(x, epsilon, y, K):
    z, logit_q_z_xy, logit_p_x_z, logit_p_y_xz = jax.vmap(
            partial(model.apply, {'params': trained_params}, train=False),
            in_axes=(None, 0, 0)
        )(x, y, epsilon)

    ll = log_likelyhood_fn(
            z, logit_q_z_xy, logit_p_x_z, logit_p_y_xz
        ).reshape(model.n_classes, K)
    ll = logsumexp(ll, axis=1) - np.log(K)
    return ll

def get_model_jacobian(x, epsilon, y, K):
    return jacrev(get_model_output, argnums=0)(x, epsilon, y, K)

def map_label_to_name(y):
    labels = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
              "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
    return labels[y]

In [8]:
# Zeroth order optimization attack
import numpy as np
import jax
import jax.numpy as jnp
import optax

class ZOO_Attack():
    def __init__(self, model, max_iter=10, learning_rate=0.1, c=1, p=2):
        self.model = model
        self.n_classes = model.n_classes
        self.max_iter = max_iter
        self.learning_rate = learning_rate
        self.c = c
        assert p > 1 
        self.p = p
        if self.p == np.inf:
            self.q = 1
        else:
            self.q = self.p / (self.p - 1)

    def qnorm(self, x):
        return jnp.linalg.norm(x.flatten(), self.q)

    def get_label(self, x):
        val = get_model_output(x, self.epsilon, self.y, self.K)
        return jnp.argmax(val)

    def get_likelihoods(self, x):
        val = get_model_output(x, self.epsilon, self.y, self.K)
        return val

    def get_gradients(self, x, epsilon=1e-5):
        # Estimate gradients using finite differences
        perturbed_x_plus = x + epsilon
        perturbed_x_minus = x - epsilon

        output_plus = self.get_likelihoods(perturbed_x_plus)
        output_minus = self.get_likelihoods(perturbed_x_minus)

        gradients = (output_plus - output_minus) / (2 * epsilon)

        return gradients
    
    def project_to_bounds(self, x):
        bounds_min = jnp.zeros_like(x)
        bounds_max = jnp.ones_like(x)
        return jnp.clip(x, bounds_min, bounds_max)
    
    def f(self, x, target_label, k = 0):
        x = self.project_to_bounds(x)
        val = self.get_likelihoods(x)
        max_logit = jnp.max(val[jnp.arange(self.n_classes) != target_label])
        logit_diff = jnp.maximum(max_logit - val[target_label], - k)
        return logit_diff
    
    def get_objective(self, w, x, target_label, k = 0):
        norm = self.qnorm(w)
        penalty = self.c * self.f(x + w, target_label, k = k)
        return norm + penalty
    
    def get_obj_grad(self, w, x, target_label):
        # Compute gradient of the objective function
        corrupted_x = x + w
        norm_grad = (2) * (corrupted_x - x)

        val = self.get_likelihoods(corrupted_x)
        grad_model = self.get_gradients(corrupted_x)
        max_label = jnp.argmax(val[jnp.arange(self.n_classes) != target_label])
        max_logit = val[max_label]
        logit_diff = max_logit - val[target_label]
        if logit_diff <= 0:
            penalty_grad = 0
        else:
            penalty_grad = grad_model[max_label] - grad_model[target_label]
        
        return norm_grad + self.c * penalty_grad
    
    def project_to_bounds(self, x):
        bounds_min = jnp.zeros_like(x)
        bounds_max = jnp.ones_like(x)
        return jnp.clip(x, bounds_min, bounds_max)

    def get_perturbation(self, x, epsilon, all_ys, K):
        self.y = all_ys
        self.epsilon = epsilon
        self.K = K
        true_label = self.get_label(x)
        max_perturbation_norm = -1
        best_label = true_label
        best_corrupted_x = jnp.zeros_like(x)
        # for label in range(self.n_classes): # to do : optimize this loop
        for label in range(self.n_classes): # to do : optimize this loop
            if label != true_label:
                w = jnp.zeros_like(x)
                # use gradient descent to find minimum of the problem
                for i in tqdm(range(self.max_iter)):
                    grad = self.get_obj_grad(w, x, label)
                    w = w - self.learning_rate * grad
                    corrupted_x = x + w
                    corrupted_x = self.project_to_bounds(corrupted_x)
                    if self.get_label(corrupted_x) == label:
                        break
                    
                # check if the attack was successful
                new_label = self.get_label(corrupted_x)
                if new_label != label:
                    print("Warning: did not find a perturbation for this label")
                    perturbation_norm = -1
                else:
                    perturbation_norm = np.linalg.norm(corrupted_x - x)/np.linalg.norm(x)
                    print("Found a perturbation for label", label, "with norm", perturbation_norm)
                    print(corrupted_x)

                # Choose minimal perturbation
                if max_perturbation_norm == -1 and perturbation_norm != -1:
                    max_perturbation_norm = perturbation_norm
                    best_label = new_label
                    best_corrupted_x = corrupted_x
                else : 
                    if perturbation_norm != -1 and perturbation_norm < max_perturbation_norm:
                        max_perturbation_norm = perturbation_norm
                        best_label = new_label
                        best_corrupted_x = corrupted_x

        return best_corrupted_x, best_label, max_perturbation_norm

In [None]:
# Zeroth order optimization attack
import numpy as np
import jax
import jax.numpy as jnp
import optax

class ZOO_Attack():
    def __init__(self, model, max_iter=10, learning_rate=0.1, c=1, p=2):
        self.model = model
        self.n_classes = model.n_classes
        self.max_iter = max_iter
        self.learning_rate = learning_rate
        self.c = c
        assert p > 1 
        self.p = p
        if self.p == np.inf:
            self.q = 1
        else:
            self.q = self.p / (self.p - 1)

    def qnorm(self, x):
        return jnp.linalg.norm(x.flatten(), self.q)

    def get_label(self, x):
        val = get_model_output(x, self.epsilon, self.y, self.K)
        return jnp.argmax(val)

    def get_likelihoods(self, x):
        val = get_model_output(x, self.epsilon, self.y, self.K)
        return val

    def get_gradients(self, x, epsilon=1e-5):
        # Estimate gradients using finite differences
        perturbed_x_plus = x + epsilon
        perturbed_x_minus = x - epsilon

        output_plus = self.get_likelihoods(perturbed_x_plus)
        output_minus = self.get_likelihoods(perturbed_x_minus)

        gradients = (output_plus - output_minus) / (2 * epsilon)

        return gradients
    
    def project_to_bounds(self, x):
        bounds_min = jnp.zeros_like(x)
        bounds_max = jnp.ones_like(x)
        return jnp.clip(x, bounds_min, bounds_max)
    
    def f(self, x, target_label, k = 0):
        x = self.project_to_bounds(x)
        val = self.get_likelihoods(x)
        max_logit = jnp.max(val[jnp.arange(self.n_classes) != target_label])
        logit_diff = jnp.maximum(max_logit - val[target_label], - k)
        return logit_diff
    
    def get_objective(self, w, x, target_label, k = 0):
        norm = self.qnorm(w)
        penalty = self.c * self.f(x + w, target_label, k = k)
        return norm + penalty
    
    def get_obj_grad(self, w, x, target_label):
        # Compute gradient of the objective function
        corrupted_x = x + w
        norm_grad = (2) * (corrupted_x - x)

        val = self.get_likelihoods(corrupted_x)
        grad_model = self.get_gradients(corrupted_x)
        max_label = jnp.argmax(val[jnp.arange(self.n_classes) != target_label])
        max_logit = val[max_label]
        logit_diff = max_logit - val[target_label]
        if logit_diff <= 0:
            penalty_grad = 0
        else:
            penalty_grad = grad_model[max_label] - grad_model[target_label]
        
        return norm_grad + self.c * penalty_grad
    
    def project_to_bounds(self, x):
        bounds_min = jnp.zeros_like(x)
        bounds_max = jnp.ones_like(x)
        return jnp.clip(x, bounds_min, bounds_max)

    def get_perturbation(self, x, epsilon, all_ys, K):
        self.y = all_ys
        self.epsilon = epsilon
        self.K = K
        true_label = self.get_label(x)
        max_perturbation_norm = -1
        best_label = true_label
        best_corrupted_x = jnp.zeros_like(x)
        # for label in range(self.n_classes): # to do : optimize this loop
        w = jnp.zeros_like(x)
        # use gradient descent to find minimum of the problem
        for i in tqdm(range(self.max_iter)):
            grad = self.get_obj_grad(w, x)
            w = w - self.learning_rate * grad
            corrupted_x = x + w
            corrupted_x = self.project_to_bounds(corrupted_x)
            if self.get_label(corrupted_x) == label:
                break
            
        # check if the attack was successful
        new_label = self.get_label(corrupted_x)
        if new_label != label:
            print("Warning: did not find a perturbation for this label")
            perturbation_norm = -1
        else:
            perturbation_norm = np.linalg.norm(corrupted_x - x)/np.linalg.norm(x)
            print("Found a perturbation for label", label, "with norm", perturbation_norm)
            print(corrupted_x)

        return corrupted_x, new_label, perturbation_norm

In [4]:
import numpy as np

def find_argmax_except_diagonal(matrix):
    # Create an array of column indices
    matrix = np.array([matrix for i in range(matrix.shape[0])])
    matrix = np.transpose(matrix)
    col_indices = np.arange(matrix.shape[1])
    values = matrix[np.arange(matrix.shape[1]), np.arange(matrix.shape[1])]
    mask = col_indices[:, np.newaxis] == np.arange(matrix.shape[0])
    matrix[mask] = np.min(matrix) - 1
    argmax_indices = np.argmax(matrix, axis=1)
    max_logits = matrix[argmax_indices, np.arange(matrix.shape[1])]

    diff = max_logits - values
    print((diff>0)*1)

    return argmax_indices

# Example usage:
matrix = np.array([1, 2, 3])

argmax_indices = find_argmax_except_diagonal(matrix)
print("Argmax indices for each column (excluding diagonal):", argmax_indices)


[1 0 0]
Argmax indices for each column (excluding diagonal): [1 0 0]


In [5]:
def get_average_performance(corruption_model, all_xs, epsilons, all_ys, K):
    perturbation_norms = []
    n_samples = len(all_xs)
    for i in tqdm(range(n_samples)):
        x = all_xs[i]
        epsilon = epsilons[i]
        _, _, perturbation_norm = corruption_model.get_perturbation(x, epsilon, all_ys, K)
        perturbation_norms.append(perturbation_norm)
    return np.array(perturbation_norms)

In [9]:
n_samples = 1
all_xs, true_labels, epsilons, all_ys, K, test_key = init_data(test_key, n_samples=n_samples)

corruption_model = ZOO_Attack(model, max_iter=10, learning_rate=0.1, c=1, p=2)

perturbation_norms_ZOO = get_average_performance(corruption_model, all_xs, epsilons, all_ys, K)
perturbation_norms_successful_ZOO = perturbation_norms_ZOO[perturbation_norms_ZOO != -1]
n_successful_ZOO = len(perturbation_norms_successful_ZOO)
n_successful_ZOO
print(f'Average perturbation norm of ZOO Attack model (on {n_successful_ZOO} successful samples): {np.mean(perturbation_norms_successful_ZOO):>.4f}')

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:12<00:00,  1.26s/it]




100%|██████████| 10/10 [00:11<00:00,  1.17s/it]




100%|██████████| 10/10 [00:13<00:00,  1.38s/it]






In [7]:
i = np.random.choice(range(n_samples))
x = all_xs[i]
true_label = true_labels[i]
test_key, epsilons = sample_gaussian(test_key, (1, model.n_classes * K, model.d_latent))
epsilon = epsilons[0]

corruption_model = ZOO_Attack(model, max_iter=100, learning_rate=0.1, c = 10, p=2)
corrupted_x, new_label, perturbation_norm = corruption_model.get_perturbation(x, epsilon, all_ys, K)

fig, axs = plt.subplots(1, 2, figsize=(15, 5))

axs[0].imshow(x.reshape(28, 28), cmap="gray")
axs[0].set_title(f"Original image (label = '{map_label_to_name(true_label)}')")

axs[1].imshow(corrupted_x.reshape(28, 28), cmap="gray")
axs[1].set_title(f"ZOO perturbated image (label = '{map_label_to_name(new_label)}')")

plt.show()

 40%|████      | 40/100 [01:05<01:38,  1.64s/it]


KeyboardInterrupt: 

In [None]:
print(jnp.linalg.norm(x - corrupted_x))

35.840046


In [None]:
print(x[0,:10])

[[0.        ]
 [0.        ]
 [0.        ]
 [0.        ]
 [0.        ]
 [0.02352941]
 [0.        ]
 [0.        ]
 [0.        ]
 [0.20392157]]


In [None]:
print(corrupted_x[0,:10])

[[0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]]
