In [1]:
import orbax.checkpoint as ocp
import os
from ml_collections import ConfigDict
from pathlib import Path
from utils import prepare_test_dataset
from dataset_utils import get_dataset
from jax import random
from models.utils import sample_gaussian

import models.ClassifierGFZ as ClassifierGFZ
import models.ClassifierDFZ as ClassifierDFZ

checkpoint_path = "dfz-2-epochs-first-try-1"
path = os.path.join(Path.cwd(), Path(f"checkpoints"), Path(checkpoint_path))
checkpoint = ocp.PyTreeCheckpointer().restore(path, item=None)

config = ConfigDict(checkpoint["config"])
dataset_config = ConfigDict(checkpoint["dataset_config"])

if config.model_name == "GFZ":
    classifier = ClassifierGFZ
elif config.model_name == "DFZ":
    classifier = ClassifierDFZ
else:
    raise NotImplementedError(config.model_name)

_, test_ds = get_dataset(config.dataset)
test_images, test_labels = prepare_test_dataset(
    test_ds, dataset_config
    )

trained_params = checkpoint["params"]

log_likelyhood_fn = classifier.log_likelyhood_A

test_key = random.PRNGKey(config.seed)

test_key, model, _ = classifier.create_and_init(
    test_key, config, dataset_config
)



In [2]:
from flax import linen as nn
import jax
from jax import jacrev
import numpy as np
from functools import partial
from jax.scipy.special import logsumexp
import jax.numpy as jnp
import matplotlib.pyplot as plt
from tqdm import tqdm
import optax
from scipy.optimize import minimize

def init_data(test_key, n_samples=10):
    idx = np.random.choice(range(len(test_images)), n_samples, replace=False)

    all_xs = test_images[idx]
    true_ys = test_labels[idx]
    true_labels = np.argmax(true_ys, axis=1)

    K = model.K
    batch_size = n_samples
    test_key, epsilons = sample_gaussian(test_key, (batch_size, model.n_classes * K, model.d_latent))
    epsilons = epsilons[:n_samples*model.n_classes]
    all_ys = nn.one_hot(jnp.repeat(jnp.arange(model.n_classes), K), model.n_classes, dtype=jnp.float32)
    
    return all_xs, true_labels, epsilons, all_ys, K, test_key

def get_model_output(x, epsilon, y, K):
    z, logit_q_z_xy, logit_p_x_z, logit_p_y_xz = jax.vmap(
            partial(model.apply, {'params': trained_params}, train=False),
            in_axes=(None, 0, 0)
        )(x, y, epsilon)

    ll = log_likelyhood_fn(
            z, logit_q_z_xy, logit_p_x_z, logit_p_y_xz
        ).reshape(model.n_classes, K)
    ll = logsumexp(ll, axis=1) - np.log(K)
    return ll

def get_model_jacobian(x, epsilon, y, K):
    return jacrev(get_model_output, argnums=0)(x, epsilon, y, K)

def map_label_to_name(y):
    labels = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
              "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
    return labels[y]

In [4]:
def get_average_performance(corruption_model, all_xs, epsilons, all_ys, K):
    perturbation_norms = []
    n_samples = len(all_xs)
    for i in tqdm(range(n_samples)):
        x = all_xs[i]
        epsilon = epsilons[i]
        _, _, perturbation_norm = corruption_model.get_perturbation(x, epsilon, all_ys, K)
        perturbation_norms.append(perturbation_norm)
    return np.array(perturbation_norms)

In [7]:
# Wagner and carlini attack
import numpy as np
import jax
import jax.numpy as jnp
import optax

