In [1]:
import orbax.checkpoint as ocp
from ml_collections import ConfigDict
from pathlib import Path
from utils import prepare_test_dataset
from dataset_utils import get_dataset
from jax import random
from models.utils import sample_gaussian

import models.ClassifierGFZ as ClassifierGFZ
import models.ClassifierDFZ as ClassifierDFZ

checkpoint_path = "dfz-2-epochs-first-try-1"
path = Path.cwd() / Path(f"checkpoints") / Path(checkpoint_path)
checkpoint = ocp.PyTreeCheckpointer().restore(path, item=None)

config = ConfigDict(checkpoint["config"])
dataset_config = ConfigDict(checkpoint["dataset_config"])

if config.model_name == "GFZ":
    classifier = ClassifierGFZ
elif config.model_name == "DFZ":
    classifier = ClassifierDFZ
else:
    raise NotImplementedError(config.model_name)

_, test_ds = get_dataset(config.dataset)
test_images, test_labels = prepare_test_dataset(
    test_ds, dataset_config
    )

trained_params = checkpoint["params"]

log_likelyhood_fn = classifier.log_likelyhood_A

test_key = random.PRNGKey(config.seed)

test_key, model, _ = classifier.create_and_init(
    test_key, config, dataset_config
)



In [2]:
from flax import linen as nn
import jax
from jax import jacrev
import numpy as np
from functools import partial
from jax.scipy.special import logsumexp
import jax.numpy as jnp
import matplotlib.pyplot as plt
from tqdm import tqdm

def init_data(test_key, n_samples=10):
    idx = np.random.choice(range(len(test_images)), n_samples, replace=False)

    all_xs = test_images[idx]
    true_ys = test_labels[idx]
    true_labels = np.argmax(true_ys, axis=1)

    K = model.K
    batch_size = n_samples
    test_key, epsilons = sample_gaussian(test_key, (batch_size, model.n_classes * K, model.d_latent))
    epsilons = epsilons[:n_samples*model.n_classes]
    all_ys = nn.one_hot(jnp.repeat(jnp.arange(model.n_classes), K), model.n_classes, dtype=jnp.float32)
    
    return all_xs, true_labels, epsilons, all_ys, K, test_key

def get_model_output(x, epsilon, y, K):
    z, logit_q_z_xy, logit_p_x_z, logit_p_y_xz = jax.vmap(
            partial(model.apply, {'params': trained_params}, train=False),
            in_axes=(None, 0, 0)
        )(x, y, epsilon)

    ll = log_likelyhood_fn(
            z, logit_q_z_xy, logit_p_x_z, logit_p_y_xz
        ).reshape(model.n_classes, K)
    ll = logsumexp(ll, axis=1) - np.log(K)
    return ll

def get_model_jacobian(x, epsilon, y, K):
    return jacrev(get_model_output, argnums=0)(x, epsilon, y, K)

def map_label_to_name(y):
    labels = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
              "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
    return labels[y]

In [3]:
class DeepFool():
    def __init__(self, model, max_iter=100, learning_rate=1):
        self.model = model
        self.n_classes = model.n_classes
        self.max_iter = max_iter
        self.learning_rate = learning_rate

    def get_label(self, x, epsilon):
        val = get_model_output(x, epsilon, self.y, self.K)
        return np.argmax(val)

    def get_likelihoods_and_gradients(self, x, epsilon):
        
        ## ToDo: check why this becomes very slow after a few iterations
        J = get_model_jacobian(x, epsilon, self.y, self.K)

        val = get_model_output(x, epsilon, self.y, self.K)
        return val, J

    def get_perturbation(self, x, epsilon, all_ys, K):
        corrupted_x = x.copy()
        self.y = all_ys
        self.K = K
        new_label = self.get_label(corrupted_x, epsilon)
        k = self.get_label(x, epsilon)
        for _ in range(self.max_iter):
            best_pert = np.inf
            likelihoods, gradients = self.get_likelihoods_and_gradients(corrupted_x, epsilon)
            for j in range(self.n_classes):
                if j != k:
                    w_j = gradients[j] - gradients[k]
                    f_j = likelihoods[j] - likelihoods[k]
                    pert = np.abs(f_j) / np.linalg.norm(w_j)
                    if pert < best_pert:
                        w = w_j
                        f = f_j
            r_i = (np.abs(f) / np.linalg.norm(w)**2) * w
            corrupted_x = corrupted_x + self.learning_rate*r_i
            new_label = self.get_label(corrupted_x, epsilon)
            if new_label != k:
                break
        perturbation_norm = np.linalg.norm(corrupted_x - x)/np.linalg.norm(x)
        return corrupted_x, new_label, perturbation_norm

In [4]:
def get_average_performance(corruption_model, all_xs, epsilons, all_ys, K):
    perturbation_norms = []
    n_samples = len(all_xs)
    for i in tqdm(range(n_samples)):
        x = all_xs[i]
        epsilon = epsilons[i]
        _, _, perturbation_norm = corruption_model.get_perturbation(x, epsilon, all_ys, K)
        perturbation_norms.append(perturbation_norm)
    return perturbation_norms

In [5]:
n_samples = 100
all_xs, true_labels, epsilons, all_ys, K, test_key = init_data(test_key, n_samples=n_samples)

corruption_model = DeepFool(model)

perturbation_norms = get_average_performance(corruption_model, all_xs, epsilons, all_ys, K)

print(f'Average perturbation norm: {np.mean(perturbation_norms):>.4f}')

 37%|███▋      | 37/100 [03:21<05:42,  5.44s/it]


KeyboardInterrupt: 

In [None]:
i = np.random.randint(n_samples)

x = all_xs[i]
epsilon = epsilons[i]
true_label = true_labels[i]

corruption_model = DeepFool(model, learning_rate=0.2)
corrupted_x, new_label, perturbation_norm = corruption_model.get_perturbation(x, epsilon, all_ys, K)

fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(x.reshape(28, 28), cmap="gray")
axs[0].set_title(f"Original image (label = '{map_label_to_name(true_label)}')")
axs[1].imshow(corrupted_x.reshape(28, 28), cmap="gray")
axs[1].set_title(f"Corrupted image (label = '{map_label_to_name(new_label)}')")
fig.suptitle(f"DeepFool example - perturbation norm: {perturbation_norm:.4f}")
plt.show()