<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc" style="margin-top: 1em;"><ul class="toc-item"><li><span><a href="#Choose-a-person-to-find" data-toc-modified-id="Choose-a-person-to-find-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Choose a person to find</a></span></li><li><span><a href="#Get-starting-image-examples" data-toc-modified-id="Get-starting-image-examples-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Get starting image examples</a></span><ul class="toc-item"><li><span><a href="#Option-1:-fetch-them-from-Google" data-toc-modified-id="Option-1:-fetch-them-from-Google-2.1"><span class="toc-item-num">2.1&nbsp;&nbsp;</span>Option 1: fetch them from Google</a></span></li><li><span><a href="#Option-2:-specify-face-ids-from-the-dataset" data-toc-modified-id="Option-2:-specify-face-ids-from-the-dataset-2.2"><span class="toc-item-num">2.2&nbsp;&nbsp;</span>Option 2: specify face ids from the dataset</a></span></li></ul></li><li><span><a href="#Your-selected-reference-images" data-toc-modified-id="Your-selected-reference-images-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Your selected reference images</a></span></li><li><span><a href="#Curating-an-initial-dataset" data-toc-modified-id="Curating-an-initial-dataset-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Curating an initial dataset</a></span><ul class="toc-item"><li><span><a href="#Getting-negative-examples-(via-sampling)" data-toc-modified-id="Getting-negative-examples-(via-sampling)-4.1"><span class="toc-item-num">4.1&nbsp;&nbsp;</span>Getting negative examples (via sampling)</a></span><ul class="toc-item"><li><span><a href="#Optional:-cleaning-the-negative-samples" data-toc-modified-id="Optional:-cleaning-the-negative-samples-4.1.1"><span class="toc-item-num">4.1.1&nbsp;&nbsp;</span>Optional: cleaning the negative samples</a></span></li></ul></li><li><span><a href="#Getting-an-initial-set-of-positive-examples" data-toc-modified-id="Getting-an-initial-set-of-positive-examples-4.2"><span class="toc-item-num">4.2&nbsp;&nbsp;</span>Getting an initial set of positive examples</a></span></li></ul></li><li><span><a href="#Model-Training" data-toc-modified-id="Model-Training-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>Model Training</a></span><ul class="toc-item"><li><span><a href="#Train-the-model" data-toc-modified-id="Train-the-model-5.1"><span class="toc-item-num">5.1&nbsp;&nbsp;</span>Train the model</a></span></li><li><span><a href="#Visualize-predictions" data-toc-modified-id="Visualize-predictions-5.2"><span class="toc-item-num">5.2&nbsp;&nbsp;</span>Visualize predictions</a></span></li></ul></li></ul></div>

In [None]:
%matplotlib inline

from IPython.display import display, clear_output
from IPython.core.pylabtools import figsize
figsize(12, 5)
import ipywidgets as widgets
import time
import random
import numpy as np
np.warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
from sklearn import metrics

from esper.prelude import *
from esper.stdlib import *
from esper.identity import *
from esper.plot_util import *
from esper.major_canonical_shows import MAJOR_CANONICAL_SHOWS
from esper import embed_google_images

import esper.face_embeddings as face_embeddings

In [None]:
def query_faces(ids):
    faces = Face.objects.filter(id__in=ids)
    return faces.values(
        'id', 'bbox_y1', 'bbox_y2', 'bbox_x1', 'bbox_x2',
        'frame__number', 'frame__video__id', 'frame__video__fps',
        'shot__min_frame', 'shot__max_frame')

def query_sample(qs, n):
    return qs.order_by('?')[:n]