class WG_Attack():
    def __init__(self, model, max_iter=10, learning_rate=0.1, c = 1, p=2):
        self.model = model
        self.n_classes = model.n_classes
        self.max_iter = max_iter
        self.learning_rate = learning_rate
        self.c = c
        assert p > 1 
        self.p = p
        if self.p == np.inf:
            self.q = 1
        else:
            self.q = self.p / (self.p - 1)

    def qnorm(self, x):
        return jnp.linalg.norm(x.flatten(), self.q)

    def get_label(self, x):
        val = get_model_output(x, self.epsilon, self.y, self.K)
        return jnp.argmax(val)

    def get_likelihoods(self, x):
        val = get_model_output(x, self.epsilon, self.y, self.K)
        return val

    def get_gradients(self, x):
        J = get_model_jacobian(x, self.epsilon, self.y, self.K)
        return J

    def loss(self, val, label):
        label_one_hot_encoding = jax.nn.one_hot(jnp.array([label]), self.n_classes)
        return optax.softmax_cross_entropy(val, label_one_hot_encoding)
    
    def f(self, x, target_label, k = 0):
        val = self.get_likelihoods(x)
        max_logit = jnp.max(val[jnp.arange(self.n_classes) != target_label])
        logit_diff = jnp.maximum(max_logit - val[target_label], - k)
        return logit_diff
    
    def get_objective(self, w, x, target_label, k = 0):
        norm = self.qnorm(1/2 * jnp.tanh(w) + 1/2 - x)
        penalty = self.c * self.f(1/2 * jnp.tanh(w) + 1/2, target_label, k = k)
        return norm + penalty
    
    def get_obj_grad(self, w, x, target_label):
        corrupted_x = 1/2 * (jnp.tanh(w) + 1)
        norm_grad = (1 - jnp.tanh(w)**2) * (corrupted_x - x)

        val = self.get_likelihoods(corrupted_x)
        grad_model = self.get_gradients(corrupted_x)
        max_label = jnp.argmax(val[jnp.arange(self.n_classes) != target_label])
        max_logit = val[max_label]
        logit_diff = max_logit - val[target_label]
        if logit_diff <= 0:
            penalty_grad = 0
        else:
            penalty_grad = grad_model[max_label] - grad_model[target_label]
        
        return jnp.sum(norm_grad + self.c * penalty_grad)

    def get_perturbation(self, x, epsilon, all_ys, K):
        self.y = all_ys
        self.epsilon = epsilon
        self.K = K
        x =jax.device_put(x)
        true_label = self.get_label(x)
        max_perturbation_norm = -1
        best_label = true_label
        best_corrupted_x = x
        for label in range(self.n_classes):
            if label != true_label:
                # use adam optimizer to find minimum of the problem
                w = jnp.zeros_like(x)
                optimizer = optax.adam(learning_rate=self.learning_rate)
                state = optimizer.init(jax.device_put(w))
                for i in range(self.max_iter):
                    corrupted_x = x.copy()                   
                    grad = self.get_obj_grad(w, x, label)
                    updates, state = optimizer.update(grad, state)
                    w = optax.apply_updates(jax.device_put(w), updates)
                    corrupted_x = 1/2 * (jnp.tanh(w) + 1)
                    print(corrupted_x)
                    plt.imshow(corrupted_x.reshape(28, 28), cmap="gray")

                # check if the attack was successful
                new_label = self.get_label(corrupted_x)
                if new_label != label:
                    print("Warning: did not find a perturbation")
                    perturbation_norm = -1
                else:
                    perturbation_norm = jnp.linalg.norm(corrupted_x - x)/jnp.linalg.norm(x)
                # Choose minimal perturbation
                if max_perturbation_norm == -1 and perturbation_norm != -1:
                    max_perturbation_norm = perturbation_norm
                    best_label = new_label
                    best_corrupted_x = corrupted_x
                else : 
                    if perturbation_norm != -1 and perturbation_norm < max_perturbation_norm:
                        max_perturbation_norm = perturbation_norm
                        best_label = new_label
                        best_corrupted_x = corrupted_x

        return best_corrupted_x, best_label, max_perturbation_norm

In [6]:
# Wagner and carlini attack
import numpy as np
import jax
import jax.numpy as jnp
import optax

