In [None]:
%matplotlib inline

print('Loading libraries... Please wait.')

from IPython.display import display, clear_output
import ipywidgets as widgets
import random
import os
import sys
import pickle
from tqdm import tqdm
import sklearn.externals.joblib as joblib
from sklearn.metrics import confusion_matrix, accuracy_score, roc_curve, auc
from sklearn.calibration import calibration_curve
from multiprocessing import Pool
from subprocess import check_call

from esper.prelude import *
from esper.widget import *
import esper.face_embeddings as face_embeddings


def query_faces(ids):
    faces = Face.objects.filter(id__in=ids)
    return faces.values(
        'id', 'bbox_y1', 'bbox_y2', 'bbox_x1', 'bbox_x2',
        'frame__number', 'frame__video__id', 'frame__video__fps',
        'shot__min_frame', 'shot__max_frame')


def query_sample(qs, n):
    return qs.order_by('?')[:n]


def query_faces_result(faces, expand_bbox=0.05):
    """Replaces qs_to_result"""
    result = []
    for face in faces:
        if (face.get('shot__min_frame') is not None and 
                face.get('shot__max_frame') is not None):
            min_frame = int(
                (face['shot__min_frame'] + 
                 face['shot__max_frame']) / 2)
        else:
            min_frame = face['frame__number']
        face_result = {
            'type': 'flat', 'label': '', 
            'elements': [{
                'objects': [{
                    'id': face['id'],
                    'background': False,
                    'type': 'bbox',
                    'bbox_y1': max(face['bbox_y1'] - expand_bbox, 0),
                    'bbox_y2': min(face['bbox_y2'] + expand_bbox, 1),
                    'bbox_x1': max(face['bbox_x1'] - expand_bbox, 0),
                    'bbox_x2': min(face['bbox_x2'] + expand_bbox, 1),
                }], 
                'min_frame': min_frame,
                'video': face['frame__video__id']
            }]
        }
        result.append(face_result)
    return {'type': 'Face', 'count': 0, 'result': result}


HANDLABELER_NAME = 'handlabeled-gender-validation'
MODEL_LABELER_NAME = 'rudecarnie'
def print_gender_validation_stats(normalize=False, threshold=0.5):
    labeler = Labeler.objects.get(name=HANDLABELER_NAME)
    hand_face_genders = {
        fg['face__id']: fg['gender__id']
        for fg in FaceGender.objects.filter(
            labeler=labeler
        ).values('face__id', 'gender__id')
    }
    gender_id_dict = {g.name: g.id for g in Gender.objects.all()}
    male_count = sum((
        1 for g in hand_face_genders.values() if g == gender_id_dict['M']
    ))
    female_count = sum((
        1 for g in hand_face_genders.values() if g == gender_id_dict['F']
    ))
    print('{} faces have been hand-labeled ({} male, {} female)'.format(
          len(hand_face_genders), male_count, female_count))
    
    y_pred = []
    y_truth = []
    for fg in FaceGender.objects.filter(
        face__id__in=list(hand_face_genders.keys()), 
        labeler__name=MODEL_LABELER_NAME
    ).values('face__id', 'gender__id', 'probability'):
        male_probability = fg['probability'] if fg['gender__id'] == gender_id_dict['M'] else 1 - fg['probability']
        y_pred.append(
            gender_id_dict['M' if male_probability >= threshold else 'F']
        )
        y_truth.append(hand_face_genders[fg['face__id']])
    
    cm = confusion_matrix(y_truth, y_pred)
    
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    classes = ['Male', 'Female']
    plt.figure(figsize=(5, 5))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Reds)
    plt.title('Gender confusion matrix')
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('Hand label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.show()
    
    print('Overall accuracy: {:0.2f}'.format(
          accuracy_score(y_truth, y_pred)))
    
def show_confusion_matrix(y_truth, y_pred, normalize=False):
    cm = confusion_matrix(y_truth, y_pred)
    
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    classes = ['Male', 'Female']
    plt.figure(figsize=(5, 5))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Reds)
    plt.title('Gender confusion matrix')
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('Hand label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.show()
    
    print('Overall accuracy: {:0.2f}'.format(
          accuracy_score(y_truth, y_pred)))
    
print_gender_validation_stats()

In [None]:
ID_SIZE = 8
BYTEORDER = 'little'

def read_ids_file(id_file):
    ids = []
    with open(id_file, 'rb') as f:
        while True:
            next = f.read(ID_SIZE)
            if next == b'':
                break
            assert len(next) == ID_SIZE
            ids.append(int.from_bytes(next, BYTEORDER))
    return ids

print('Reading ids')
face_ids = read_ids_file('/app/data/embs/face_ids.bin')
print('Done reading ids')

In [None]:
BATCH_SIZE = 10000

RESULT_DIR = '/app/data/face_genders_tmp/'
if not os.path.isdir(RESULT_DIR):
    os.makedirs(RESULT_DIR)

def worker(outfile, base_idx, n):
    check_call(['python3', 'gender_knn.py', outfile, str(base_idx), str(n)])
    
with Pool() as workers:
    results = []
    for base_idx in range(0, len(face_ids), BATCH_SIZE):
        outfile = os.path.join(RESULT_DIR, '{}.pkl'.format(base_idx))
        if os.path.exists(outfile):
            continue
        n = min(BATCH_SIZE, len(face_ids) - base_idx)
        results.append((outfile, workers.apply_async(worker, args=(outfile, base_idx, n))))
    
    for outfile, future in tqdm(results):
        future.get()
        assert os.path.exists(outfile)

In [None]:
labeler, created = Labeler.objects.get_or_create(name='knn-gender')
if created:
    print('created:', labeler.name)
else:
    print('exists:', labeler.name)
male = Gender.objects.get(name='M')
female = Gender.objects.get(name='F')

In [None]:
prev_face_id = None
duplicates = 0
missing = 0
progress = 0
for base_idx in range(progress, len(face_ids), BATCH_SIZE):
    outfile = os.path.join(RESULT_DIR, '{}.pkl'.format(base_idx))
    assert os.path.exists(outfile)
    batch_ids = face_ids[base_idx:base_idx + BATCH_SIZE]
    with open(outfile, 'rb') as f:
        batch_preds = pickle.load(f)
    assert len(batch_ids) == len(batch_preds)
    
    if FaceGender.objects.filter(face__id=batch_ids[0], labeler=labeler).count() > 0:
        continue

    batch_face_genders = []
    for face_id, face_pred in zip(batch_ids, batch_preds):
        if face_id == prev_face_id:
            duplicates += 1
            continue
        is_male = face_pred[1] >= 0.5
        batch_face_genders.append(FaceGender(
            face_id=face_id, gender=male if is_male else female,
            labeler=labeler, probability=face_pred[1] if is_male else face_pred[0]
        ))
        prev_face_id = face_id
    try:
        FaceGender.objects.bulk_create(batch_face_genders)
    except:
        valid_batch_ids = {f['id'] for f in Face.objects.filter(id__in=batch_ids).values('id')}
        missing += len(batch_face_genders) - len(valid_batch_idsw
        batch_face_genders = [f for f in batch_face_genders if f.face_id in valid_batch_ids]

        FaceGender.objects.bulk_create(batch_face_genders)
    print('Saved: {}, Duplicates: {}, Missing: {}'.format(
          base_idx + len(batch_ids), duplicates, missing))

In [None]:
duplicates = 10947
missing = 5191