def query_faces_result(faces):
    """Replaces qs_to_result"""
    result = []
    for face in faces:
        if face.get('shot__min_frame') is not None and face.get('shot__max_frame') is not None:
            min_frame = int((face['shot__min_frame'] + face['shot__max_frame']) / 2)
        else:
            min_frame = face['frame__number']
        face_result = {
            'type': 'flat', 'label': '', 
            'elements': [{
                'objects': [{
                    'id': face['id'],
                    'background': False,
                    'type': 'bbox',
                    'bbox_y1': face['bbox_y1'],
                    'bbox_y2': face['bbox_y2'],
                    'bbox_x1': face['bbox_x1'],
                    'bbox_x2': face['bbox_x2'],
                }], 
                'min_frame': min_frame,
                'video': face['frame__video__id']
            }]
        }
        result.append(face_result)
    return {'type': 'Face', 'count': 0, 'result': result}

# Choose a person to find

In [None]:
name = input('Enter a name: ').strip()
assert name != '', 'Name cannot be the empty string'

# Get starting image examples

## Option 1: fetch them from Google

In [None]:
img_dir = embed_google_images.fetch_images(name)

# If the images returned are not satisfactory, rerun the above with extra params:
#     query_extras='' # additional keywords to add to search
#     force=True      # ignore cached images

def flatten(l):
    return [item for sublist in l for item in sublist]

# TODO: use Esper to select images
face_imgs = flatten(load_and_select_faces_from_images(img_dir))
face_embs = embed_google_images.embed_images(face_imgs)
face_ids = []
assert(len(face_embs) == len(face_imgs))

## Option 2: specify face ids from the dataset

In [None]:
face_ids = [
    644710, 4686364, 2678025, 62032, 13248, 4846879, 4804861, 561270, 2651257,
    2083010, 2117202, 1848221, 2495606, 4465870, 3801638, 865102, 3861979, 4146727,
    3358820, 2087225, 1032403, 1137346, 2220864, 5384396, 3885087, 5107580, 2856632,
    335131, 4371949, 533850, 5384760, 3335516
]
# def prompt_for_face_ids():
#     l = input('Enter face ids (separated by commas): ')
#     return {int(x.strip()) for x in l.split(',')}
# face_ids = prompt_for_face_ids()

In [None]:
def load_face_img(face):
    return crop(load_frame(face.frame.video, face.frame.number, []), face)

def confirm_selected_faces():        
    submit_button = widgets.Button(
        layout=widgets.Layout(width='auto'),
        style={'description_width': 'initial'},
        description='Confirm selection',
        disabled=False,
        button_style='',
        tooltip='Submit labels',
        icon='check'
    )
    
    example_faces = query_faces(sorted(face_ids))
    example_selection_widget = esper_widget(
        query_faces_result(example_faces),
        crop_bboxes=True
    )
    
    def on_submit(b):
        ignored_example_face_idxs = set(example_selection_widget.ignored)
        example_selection_widget.close()
        clear_output()

        print('You deselected {} faces.'.format(len(ignored_example_face_idxs)))
        global face_ids, face_imgs, face_embs
        face_ids = [
            f['id'] for i, f in enumerate(example_faces) 
            if i not in ignored_example_face_idxs
        ]
        face_imgs = par_for(load_face_img, Face.objects.filter(id__in=face_ids))
        face_embs = [x for _, x in face_embeddings.get(face_ids)]
    
    submit_button.on_click(on_submit)
    
    display(submit_button)
    display(example_selection_widget)

In [None]:
confirm_selected_faces()

# Your selected reference images

In [None]:
reference_imgs = tile_images(
    [cv2.resize(x, (100, 100)) for x in face_imgs], 
    cols=10, blank_value=255)
def show_reference_imgs():
    print('Your selected reference images for {}.'.format(name))
    plt.figure()
    imshow(reference_imgs)
    plt.tight_layout()
    plt.show()

In [None]:
show_reference_imgs()

# Curating an initial dataset

## Getting negative examples (via sampling)

In [None]:
neg_examples = face_embeddings.sample(10000)

### Optional: cleaning the negative samples

In [None]:
def sort_by_distance(ids):
        dists = face_embeddings.dist(ids, targets=face_embs)
        return [i for _, i in sorted(zip(dists, ids))]

