In [13]:
import pandas as pd
from pathlib import Path
import numpy as np
from typing import Iterable
from tqdm.auto import tqdm
import pickle
from scipy.spatial import distance
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_curve

In [2]:
FIW_FEATURES = Path("/Users/zkhan/Dropbox/rfiw2020-data/FIDs-features/")
# https://1drv.ms/u/s!AkDk_XdBkJ9wgagCPB-UakHehdEACw?e=hBAJz2
training_csv = pd.read_csv("/Users/zkhan/Downloads/sample_train_face_list.csv")
test_csv = pd.read_csv("/Users/zkhan/Downloads/test_reference.csv")

In [3]:
def read_features_from_iterable_of_pictures(iterable: Iterable[str], feature_dir: Path, feature_len: int = 512):
    """
    For each picture in the iterable, read the corresponding feature
    file from a directory of feature files.
    
    Parameters
    ------------
    iterable:
        An iterable of face image names.
    feature_dir:
        A Path to a directory containing features of faces, organized in
        the same way as FIW.
    feature_len:
        The size of the feature vector.
        
    Returns
    ------------
    A mxn matrix, where m is the number of images in the iterable, and n is
    the feature len.
    """
    dims = (len(iterable), feature_len)
    features = np.zeros(dims)
    for idx, img in enumerate(tqdm(iterable)):
        feature_file_name = (FIW_FEATURES / img).with_suffix(".pkl")
        with open(feature_file_name, "rb") as f:
            feature_vector = pickle.load(f)
        features[idx] = feature_vector
    return features

# Finding the best threshold for kinship classification

In [4]:
person_one_features = read_features_from_iterable_of_pictures(training_csv.p1, FIW_FEATURES)
person_two_features = read_features_from_iterable_of_pictures(training_csv.p2, FIW_FEATURES)

HBox(children=(FloatProgress(value=0.0, max=502322.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=502322.0), HTML(value='')))




In [7]:
cosine_similarity_train_faces = np.array(
    [distance.cosine(u, v) for u, v in tqdm(zip(person_one_features, person_two_features))]
)

In [12]:
train_labels = training_csv.label.values.copy()

In [37]:
thresholds = np.arange(1, 0, step=-0.0125)
accuracy_scores = []
for thresh in tqdm(thresholds):
    accuracy_scores.append(accuracy_score(train_labels, cosine_similarity_train_faces > thresh))

accuracies = np.array(accuracy_scores)
max_accuracy = accuracies.max() 
max_accuracy_threshold =  thresholds[accuracies.argmax()]

HBox(children=(FloatProgress(value=0.0, max=80.0), HTML(value='')))




In [38]:
print(max_accuracy)
print(max_accuracy_threshold)

0.526634310263138
0.025000000000003464


The best performing threshold is 0.025, with an accuracy of 0.5266.

# Evaluating performance on the test set

In [39]:
person_one_features_test = read_features_from_iterable_of_pictures(test_csv.p1, FIW_FEATURES)
person_two_features_test = read_features_from_iterable_of_pictures(test_csv.p2, FIW_FEATURES)

HBox(children=(FloatProgress(value=0.0, max=39743.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=39743.0), HTML(value='')))




In [40]:
cosine_similarity_test_faces = np.array(
    [distance.cosine(u, v) for u, v in tqdm(zip(person_one_features_test, person_two_features_test))]
)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [41]:
test_labels = test_csv.label.values.copy()

In [42]:
accuracy_score(test_labels, cosine_similarity_test_faces > max_accuracy_threshold)

0.5004654907782503

We get an accuracy of 0.5.

# Finer grained analysis of verification results
Break down accuracy by relationship types.

In [46]:
test_csv["pred"] = cosine_similarity_test_faces > max_accuracy_threshold

In [58]:
relationship_types = test_csv['ptype'].unique()
accuracy_df = pd.DataFrame(columns=relationship_types, dtype=float)
for rel_type in relationship_types:
    df = test_csv[test_csv.ptype.eq(rel_type)]
    accuracy_df.loc[0, rel_type] = accuracy_score(df.label, df.pred)
    
accuracy_df["avg"] = np.mean(accuracy_df.loc[0, :])

In [59]:
accuracy_df.round(decimals=3)

Unnamed: 0,bb,ss,fd,md,fs,gmgd,gmgs,ms,gfgs,gfgd,sibs,avg
0,0.469,0.514,0.498,0.515,0.5,0.264,0.469,0.533,0.392,0.273,0.566,0.454
