In [66]:
import os
import cv2
import tensorflow as tf
import numpy as np
import pandas as pd
import altair as alt

from tqdm import tqdm
from imageio import imread
from sklearn import metrics
from skimage.transform import resize

import data
import utils

In [2]:
embedding_dir='/projects/leelab3/image_datasets/lfw/embeddings/'

In [3]:
data_dict = data.get_dataset_dict()

length_dict = {}
for item in data_dict:
    length_dict[item] = len(data_dict[item])

names  = np.array(list(length_dict.keys()))
counts = np.array(list(length_dict.values()))
select_names = names[counts > 10]

embeddings = []
embedding_labels = []

100%|██████████| 5749/5749 [00:01<00:00, 3432.51it/s]


In [4]:
for person in tqdm(select_names):
    embedding_top_dir = os.path.join(embedding_dir, person)
    embedding_paths = os.listdir(embedding_top_dir)

    for embedding_path in embedding_paths:
        embedding = np.load(os.path.join(embedding_top_dir, embedding_path))

        embeddings.append(embedding)
        embedding_labels.append(person)

100%|██████████| 143/143 [00:04<00:00, 34.98it/s]


In [5]:
embeddings = np.stack(embeddings, axis=0)
embedding_labels = np.array(embedding_labels)

print('Dataset size: {}'.format(embeddings.shape))
print('Computing distances...')
embedding_distance_matrix = metrics.pairwise.euclidean_distances(embeddings)

Dataset size: (4125, 128)
Computing distances...


In [15]:
sorted_indices = np.argsort(embedding_distance_matrix, axis=1)

In [42]:
num_samples_values = np.arange(1, 11)
k_values = np.arange(1, 11)
resample_number = 10

In [56]:
len(indices_of_person)

15

In [55]:
resample_number

10

In [60]:
total_accuracy_matrix = []
for name in tqdm(select_names):
    indices_of_person = np.where(embedding_labels == name)[0]
    accuracy_matrix = np.zeros((len(num_samples_values), len(k_values)))

    for index in indices_of_person:
        non_index_indices = indices_of_person[indices_of_person != index]
        face_matrix = np.zeros((len(num_samples_values), len(k_values)))
        nearest_neighbors = sorted_indices[index, 1:]

        for i, num_samples in enumerate(num_samples_values):
            for j, k in enumerate(k_values):
                for _ in range(resample_number):
                    blacklisted_samples = np.random.choice(non_index_indices,
                                                           size=len(non_index_indices) - num_samples,
                                                           replace=False)
                    masked_nearest_neighbors = nearest_neighbors[np.isin(nearest_neighbors, blacklisted_samples, invert=True)]

                    (values, counts) = np.unique(masked_nearest_neighbors[:k], return_counts=True)
                    ind = np.argmax(counts)
                    predicted_class = embedding_labels[values[ind]]

                    face_matrix[i, j] += int(predicted_class == name)
        accuracy_matrix += face_matrix

    accuracy_matrix /= (len(indices_of_person) * resample_number)
    total_accuracy_matrix.append(accuracy_matrix)
total_accuracy_matrix = np.stack(total_accuracy_matrix, axis=0)

100%|██████████| 143/143 [18:48<00:00,  7.89s/it]


In [61]:
np.save('total_accuracy_matrix.npy', total_accuracy_matrix)

In [121]:
mean_data = pd.DataFrame(np.mean(total_accuracy_matrix, axis=0))
mean_data.columns = np.arange(1, 11)
mean_data['Number of Faces'] = np.arange(1, 11)
mean_data = pd.melt(mean_data, id_vars=['Number of Faces'], var_name='k', value_name='Mean Accuracy')

In [122]:
mean_data.head()

Unnamed: 0,Number of Faces,k,Mean Accuracy
0,1,1,0.851353
1,2,1,0.924091
2,3,1,0.93692
3,4,1,0.941433
4,5,1,0.944326


In [158]:
k = 1
datas = []
for k in range(1, 4):
    band_data = pd.DataFrame(total_accuracy_matrix[:, :, k - 1])
    band_data.columns = np.arange(1, 11)
    band_data = pd.melt(band_data, var_name='Number_of_Samples', value_name='Mean_Accuracy')
    band_data['k'] = k
    datas.append(band_data)
band_data = pd.concat(datas, axis=0)

In [159]:
alt.data_transformers.disable_max_rows()

DataTransformerRegistry.enable('default')

In [160]:
line = alt.Chart(band_data).mark_line().encode(
    x='Number_of_Samples',
    y='mean(Mean_Accuracy)',
    color=alt.Color('k:O')
)

band = alt.Chart(band_data).mark_errorband(extent='ci').encode(
    x=alt.X('Number_of_Samples', title='Number of Faces Collected'),
    y=alt.Y('Mean_Accuracy', title='Accuracy on LFW'),
    color=alt.Color('k:O')
)

line + band