def select_negative_samples(neg_samples):
    submit_button = widgets.Button(
        layout=widgets.Layout(width='auto'),
        style={'description_width': 'initial'},
        description='Confirm selections',
        disabled=False,
        button_style='',
        tooltip='Submit labels',
        icon='check'
    )

    neg_samples_ord = sort_by_distance(neg_samples)
    neg_samples_ord_idxs = {b: a for a, b in enumerate(neg_samples_ord)}
    
    neg_samples_faces = list(query_faces(neg_samples_ord))
    neg_samples_faces.sort(key=lambda f: neg_samples_ord_idxs[f['id']])
    neg_samples_id_set = {f['id'] for f in neg_samples_faces}
    neg_samples = list(filter(
        lambda x: x in neg_samples_id_set, neg_samples_ord))
    selection_widget = esper_widget(
        query_faces_result(neg_samples_faces)
    )

    def on_submit(b):
        ignored_idxs = set(selection_widget.ignored)
        selected_idxs = set(selection_widget.selected)
        selection_widget.close()
        clear_output()
        
        # Add to positive set
        global face_ids, face_embs, neg_examples
        if face_ids:
            for i in selected_idxs:
                face_id = neg_samples_ord[i]
                if face_id not in face_ids:
                    face_ids.append(face_id)
                    _id, emb = face_embeddings.get([face_id])[0]
                    assert _id == face_id
                    face_embs.append(emb)

        # Filter negative set
        neg_examples = [
            _id for _, _id in filter(
                lambda x: x[0] not in ignored_idxs and x[0] not in selected_idxs,
                enumerate(neg_samples_ord))
        ]

        print('You selected {} and ignored {} faces. There are now {} negative samples.'.format(
              len(selected_idxs), len(ignored_idxs),
              len(neg_examples)))
        
    submit_button.on_click(on_submit)
    display(submit_button)
    display(selection_widget)

In [None]:
select_negative_samples(neg_examples)

## Getting an initial set of positive examples

In [None]:
def get_positive_examples(k):
    submit_button = widgets.Button(
        layout=widgets.Layout(width='auto'),
        style={'description_width': 'initial'},
        description='Confirm selections',
        disabled=False,
        button_style='',
        tooltip='Submit labels',
        icon='check'
    )

    pos_samples_and_dists = face_embeddings.knn(targets=face_embs, k=k, max_threshold=1.)
    pos_samples_to_idx = {f[0]: i for i, f in enumerate(pos_samples_and_dists)}
    pos_samples_faces = list(query_faces([x[0] for x in pos_samples_and_dists]))
    pos_samples_faces.sort(key=lambda f: pos_samples_to_idx[f['id']])
    for p in pos_samples_faces:
        _, p['dist'] = pos_samples_and_dists[pos_samples_to_idx[p['id']]]
    
    selection_widget = esper_widget(query_faces_result(pos_samples_faces))

    def on_submit(b):
        selected_idxs = selection_widget.selected
        ignored_idxs = set(selection_widget.ignored)
        max_selected_idx = max(selected_idxs) if len(selected_idxs) > 0 else len(pos_samples_faces)
        clear_output()
        
        global pos_examples
        pos_examples = [
            x['id'] for i, x in enumerate(pos_samples_faces[:max_selected_idx])
            if i not in ignored_idxs
        ]
        print('Accepted {} positive labels'.format(len(pos_examples)))      

    submit_button.on_click(on_submit)
    display(submit_button)
    display(selection_widget)

In [None]:
pos_examples = None
get_positive_examples(k=10000)

# Model Training

In [None]:
if pos_examples is None:
    raise ValueError('No positive training examples! Did you confirm the selection above?')
if neg_examples is None:
    raise ValueError('No negative training examples!')
print('Proceeding with {} positive and {} negative training examples'.format(
    len(pos_examples), len(neg_examples)))

## Train the model

