In [11]:
from matplotlib import pyplot as plt
%matplotlib inline
import h5py
import seaborn
import numpy as np
import os
import tensorflow as tf
from sklearn.metrics import pairwise_distances
import sys
sys.path.append("../")
from utils import l2_normalize, prewhiten, read_sampled_identities
from PIL import Image
seaborn.set()

In [12]:
def recall(
    base_embeddings, 
    lookup_embeddings,
    lookup_labels, 
    k
):
    '''
    For each base embedding in base_embeddings,
    compute top k recall as an absolute number
    (the number of ground truth positives in the lookup_embeddings
    that are in the top k)
    lookup_labels must be an array of 1's and 0's indicating the ground truth
    for the lookup set (1 = positive, 0 = negative)
    '''
    dist = pairwise_distances(
        base_embeddings,
        lookup_embeddings
    )
    return [np.sum(np.array(lookup_labels)[np.argsort(d)][:k]) for d in dist]

In [13]:
path_to_adversarial = "/data/vggface/test_perturbed_sampled/{true}/community_naive_same/{target}/epsilon_{epsilon}.h5"
path_to_clean = "/data/vggface/test_preprocessed_sampled/{id}/embeddings.h5"

epsilons = [0.02, 0.04, 0.06, 0.08, 0.1]

id2imnames = read_sampled_identities("../sampled_identities.txt")
identities = id2imnames.keys()

positive = []
negative_clean = []
negative_adv = {eps: [] for eps in epsilons}
top_k_recall = {identity: {k: {eps: [] for eps in epsilons} for k in [1, 5, 10, 100]} for identity in identities}

for adversarial_target in identities:
    clean_embeddings = []
    mod_embeddings = []
    adv = {eps: [] for eps in epsilons}
    
    with h5py.File(path_to_clean.format(id=adversarial_target), "r") as f:
        clean_embeddings.extend(f["embeddings"][:])

    for modified_identity in identities:
        if modified_identity == adversarial_target:
            continue
        # the identity that was modified for the adversarial
        with h5py.File(path_to_clean.format(id=modified_identity), "r") as f:
            mod_embeddings.extend(f["embeddings"][:])

        for indx, epsilon in enumerate(epsilons):
            with h5py.File(path_to_adversarial.format(
                target=adversarial_target,
                true=modified_identity,
                epsilon=epsilon
            ), "r") as f:
                adv[epsilon].extend(f["embeddings"][:])

    positive.extend(
        pairwise_distances(
            clean_embeddings, 
            clean_embeddings)[np.tril_indices(len(clean_embeddings), -1)].flatten()
    )
    negative_clean.extend(
        pairwise_distances(
            clean_embeddings, 
            mod_embeddings).flatten()
    )
    for indx, epsilon in enumerate(epsilons):
        negative_adv[epsilon].extend(
            pairwise_distances(clean_embeddings, adv[epsilon]).flatten()
        )
        
    for k in top_k_recall[adversarial_target].keys():
        for eps in epsilons:
            top_k_recall[adversarial_target][eps] = recall(
                clean_embeddings, 
                adv[eps],
                [1 for _ in range(len(clean_embeddings))] + [0 for _ in range(len(adv[eps]))],
                k
            )

In [15]:
print(top_k_recall)

{'n009288': [0, 1, 10, 5, 7, 3, 5, 5, 3, 5, 8, 8, 5, 5, 1, 6, 6, 6, 4, 4, 3, 0, 1, 6, 7, 3, 1, 0, 0, 5, 3, 3, 5, 3, 2, 8, 1, 1, 5, 3, 3, 4, 2, 5, 5, 1, 0, 8, 3, 4], 'n002763': [2, 5, 7, 6, 6, 7, 6, 8, 6, 7, 7, 5, 3, 4, 6, 8, 6, 7, 10, 9, 4, 11, 4, 12, 8, 7, 6, 8, 8, 8, 5, 7, 9, 9, 2, 4, 2, 11, 7, 4, 3, 1, 7, 3, 8, 2, 6, 8, 7, 6], 'n000958': [6, 14, 7, 7, 9, 7, 8, 13, 10, 8, 4, 12, 14, 16, 13, 10, 18, 8, 9, 4, 4, 11, 6, 7, 8, 11, 5, 12, 6, 14, 9, 2, 10, 6, 10, 7, 13, 8, 14, 14, 8, 9, 6, 12, 10, 20, 9, 10, 12, 9], 'n002647': [12, 1, 0, 9, 5, 6, 5, 5, 3, 9, 3, 5, 0, 2, 4, 4, 4, 0, 13, 2, 9, 2, 2, 1, 3, 4, 4, 1, 10, 3, 4, 2, 3, 3, 3, 5, 7, 6, 8, 6, 6, 6, 2, 5, 2, 7, 7, 1, 9, 7], 'n008655': [0, 17, 11, 8, 13, 8, 3, 3, 10, 2, 8, 2, 1, 6, 10, 12, 10, 1, 2, 11, 4, 8, 16, 5, 12, 4, 9, 18, 6, 1, 8, 11, 6, 5, 12, 11, 15, 24, 10, 4, 11, 17, 10, 4, 14, 5, 9, 6, 5, 5], 'n003356': [9, 14, 14, 11, 17, 15, 11, 19, 16, 7, 8, 13, 14, 19, 16, 6, 17, 18, 17, 17, 13, 20, 15, 19, 19, 20, 15, 15, 16, 17, 15, 