class untargeted_WG_Attack():
    def __init__(self, model, max_iter=10, learning_rate=0.1, c = 1, p=2):
        self.model = model
        self.n_classes = model.n_classes
        self.max_iter = max_iter
        self.learning_rate = learning_rate
        self.c = c
        assert p > 1 
        self.p = p
        if self.p == np.inf:
            self.q = 1
        else:
            self.q = self.p / (self.p - 1)

    def qnorm(self, x):
        return jnp.linalg.norm(x.flatten(), self.q)

    def get_label(self, x):
        val = get_model_output(x, self.epsilon, self.y, self.K)
        return jnp.argmax(val)

    def get_likelihoods(self, x):
        val = get_model_output(x, self.epsilon, self.y, self.K)
        return val

    def get_gradients(self, x):
        J = get_model_jacobian(x, self.epsilon, self.y, self.K)
        return J

    def loss(self, val, label):
        label_one_hot_encoding = jax.nn.one_hot(jnp.array([label]), self.n_classes)
        return optax.softmax_cross_entropy(val, label_one_hot_encoding)
    
    def f(self, x, target_label, k = 0):
        val = self.get_likelihoods(x)
        max_logit = jnp.max(val[jnp.arange(self.n_classes) != target_label])
        logit_diff = jnp.maximum(max_logit - val[target_label], - k)
        return logit_diff
    
    def get_objective(self, w, x, target_label, k = 0):
        norm = self.qnorm(1/2 * jnp.tanh(w) + 1/2 - x)
        penalty = self.c * self.f(1/2 * jnp.tanh(w) + 1/2, target_label, k = k)
        return norm + penalty
    
    def get_obj_grad(self, w, x, change_var = False):
        if change_var :
            corrupted_x = 1/2 * (jnp.tanh(w) + 1)
            norm_grad = (1 - jnp.tanh(w)**2) * (corrupted_x - x)
        else :
            corrupted_x = x + w
            norm_grad = 2 * (corrupted_x - x)

        val = self.get_likelihoods(corrupted_x)
        grad_model = self.get_gradients(corrupted_x)
        max_label = jnp.argmax(val[jnp.arange(self.n_classes) != self.true_label])
        max_logit = val[max_label]
        logit_diff = val[self.true_label] - max_logit
        if logit_diff <= 0:
            penalty_grad = 0
        else:
            penalty_grad = grad_model[self.true_label] - grad_model[max_label]
        
        return jnp.sum(norm_grad + self.c * penalty_grad)
    
    def project_to_bounds(self, x):
        bounds_min = jnp.zeros_like(x)
        bounds_max = jnp.ones_like(x)
        return jnp.clip(x, bounds_min, bounds_max)

    def get_perturbation(self, x, epsilon, all_ys, K, change_var = False):
        self.y = all_ys
        self.epsilon = epsilon
        self.K = K
        x = jax.device_put(x)
        self.true_label = self.get_label(x)
        # use adam optimizer to find minimum of the problem
        w = jnp.zeros_like(x)
        optimizer = optax.adam(learning_rate=self.learning_rate)
        state = optimizer.init(jax.device_put(w))
        for i in range(self.max_iter):
            corrupted_x = x.copy()
            grad = self.get_obj_grad(w, x)
            updates, state = optimizer.update(grad, state)
            w = optax.apply_updates(jax.device_put(w), updates)
            # w = w - self.learning_rate * grad
            if change_var :
                corrupted_x = 1/2 * (jnp.tanh(w) + 1)
            else :
                corrupted_x = x + w
                corrupted_x = self.project_to_bounds(corrupted_x)
            
        # check if the attack was successful
        new_label = self.get_label(corrupted_x)
        if new_label == self.true_label:
            print("Warning: did not find a perturbation")
            perturbation_norm = -1
        else:
            perturbation_norm = jnp.linalg.norm(corrupted_x - x)/jnp.linalg.norm(x)

        return corrupted_x, new_label, perturbation_norm

In [None]:
n_samples = 10
all_xs, true_labels, epsilons, all_ys, K, test_key = init_data(test_key, n_samples=n_samples)

corruption_model = untargeted_WG_Attack(model, max_iter=100, learning_rate=0.01, c = 0.10, p=2)

perturbation_norms_WG = get_average_performance(corruption_model, all_xs, epsilons, all_ys, K)
perturbation_norms_successful_WG = perturbation_norms_WG[perturbation_norms_WG != -1]
n_successful_WG = len(perturbation_norms_successful_WG)
n_successful_WG
print(f'Average perturbation norm of Wagner & Carlini Attack model (on {n_successful_WG} successful samples): {np.mean(perturbation_norms_successful_WG):>.4f}')

In [None]:
i = np.random.choice(range(n_samples))
x = all_xs[i]
true_label = true_labels[i]
test_key, epsilons = sample_gaussian(test_key, (1, model.n_classes * K, model.d_latent))
epsilon = epsilons[0]

corruption_model = untargeted_WG_Attack(model, max_iter=100, learning_rate=0.01, c = 1, p=2)
corrupted_x, new_label, perturbation_norm = corruption_model.get_perturbation(x, epsilon, all_ys, K)

fig, axs = plt.subplots(1, 2, figsize=(15, 5))

axs[0].imshow(x.reshape(28, 28), cmap="gray")
axs[0].set_title(f"Original image (label = '{map_label_to_name(true_label)}')")

axs[1].imshow(corrupted_x.reshape(28, 28), cmap="gray")
axs[1].set_title(f"Carlini and Wagner perturbated image (label = '{map_label_to_name(new_label)}')")

plt.show()