In [None]:
POS_LABEL = 1
NEG_LABEL = 0
NUM_EPOCHS = 40
LEARNING_RATE = 1
L2_PENALTY = 0.00001

def plot_roc(y_true, y_pred, title='Receiver Operating Characteristic'):
    fpr, tpr, threshold = metrics.roc_curve(y_true, y_pred)
    roc_auc = metrics.auc(fpr, tpr)
    plt.figure()
    plt.title(title)
    plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
    plt.legend(loc = 'lower right')
    plt.plot([0, 1], [0, 1],'r--')
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
    plt.show()
    
def plot_binary_score_histograms(y_true, y_pred, y_max=None, 
                                 title='Score Distribution by Class'):
    bins = np.linspace(0, 1, 100)
    plt.figure()
    plt.hist([x for i, x in enumerate(y_pred) if y_true[i] == POS_LABEL], 
             bins, alpha=0.5, label=name)
    plt.hist([x for i, x in enumerate(y_pred) if y_true[i] == NEG_LABEL], 
             bins, alpha=0.5, label='Not {}'.format(name))
    plt.title(title)
    plt.xlabel('Predicted Score')
    if y_max is not None: 
        plt.ylim(0, y_max)
    plt.legend()
    plt.show()
    
def plot_score_histogram(values, title='Score Distribution'):
    bins = np.linspace(0, 1, 100)
    plt.figure()
    plt.hist(values, bins, alpha=1)
    plt.title(title)
    plt.xlabel('Predicted Score')
    plt.yscale('log', nonposy='clip')
    plt.show()
    
def split_list(l, idx):
    return l[:idx], l[idx:]
    
def train_model(train_val_ratio=10):
    print('Training logistic classifier with {}:1 train to validation split'.format(
          train_val_ratio))
    
    pos_examples_copy = pos_examples.copy()
    random.shuffle(pos_examples_copy)
    pos_split_idx = int(len(pos_examples_copy) / train_val_ratio)
    val_pos, train_pos = split_list(pos_examples_copy, pos_split_idx)
    
    neg_examples_copy = neg_examples.copy()
    random.shuffle(neg_examples_copy)
    neg_split_idx = int(len(neg_examples_copy) / train_val_ratio)
    val_neg, train_neg = split_list(neg_examples_copy, neg_split_idx)
    
    train_ids = train_pos + train_neg
    train_y = ([POS_LABEL] * len(train_pos)) + ([NEG_LABEL] * len(train_neg))
    
    val_ids = val_pos + val_neg
    val_y = ([POS_LABEL] * len(val_pos)) + ([NEG_LABEL] * len(val_neg))
    
    weights, predictions = face_embeddings.logreg(
        train_ids, train_y,
        0, 1, num_epochs=NUM_EPOCHS, learning_rate=LEARNING_RATE, 
        l2_penalty=L2_PENALTY)
    
    train_id_to_idx = {v: i for i, v in enumerate(train_ids)}
    train_pred_y = [0] * len(train_ids)
    val_id_to_idx = {v: i for i, v in enumerate(val_ids)}
    val_pred_y = [0] * len(val_ids)
    
    for v, s in predictions:
        if v in train_id_to_idx:
            train_pred_y[train_id_to_idx[v]] = s
        if v in val_id_to_idx:
            val_pred_y[val_id_to_idx[v]] = s
            
    num_tabs = 3
    outputs = [widgets.Output() for _ in range(num_tabs)]
    tabs = widgets.Tab(children=outputs)
    
    with outputs[0]:
        tabs.set_title(0, 'Training Set')
        plot_roc(train_y, train_pred_y)
        plot_binary_score_histograms(train_y, train_pred_y)
        
    with outputs[1]:
        tabs.set_title(1, 'Validation Set')
        plot_roc(val_y, val_pred_y)
        plot_binary_score_histograms(val_y, val_pred_y)
        
    with outputs[2]:
        tabs.set_title(2, 'Entire Dataset')
        plot_score_histogram(
            [x[1] for x in random.sample(predictions, 100000)],
            title='Score Distribution (Random Subample: 100,000)')
    
    display(tabs)
    return weights, predictions

In [None]:
weights, predictions = train_model()

## Visualize predictions

In [None]:
sample_size_text = widgets.BoundedIntText(
    style={'description_width': 'initial'},
    value=100,
    min=1,
    max=10000,
    description='Sample size:',
    disabled=False
)

sample_sort_button = widgets.ToggleButtons(
    style={'description_width': 'initial'},
    options=['random', 'descending distance', 'ascending distance'],
    value='descending distance',
    description='Sample sort:',
    disabled=False,
    orientation='horizontal'
)

score_range_slider = widgets.FloatRangeSlider(
    layout=widgets.Layout(width='100%'),
    style={'description_width': 'initial'},
    value=[0.45, 0.55],
    min=0,
    max=1,
    step=0.05,
    description='Predicted scores:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
)

commercial_filter_button = widgets.ToggleButtons(
    style={'description_width': 'initial'},
    options=['disabled', 'select', 'exclude'],
    value='disabled',
    description='Commercial filter:',
    disabled=False,
    orientation='horizontal'
)

gender_filter_button = widgets.ToggleButtons(
    style={'description_width': 'initial'},
    options=['disabled', 'male', 'female'],
    value='disabled',
    description='Gender filter:',
    disabled=False,
    orientation='horizontal'
)

MAX_HEIGHT = 1.
MIN_HEIGHT = 0.
face_height_slider = widgets.FloatRangeSlider(
    layout=widgets.Layout(width='100%'),
    style={'description_width': 'initial'},
    value=[MIN_HEIGHT, MAX_HEIGHT],
    min=0,
    max=1,
    step=0.05,
    description='Face height (proportion):',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
)

MAX_SHARPNESS = 1000.
MIN_SHARPNESS = 0.
face_sharpness_slider = widgets.FloatRangeSlider(
    layout=widgets.Layout(width='100%'),
    style={'description_width': 'initial'},
    value=[MIN_SHARPNESS, MAX_SHARPNESS],
    min=MIN_SHARPNESS,
    max=MAX_SHARPNESS,
    step=0.5,
    description='Face sharpness:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
)

caption_filter_button = widgets.ToggleButtons(
    style={'description_width': 'initial'},
    options=['disabled', 'mentioned', 'not mentioned'],
    value='disabled',
    description='Captions filter:',
    disabled=False,
    orientation='horizontal'
)

canonical_show_dropdown = widgets.Dropdown(
    layout=widgets.Layout(width='100%'),
    style={'description_width': 'initial'},
    options=['All'] + list(sorted(MAJOR_CANONICAL_SHOWS)),
    value='All',
    description='Show filter:',
    disabled=False,
)

def get_vis_args():
    score_range = score_range_slider.value
    height_range = face_height_slider.value
    sharpness_range = face_sharpness_slider.value
    return {
        'sample_size': sample_size_text.value,
        'sample_sort': sample_sort_button.value,
        'score_range': score_range,
        'commercial_filter': commercial_filter_button.value,
        'gender_filter': gender_filter_button.value,
        'height_range': height_range,
        'sharpness_range': sharpness_range,
        'caption_filter': caption_filter_button.value,
        'canonical_show': canonical_show_dropdown.value
    }

In [None]:
display(sample_size_text)
display(sample_sort_button)
display(score_range_slider)
display(commercial_filter_button)
display(gender_filter_button) # TODO
display(canonical_show_dropdown)
display(face_height_slider)
display(face_sharpness_slider)
display(caption_filter_button)

In [None]:
def sort_faces_by_distance(faces, ascending=False):
    ids = [f['id'] for f in faces]
    id_to_dist = {
        k: v for k, v in zip(ids, face_embeddings.dist(ids, targets=face_embs))
    }
    order_const = 1 if ascending else -1
    faces.sort(key=lambda x: order_const * id_to_dist[x['id']])
    return faces
    
def visualize():
    vis_args = get_vis_args()
    
    submit_button = widgets.Button(
        layout=widgets.Layout(width='auto'),
        style={'description_width': 'initial'},
        description='Confirm selections',
        disabled=False,
        button_style='',
        tooltip='Submit labels',
        icon='check'
    )
    
    labeled_ids_set = set(pos_examples) | set(neg_examples)
    
    min_score, max_score = vis_args['score_range']
    def pre_query_filter_fn(face_id_and_score):
        face_id, score = face_id_and_score
        if face_id in labeled_ids_set:
            return False
        return score >= min_score and score <= max_score
    
    def query_filter_fn(qs):
        if vis_args['commercial_filter'] != 'disabled':
            qs = qs.filter(
                shot__in_commercial=vis_args['commercial_filter'] == 'select')
        
        min_height, max_height = vis_args['height_range']
        if min_height > MIN_HEIGHT or max_height < MAX_HEIGHT:
            qs = qs.annotate(height=BoundingBox.height_expr())
            min_height = min_height
            if min_height > MIN_HEIGHT:
                qs = qs.filter(height__gte=min_height)
            max_height = max_height
            if max_height < MAX_HEIGHT:
                qs = qs.filter(height__lte=max_height)
        
        min_sharpness, max_sharpness = vis_args['sharpness_range']
        if min_sharpness > MIN_SHARPNESS:
            qs = qs.filter(blurriness__gte=min_sharpness)
        if max_sharpness < MAX_SHARPNESS:
            qs = qs.filter(blurriness__lte=max_sharpness)
      
        if vis_args['canonical_show'] != 'All':
            qs = qs.filter(
                frame__video__show__canonical_show__name=vis_args['canonical_show'])
            
        if vis_args['gender_filter'] != 'disabled':
            print('Warning: gender filter is not implemented yet')
        
        if vis_args['caption_filter'] != 'disabled':
            print('Warning: caption filter is not implemented yet')
        return qs
    
    filtered_pred = list(filter(pre_query_filter_fn, predictions))
    filtered_pred_faces = query_faces([x[0] for x in filtered_pred])
    filtered_pred_faces = query_filter_fn(filtered_pred_faces)
    filtered_count = filtered_pred_faces.count()
    sample_size = vis_args['sample_size']
    if filtered_count > sample_size:
        filtered_pred_faces = query_sample(filtered_pred_faces, sample_size)
    filtered_pred_faces = list(filtered_pred_faces)
    
    print('Showing {} of {} faces'.format(
          min(sample_size, filtered_count), filtered_count))
    
    # Reorder the samples
    if vis_args['sample_sort'] != 'disabled':
        filtered_pred_faces = sort_faces_by_distance(
            filtered_pred_faces,
            'ascending' in vis_args['sample_sort'])

    selection_widget = esper_widget(
        query_faces_result(filtered_pred_faces), crop_bboxes=True)

    def on_submit(b):
        selected_idxs = set(selection_widget.selected)
        ignored_idxs = set(selection_widget.ignored)
        clear_output()
        
        selected_face_ids = []
        ignored_face_ids = []
        for i, f in enumerate(filtered_pred_faces):
            if i in selected_idxs:
                selected_face_ids.append(f['id'])
            if i in ignored_idxs:
                ignored_face_ids.append(f['id'])
              
        new_pos_labels = 0
        for i in selected_face_ids:
            if i not in labeled_ids_set:
                pos_examples.append(i)
                new_pos_labels += 1
        new_neg_labels = 0
        for i in ignored_face_ids:
            if i not in labeled_ids_set:
                neg_examples.append(i)
                new_neg_labels += 1
                
        print('Added {} new positive and {} new negative examples'.format(
              new_pos_labels, new_neg_labels))
        visualize()
    
    submit_button.on_click(on_submit)
    display(submit_button)
    display(selection_widget)

visualize()