In [7]:
from IPython.display import display, clear_output
import ipywidgets as widgets
from collections import Counter, namedtuple
import sys
import math
import json
from datetime import datetime
from pytz import timezone
import numpy as np
import matplotlib.pyplot as plt

print('Initializing notebook. Please wait...', file=sys.stderr)

import esper.captions as captions
from captions.util import PostingUtil
from esper.major_canonical_shows import MAJOR_CANONICAL_SHOWS
from esper.widget import *
from esper.rekall import *
from rekall.interval_list import IntervalList
from captions import CaptionIndex

WIDGET_STYLE_ARGS = {'description_width': 'initial'}

GroundTruth = namedtuple('GroundTruth', ['positive', 'negative'])
        

def extend_postings(postings, threshold):
    # This does a merge with threshold
    return PostingUtil.deoverlap(postings, threshold)


def extend_postings_with_context(keys, contexts, threshold):
    results = []
    for key_p in keys:
        for context_p in contexts:
            if context_p.start >= key_p.start and context_p.start - key_p.end <= threshold:
                key_p = PostingUtil.merge(key_p, context_p)
        for context_p in contexts[::-1]:
            if context_p.start <= key_p.start and key_p.start - context_p.end <= threshold:
                key_p = PostingUtil.merge(key_p, context_p)
        results.append(key_p)
    return extend_postings(results, threshold)


def filter_dict(d, keys):
    return {k: v for k, v in d.items() if k in keys}


TopicSegments = namedtuple('TopicSegments', [
   'video_to_key_phrases', 'video_to_context_phrases', 'video_to_segments'
])


def or_queries(queries):
    query = '|'.join('({})'.format(q) for q in queries)
    return query


def filter_video_qs(video_qs, filters):
    if 'show' in filters:
        video_qs = video_qs.filter(show__canonical_show__name=filters['show'])
    if 'channel' in filters:
        video_qs = video_qs.filter(channel__name=filters['channel'])
    if 'start' in filters:
        video_qs = video_qs.filter(time__gte=filters['start'])
    if 'end' in filters:
        video_qs = video_qs.filter(time__lte=filters['end'])
    return video_qs


def _find_segments(key_phrases, context_phrases, filters):    
    video_qs = Video.objects.filter(duplicate=False, corrupted=False)
    video_qs = filter_video_qs(video_qs, filters)
    video_ids = [x['id'] for x in video_qs.values('id')]
    
    # Find the key locations
    video_key_locations = {}
    for d in captions.query_search(or_queries(key_phrases).upper(), video_ids=video_ids):
        doc = captions.get_document(d.id)
        doc_duration = captions.INDEX.document_duration(doc)
        video_key_locations[d.id] = extend_postings(
            PostingUtil.dilate(
                d.postings, KEY_PHRASE_WINDOW_SIZE, doc_duration), 0)
    
    # Search for context locations
    video_context_locations = {}
    if len(context_phrases) > 0:
        for d in captions.query_search(or_queries(context_phrases).upper(), 
                                       video_ids=video_key_locations.keys()):
            video_context_locations[d.id] = list(d.postings)
    
    # Extend the key locations
    video_topic_segments = {}
    for video_id, key_postings in video_key_locations.items():
        story_segments = extend_postings_with_context(
            key_postings, video_context_locations.get(video_id, []),
            CONTEXT_PHRASE_EXTEND_THRESH)
        story_segments = list(filter(
            lambda p: p.end - p.start >= MIN_PROPOSED_SEGMENT_LEN,
            story_segments))
        video_topic_segments[video_id] = story_segments
    return TopicSegments(
        video_key_locations, video_context_locations, video_topic_segments)


CACHED_SEGMENTS_QUERY = None
CACHED_SEGMENTS_RESULT = None


def find_segments(key_phrases, context_phrases, filters):
    print('Searching for segments...'.format(len(key_phrases), len(context_phrases)), 
          file=sys.stderr)
    global CACHED_SEGMENTS_QUERY, CACHED_SEGMENTS_RESULT
    if CACHED_SEGMENTS_QUERY == (key_phrases, context_phrases, filters):
        result = CACHED_SEGMENTS_RESULT
    else:
        result = _find_segments(key_phrases, context_phrases, filters)
        CACHED_SEGMENTS_QUERY = (key_phrases, context_phrases, filters)
        CACHED_SEGMENTS_RESULT = result

    coverage_seconds = sum(sum(p.end - p.start for p in l) 
                           for l in result.video_to_segments.values())
    print('Found {} segments in {} videos covering {:0.2f} minutes.'.format(
        sum(len(l) for l in result.video_to_segments.values()),
        len(result.video_to_segments),
        coverage_seconds / 60
    ), file=sys.stderr)
    return result
    

MIN_TOKEN_COUNT = 10000


def propose_context_phrases(k=192, ncols=8, default_threshold=5.):
    topic_result = find_segments(KEY_PHRASES, CONTEXT_PHRASES, get_filters())
    
    topic_word_counts = Counter()
    for video_id, segments in topic_result.video_to_segments.items():
        d = captions.get_document(video_id)
        for p in segments:
            topic_word_counts.update(captions.INDEX.tokens(d, p.idx, p.len))

    all_words_total = sum(w.count for w in captions.LEXICON)
    topic_words_total = sum(topic_word_counts.values())
    
    def filter_cond(t):
        if t not in captions.LEXICON: 
            return False
        w = captions.LEXICON[t]
        return w.count > MIN_TOKEN_COUNT and w.token not in CONTEXT_PHRASES

    const_expr = math.log(all_words_total) - math.log(topic_words_total) 
    log_pmis = [
        (t, math.log(topic_word_counts[t]) - math.log(captions.LEXICON[t].count) + const_expr)
        for t in topic_word_counts.keys() if filter_cond(t)
    ]
    log_pmis.sort(key=lambda x: -x[1])
    log_pmis = log_pmis[:k]
    
    selections = []
    for t, score in log_pmis:
        token = captions.LEXICON[t].token
        w = widgets.ToggleButton(
            value=score >= default_threshold,
            description=token,
            disabled=False,
            button_style='',
            icon=''
        )
        selections.append((t, w))
    
    submit_button = widgets.Button(
        description='Submit',
        disabled=False,
        button_style='danger'
    )
    def on_submit(b):
        selected_words = []
        for t, w in selections:
            if w.value == True:
                selected_words.append(captions.LEXICON[t].token)
        clear_output()
        print('Added {} words to the context.'.format(len(selected_words)))
        
        global CONTEXT_PHRASES
        CONTEXT_PHRASES.update(selected_words)
        sync_context_widget()
    
    submit_button.on_click(on_submit)
    
    cancel_button = widgets.Button(
        description='Cancel',
        disabled=False,
        button_style=''
    )
    def on_cancel(b):
        clear_output()
    cancel_button.on_click(on_cancel)
    
    hboxes = []
    for i in range(0, len(selections), ncols):
        hboxes.append(widgets.HBox([w for _, w in selections[i:i + ncols]]))
    vbox = widgets.VBox(hboxes)
    display(widgets.HBox([
        widgets.Label(
            'Instructions: Select new context words and hit submit. '
            '(Likely words may already be highlighted.) '),
        submit_button, cancel_button
    ]))
    display(vbox)
    

def display_segments(topic_results, ground_truth, limit=1000, results_per_page=50,
                     selection=None):
    def is_selected(video_id):    
        if selection == 'unlabeled':
            return (video_id not in ground_truth.positive and 
                    video_id not in ground_truth.negative)
        elif selection == 'positive':
            return video_id in ground_truth.positive
        elif selection == 'negative':
            return video_id in ground_truth.negative
        return True
    
    video_to_key_time = {
        video_id : sum(p.end - p.start for p in postings)
        for video_id, postings in topic_results.video_to_key_phrases.items()
        if is_selected(video_id)
    }
    video_to_topic_time = {
        video_id : sum(p.end - p.start for p in postings)
        for video_id, postings in topic_results.video_to_segments.items()
        if is_selected(video_id)
    }
    video_qs = Video.objects.filter(id__in=list(video_to_key_time.keys()), 
                                    duplicate=False, corrupted=False)
    video_to_fps = {
        v['id']: v['fps'] for v in video_qs.values('id', 'fps', 'channel__name')
    }
    if len(video_to_fps) == 0:
        print('No videos to display', file=sys.stderr)
        return
    video_to_key_time = filter_dict(video_to_key_time, video_to_fps)
    video_to_topic_time = filter_dict(video_to_topic_time, video_to_fps)
    
    # For display
    video_order_all = list(sorted(
        video_to_fps.keys(), 
        key=lambda x: -video_to_topic_time.get(x, 0)
    ))
    video_order = video_order_all[:limit]
    video_ids = set(video_order)
                       
    def convert_time(v, t):
        return int(t * video_to_fps[v])
    
    def to_intervallist(video_to_postings):
        return {
            video_id : IntervalList([
                (convert_time(video_id, p.start), convert_time(video_id, p.end), None)
                for p in postings
            ]) 
            for video_id, postings in video_to_postings.items() 
            if video_id in video_ids
        }
    
    def compute_true_time(video_id):
        intervals = ground_truth.positive.get(video_id, [])
        return sum(b - a for a, b in intervals)
    
    # Plot distribution of topic times in videos
    def plot_dist_of_videos(video_order):
        fig, ax1 = plt.subplots(figsize=(7,2))
        x = np.arange(len(video_order))
        y_pred = np.array([video_to_topic_time.get(v, 0) for v in video_order]) / 60
       
        ax1.plot(x, y_pred, color='purple')
        y_true_tmp = [compute_true_time(v) for v in video_order]
        if sum(y_true_tmp) > 0:
            x_true = np.array([i for i, y in enumerate(y_true_tmp) if y > 0])
            y_true = np.array([y for y in y_true_tmp if y > 0]) / 60
            ax1.plot(x_true, y_true, 'x', color='blue')
            y_max = max(np.max(y_pred), np.max(y_true))
        else:
            y_max = np.max(y_pred)
        ax1.fill_betweenx([0, y_max], len(video_ids), alpha=0.2, color='gray')
        ax1.set_ylabel('Minutes', color='purple')
        ax1.tick_params('y', colors='purple')
        ax1.set_ylim(0, y_max)
        ax1.set_xlabel('Video Number')
        ax1.set_xlim(0, len(video_order))
        y_prop = np.cumsum(y_pred)
        y_prop *= 100. / y_prop[-1]
        ax2 = ax1.twinx()
        ax2.plot(x, y_prop, color='black')
        ax2.set_ylabel('Cumulative % Minutes', color='black')
        ax2.tick_params('y', colors='black')
        plt.show()
    
    if DEBUG:
        print('Videos (ordered by descending segment time)')
        plot_dist_of_videos(video_order_all)
    print('Loading {} of {} videos{}... Please wait.'.format(
        len(video_ids), len(video_to_key_time), ' (shaded region)' if DEBUG else ''))
    
    # Convert to intervallists
    video_to_key_intervals = to_intervallist(topic_results.video_to_key_phrases)
    video_to_context_intervals = to_intervallist({
        k: extend_postings(v, 15) 
        # Coalesce context words to reduce memory usage
        for k, v in topic_results.video_to_context_phrases.items()
    })
    video_to_topic_intervals = to_intervallist(topic_results.video_to_segments)
    video_to_commerical_intervals = qs_to_intrvllists(
        Commercial.objects.filter(labeler__name='haotian-commercials',
                                  video__id__in=video_ids))
    
    def ranges_to_intrvllist(v, ranges):
        return IntervalList([
            (convert_time(v, a), convert_time(v, b), None) 
            for a, b in ranges
        ])
    
    video_to_labeled_pos_intervals = {
        v: ranges_to_intrvllist(v, labels)
        for v, labels in ground_truth.positive.items()
        if v in video_ids
    }
    video_to_labeled_neg_intervals = {
        v: ranges_to_intrvllist(v, labels)
        for v, labels in ground_truth.negative.items()
        if v in video_ids
    }
    
    # Display results
    result = intrvllists_to_result(
        video_to_key_intervals, color='green', video_order=video_order)
    add_intrvllists_to_result(result, video_to_context_intervals, color='orange')
    add_intrvllists_to_result(result, video_to_topic_intervals, color='purple')
    add_intrvllists_to_result(result, video_to_commerical_intervals, color='black')
    add_intrvllists_to_result(result, video_to_labeled_pos_intervals, color='blue')
    add_intrvllists_to_result(result, video_to_labeled_neg_intervals, color='red')
    
    video_widget = esper_widget(result, jupyter_keybindings=True,
                                timeline_annotation_keys={';': 4, '\'': 5},
                                results_per_page=results_per_page,
                                show_inline_metadata=True)
    update_button = widgets.Button(
        description='Update ground truth',
        disabled=False,
        button_style='warning'
    )
    def on_update(b):
        selected_idxs = set(video_widget.selected)
        ignored_idxs = set(video_widget.ignored)
        n_pos_segs = 0
        n_neg_segs = 0
        
        def segment_is_ok(seg):
            return 'min_frame' in seg and 'max_frame' in seg
        
        for i, video_id in enumerate(video_order):
            video_fps = video_to_fps[video_id]
            
            pos_segments = []
            neg_segments = []
            if len(video_widget.groups) > 0:
                pos_segments.extend([
                    (
                        int(seg['min_frame']) / video_fps, 
                        int(seg['max_frame']) / video_fps
                    )
                    for seg in video_widget.groups[i]['elements'][4]['segments'] 
                    if segment_is_ok(seg)
                ])
                neg_segments.extend([
                    (
                        int(seg['min_frame']) / video_fps, 
                        int(seg['max_frame']) / video_fps
                    )
                    for seg in video_widget.groups[i]['elements'][5]['segments'] 
                    if segment_is_ok(seg)
                ])
                
            if i in selected_idxs:
                pos_segments.extend([
                    (p.start, p.end)
                    for p in topic_results.video_to_segments[video_id]
                ])
            if i in ignored_idxs:
                neg_segments.extend([
                    (p.start, p.end)
                    for p in topic_results.video_to_segments[video_id]
                ])
            
            n_pos_segs += len(pos_segments)
            if len(pos_segments) > 0:
                if video_id not in ground_truth.positive:
                    ground_truth.positive[video_id] = set()
                ground_truth.positive[video_id].update(pos_segments)

            n_neg_segs += len(neg_segments)
            if len(neg_segments) > 0:
                if video_id not in ground_truth.negative:
                    ground_truth.negative[video_id] = set()
                ground_truth.negative[video_id].update(neg_segments)

        clear_output()
        print('Added {} positive segments and {} negative segments.'.format(
            n_pos_segs, n_neg_segs))
                
    update_button.on_click(on_update)
    display(update_button)
    display(video_widget)

    
def show_filter_widgets():
    channel_filter_button = widgets.Dropdown(
        style=WIDGET_STYLE_ARGS,
        options=['All', 'CNN', 'FOXNEWS', 'MSNBC'],
        value='All',
        description='Channel:',
        disabled=False,
    )
    canonical_show_dropdown = widgets.Dropdown(
        style=WIDGET_STYLE_ARGS,
        options=['All'] + list(sorted(MAJOR_CANONICAL_SHOWS)),
        value='All',
        description='Show:',
        disabled=False,
    )
    start_date_picker = widgets.DatePicker(
        style=WIDGET_STYLE_ARGS,
        description='Start date:',
        disabled=False
    )
    end_date_picker = widgets.DatePicker(
        style=WIDGET_STYLE_ARGS,
        description='End date:',
        disabled=False
    )
    global FILTER_WIDGETS
    FILTER_WIDGETS = {
        'show': canonical_show_dropdown,
        'channel': channel_filter_button,
        'start_date': start_date_picker,
        'end_date': end_date_picker
    }
    display(widgets.HBox([
        channel_filter_button, canonical_show_dropdown, 
        start_date_picker, end_date_picker]))

    
def show_story_widgets():
    status_output = widgets.Output()
    key_widget = widgets.Textarea(
        style=WIDGET_STYLE_ARGS,
        value='',
        layout=widgets.Layout(width='100%'),
        placeholder='Phrases (one per line)',
        description='Key phrases:',
        disabled=False
    )
    global sync_key_widget
    def sync_key_widget():
        key_widget.value = '\n'.join(sorted(KEY_PHRASES))
        computed_height = 20 * (len(KEY_PHRASES) + 2)
        key_widget.layout = widgets.Layout(
            width='100%', 
            height='{}px'.format(computed_height)
        )
    def on_key_changed(b):
        with status_output:
            clear_output()
            try:
                global KEY_PHRASES
                KEY_PHRASES = {
                    t.strip() for t in key_widget.value.split('\n')
                    if len(t.strip()) > 0
                }
            except Exception as e:
                print(e)
    key_widget.observe(on_key_changed, names='value')

    context_widget = widgets.Textarea(
        value='',
        style=WIDGET_STYLE_ARGS,
        layout=widgets.Layout(width='100%'),
        placeholder='Phrases (one per line)',
        description='Context phrases:',
        disabled=False
    )
    global sync_context_widget
    def sync_context_widget():
        context_widget.value = '\n'.join(sorted(CONTEXT_PHRASES))
        max_height = 250
        computed_height = 20 * (len(CONTEXT_PHRASES) + 2)
        context_widget.layout = widgets.Layout(
            width='100%', 
            height='{}px'.format(min(max_height, computed_height))
        )
    def on_context_changed(b):
        with status_output:
            clear_output()
            try:
                global CONTEXT_PHRASES
                CONTEXT_PHRASES = {
                    t.strip() for t in context_widget.value.split('\n') 
                    if len(t.strip()) > 0
                }
            except Exception as e:
                print(e)
    context_widget.observe(on_context_changed, names='value')

    sort_button = widgets.Button(
        description='Sort phrases',
        disabled=False,
        button_style=''
    )
    def on_sort(b):
        sync_key_widget()
        sync_context_widget()
    sort_button.on_click(on_sort)

    show_filter_widgets()
    display(key_widget)
    display(context_widget)
    display(sort_button)
    display(status_output)
    sync_key_widget()
    sync_context_widget()

                     
def get_filters():
    filters = {}
    show = FILTER_WIDGETS['show'].value
    if show != 'All':
        filters['show'] = show
    channel = FILTER_WIDGETS['channel'].value
    if channel != 'All':
        filters['channel'] = channel
    if FILTER_WIDGETS['start_date'].value:
        filters['start'] = FILTER_WIDGETS['start_date'].value
    if FILTER_WIDGETS['end_date'].value:
        filters['end'] = FILTER_WIDGETS['end_date'].value 
    return filters

                     
def show_video_controls():
    show_videos_output = widgets.Output()
    limit_slider = widgets.BoundedIntText(
        style=WIDGET_STYLE_ARGS,
        value=1000,
        min=1,
        max=10000,
        description='Video limit:',
        disabled=False,
    )
    results_per_page_slider = widgets.BoundedIntText(
        style=WIDGET_STYLE_ARGS,
        value=40,
        min=1,
        max=100,
        description='Results per page:',
        disabled=False,
    )
    show_videos_button = widgets.Button(
        style=WIDGET_STYLE_ARGS,
        description='Show videos',
        disabled=False,
        button_style='danger'
    )
    filter_videos_dropdown = widgets.Dropdown(
        style=WIDGET_STYLE_ARGS,
        options=['All', 'Unlabeled', 'Labeled Positive', 'Labeled Negative'],
        value='All',
        description='Videos:',
        disabled=False
    )
    def on_show_videos(b):
        with show_videos_output:
            clear_output()
            if filter_videos_dropdown.value == 'All':
                selection = 'all'
            elif filter_videos_dropdown.value == 'Unlabeled':
                selection = 'unlabeled'
            elif filter_videos_dropdown.value == 'Labeled Positive':
                selection = 'positive'
            elif filter_videos_dropdown.value == 'Labeled Negative':
                selection = 'negative'
            else:
                raise Exception('Unknown option...')
            topic_results = find_segments(KEY_PHRASES, CONTEXT_PHRASES, get_filters())
            display_segments(
                topic_results, GROUND_TRUTH,
                limit=limit_slider.value,
                results_per_page=results_per_page_slider.value,
                selection=selection
            )
    show_videos_button.on_click(on_show_videos)
    clear_videos_button = widgets.Button(
        style=WIDGET_STYLE_ARGS,
        description='Dismiss videos',
        disabled=False,
        button_style=''
    )
    def on_clear_videos(b):
        with show_videos_output:
            clear_output()
    clear_videos_button.on_click(on_clear_videos)
    display(widgets.HBox([
        limit_slider, results_per_page_slider, filter_videos_dropdown]))
    display(widgets.HBox([show_videos_button, clear_videos_button]))
    display(show_videos_output)
    
    
def init_global_variables():
    global KEY_PHRASES, CONTEXT_PHRASES
    try:
        KEY_PHRASES, CONTEXT_PHRASES
    except NameError:
        KEY_PHRASES = set()
        CONTEXT_PHRASES = set()

    global GROUND_TRUTH
    try:
        GROUND_TRUTH
    except NameError:
        GROUND_TRUTH = GroundTruth({}, {})
    
init_global_variables()

print('Done initializing notebook.', file=sys.stderr)

Initializing notebook. Please wait...
Done initializing notebook.


Some constants to help with visualization.

In [8]:
KEY_PHRASE_WINDOW_SIZE = 5
CONTEXT_PHRASE_EXTEND_THRESH = 120
MIN_PROPOSED_SEGMENT_LEN = 30
DEBUG = False

# Stories from a Lexicon

Stories are retreived via lexicons of words. You can search for a story by defining a set of <b>key phrases</b>. This will find all segments in the data set where the phrases appear. 

For instance, if you are looking for segments about 'Hurricane Irma', then the relevant key phrases may include 'Irma' and 'Hurricane Irma'. Searching for segment will retreive all of these mentions. Note that in 'hurricane' would be a poor key phrase in this case because it will match all hurricane segments, regardless of whether they are about Irma or not. Ideally, your key phrases should be unique to the story.

It can be useful to search for and visualize additional phrases, in addition to the key phrases. <b>Context phrases</b> are phrases that relevant to the story, but not unique to it. For instance, words such as 'devastation' and 'storm' will be used in the context of 'Hurricane Irma' but also in context of other hurricanes and weather disasters. A later cell will plot these words on the same timeline, and presence of these words are used to prioritize the order in which results are presented. However, adding <b>context phrases</b> will change the set of videos shown as those are defined purely by <b>key phrases</b>. 

<b>Instructions:</b>
- Enter relevant filters.
- Enter key phrases to start (required; see caption-index query syntax)
- Enter a few context phrases (optional)
- Show videos (see videos section)

In [9]:
show_story_widgets()

HBox(children=(Dropdown(description='Channel:', options=('All', 'CNN', 'FOXNEWS', 'MSNBC'), style=DescriptionS…

Textarea(value='', description='Key phrases:', layout=Layout(width='100%'), placeholder='Phrases (one per line…

Textarea(value='', description='Context phrases:', layout=Layout(width='100%'), placeholder='Phrases (one per …

Button(description='Sort phrases', style=ButtonStyle())

Output()

## Videos

Show videos and retreived topic segments with a timeline. You must have hit 'search for segments' prior to running this.

Timeline colors:
- Green = key phrases
- Orange = context phrases
- Purple = proposed story segment
- Grey = commercial

Timeline (Human Labeled) colors:
- Blue = labeled positive segment
- Red = labeled negative segment

Videos will be ordered by descending amount of proposed time identified as the story.

Select postive segments with <b>;</b> and negative segments with <b>'</b>. Use <b>[</b> and <b>]</b> to accept or reject all proposed story segments in the video.

In [10]:
show_video_controls()

HBox(children=(BoundedIntText(value=1000, description='Video limit:', max=10000, min=1, style=DescriptionStyle…

HBox(children=(Button(button_style='danger', description='Show videos', style=ButtonStyle()), Button(descripti…

Output()

## Automatically propose context words

Once we have some segments corresponding to the lexicon, we can use NLP to propose new context words. `propose_context_phrases()` will use statistics to suggest new context words.

In [None]:
propose_context_phrases()

## Analysis

Run `analysis()` to compute statistics over the story segments retreived. These graphs will respond to the filters earlier.

In [11]:
MALE_ID = 1
FEMALE_ID = 2


try:
    _FACE_IDENTS
except NameError:
    _FACE_IDENTS = None


def get_face_idents():
    face_idents_path = '/app/data/stories-data/identities_by_video.json'
    global _FACE_IDENTS
    if _FACE_IDENTS is None:
        print('Loading face identities...', file=sys.stderr)
        with open(face_idents_path, 'r') as f:
            _FACE_IDENTS = json.load(f)
    else:
        pass
    return _FACE_IDENTS


try:
    _FACE_GENDERS_AND_RACE
except NameError:
    _FACE_GENDERS_AND_RACE = None


def get_face_genders():
    face_gender_race_path = '/app/data/stories-data/genders_race_by_video.json'
    global _FACE_GENDERS_AND_RACE
    if _FACE_GENDERS_AND_RACE is None:
        print('Loading face genders...', file=sys.stderr)
        with open(face_gender_race_path, 'r') as f:
            _FACE_GENDERS_AND_RACE = json.load(f)
    else:
        pass
    return _FACE_GENDERS_AND_RACE


def _analysis(topic_results):
    video_qs = Video.objects.filter(
        id__in=list(topic_results.video_to_key_phrases.keys()), 
        duplicate=False, corrupted=False)
    video_to_meta = {
        v['id']: {
            'channel': v['channel__name'],
            'show': v['show__canonical_show__name'],
            'time': v['time'],
            'fps': v['fps'],
            'is_3y': v['threeyears_dataset'],
            'path': v['path']
        } for v in video_qs.values(
            'id', 'channel__name', 'show__canonical_show__name', 'time', 'fps',
            'threeyears_dataset', 'path'
        )
    }
    if len(video_to_meta) == 0:
        print('No videos to analyze.',
              file=sys.stderr)
        return
    
    clear_button = widgets.Button(
        description='Clear Analysis',
        disabled=False,
        button_style=''
    )
    def on_clear(b):
        clear_output()
    clear_button.on_click(on_clear)
    display(clear_button)
    
    channels = [c.name for c in Channel.objects.all()]
    utc = timezone('UTC')
    eastern = timezone('US/Eastern')
    
    channel_to_time = {c: 0. for c in channels}
    channel_to_daypart_to_time = {c: np.zeros(24) for c in channels}
    channel_to_weekday_to_time = {c: np.zeros(7) for c in channels}
    channel_to_time_to_time = {c: defaultdict(float) for c in channels}
    show_to_time = Counter()
    for video_id, postings in topic_results.video_to_segments.items():
        if video_id not in video_to_meta:
            continue

        video_topic_len = sum(p.end - p.start for p in postings)
        channel = video_to_meta[video_id]['channel']
        channel_to_time[channel] += video_topic_len
        
        video_dt = utc.localize(video_to_meta[video_id]['time']).astimezone(eastern)
        for p in postings:
            base_hour = video_dt.hour
            posting_len = p.end - p.start
            channel_to_daypart_to_time[channel][
                (base_hour + int(p.start / 3600)) % 24
            ] += posting_len
            
        channel_to_weekday_to_time[channel][video_dt.weekday()] += video_topic_len
        channel_to_time_to_time[channel][video_dt.date()] += video_topic_len
        
        show = video_to_meta[video_id]['show']
        show_to_time[(channel, show)] += video_topic_len
        
    print('Topic time by channel:')
    for c in channel_to_time:
        print('  {}: {:0.3f} hours'.format(c, channel_to_time[c] / 3600))
        
    print('\nTopic time by day:')
    def plot_timeline():
        plt.figure(figsize=(11, 3))
        bar_width = 1 / (len(channels) + 1)
        for c in channels:
            data = [x for x in sorted(channel_to_time_to_time[c].items())]
            plt.scatter(
                [x for x, _ in data], [y / 60 for _, y in data],
                alpha=0.5, s=2, label=c)
        plt.legend()
        plt.ylabel('Minutes')
        plt.xlabel('Day')
        plt.show()
    plot_timeline()
        
    print('\nTopic time by daypart:')
    def plot_daypart():
        plt.figure(figsize=(11,3))
        bar_width = 1 / (len(channels) + 1)
        for i, c in enumerate(channels):
            plt.bar(np.arange(24) + (i - 1) * bar_width,
                    channel_to_daypart_to_time[c] / 60, 
                    width=bar_width, alpha=0.5, label=c)
        plt.xticks(np.arange(24))
        plt.legend()
        plt.ylabel('Minutes')
        plt.xlabel('Hour of Day')
        plt.show()
    plot_daypart()
    
    print('\nTopic time by weekday:')
    def plot_weekday():
        plt.figure(figsize=(11,3))
        bar_width = 1 / (len(channels) + 1)
        for i, c in enumerate(channels):
            plt.bar(np.arange(7) + (i - 1) * bar_width, 
                    channel_to_weekday_to_time[c] / 60, 
                    width=bar_width, alpha=0.5, label=c)
        plt.xticks(np.arange(7), ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'])
        plt.legend()
        plt.ylabel('Minutes')
        plt.xlabel('Weekday')
        plt.show()
    plot_weekday()
    
    top_n = 10
    print('\nShows with most coverage (top-{}):'.format(top_n))
    for (channel, show), seconds in show_to_time.most_common(top_n):
        print('  {} ({}): {:0.1f} minutes'.format(show, channel, seconds / 60))

    def join_face_labels_and_postings(labels, postings, label_len=3):
        result = Counter()
        try:
            label_head = next(labels)
            postings_head = next(postings)
            while True:
                if label_head[1] > postings_head.end:
                    postings_head = next(postings)
                elif label_head[1] + label_len < postings_head.start:
                    label_head = next(labels)
                else:
                    result[label_head[0]] += label_len
                    label_head = next(labels)
        except StopIteration:
            pass
        return result
    
    # Time by Gender and Race
    face_genders = get_face_genders()
    gender_to_time = Counter()
    gender_to_time_host = Counter()
    gender_to_time_black = Counter()
    gender_to_time_black_host = Counter()
    channel_to_gender_to_time = defaultdict(lambda: Counter())
    channel_to_gender_to_time_host = defaultdict(lambda: Counter())
    channel_to_gender_to_time_black = defaultdict(lambda: Counter())
    channel_to_gender_to_time_black_host = defaultdict(lambda: Counter())
    for video_id, postings in topic_results.video_to_segments.items():
        if video_id not in video_to_meta:
            continue
        video_genders = face_genders.get(str(video_id), [])

        # Compute for all faces
        video_story_genders = join_face_labels_and_postings(
            iter(video_genders), iter(postings))

        gender_to_time.update(video_story_genders)
        channel_to_gender_to_time[
            video_to_meta[video_id]['channel']
        ].update(video_story_genders)

        # Compute for hosts only
        video_story_genders_host = join_face_labels_and_postings(
            filter(lambda x: x[-1] == 1, video_genders), iter(postings))
        gender_to_time_host.update(video_story_genders_host)
        channel_to_gender_to_time_host[
            video_to_meta[video_id]['channel']
        ].update(video_story_genders_host)
        
        # Compute for black only
        video_story_genders_black = join_face_labels_and_postings(
            filter(lambda x: x[-2] == 1, video_genders), iter(postings))
        gender_to_time_black.update(video_story_genders_black)
        channel_to_gender_to_time_black[
            video_to_meta[video_id]['channel']
        ].update(video_story_genders_black)
        
        # Compute for black hosts only
        video_story_genders_black_host = join_face_labels_and_postings(
            filter(lambda x: x[-2] == 1 and x[-1] == 1, video_genders), iter(postings))
        gender_to_time_black_host.update(video_story_genders_black_host)
        channel_to_gender_to_time_black_host[
            video_to_meta[video_id]['channel']
        ].update(video_story_genders_black_host)
    channels = list(sorted(channel_to_gender_to_time))
    
    def plot_gender_screen_time(data):
        male_props = []
        totals = []
        for name, gender_to_time in data:
            total = sum(gender_to_time[k] for k in gender_to_time)
            male_prop = gender_to_time[MALE_ID] / total
            male_props.append(male_prop)
        
        x = np.arange(len(data))
        names = [x[0] for x in data]
        male_props = np.array(male_props)
        width = 0.4
        
        plt.figure(figsize=(11,3))
        ax = plt.gca()
        p1 = ax.bar(x - width / 2, male_props * 100, width,
                    color='lightblue', label='Men')
        p2 = ax.bar(x + width / 2, (-male_props + 1.) * 100, width,
                    color='salmon', label='Women')
        for i in range(len(male_props)):
            p = male_props[i]
            ax.text(i - width / 2, p * 100 + 5, 
                    str(round(p * 100, 1)), ha='center', color='darkblue')
            ax.text(i + width / 2, (1 - p) * 100 + 5, 
                    str(round((1 - p) * 100, 1)), ha='center',
                    color='darkred')
        plt.ylabel('Percentage')
        plt.ylim(0, 100)
        plt.xticks(x, names, rotation=45, ha='right')
        plt.legend()
        plt.show()
    
    print('\nFace screen time by gender:')
    gender_screen_time_data = []
    gender_screen_time_data.append(('All channels', gender_to_time))
    for channel in channels:
        gender_screen_time_data.append(
            (channel, channel_to_gender_to_time[channel]))
    gender_screen_time_data.append(
        ('All channels (hosts)', gender_to_time_host))
    for channel in channels:
        gender_screen_time_data.append(
            ('{} (hosts)'.format(channel), 
             channel_to_gender_to_time_host[channel]))
    gender_screen_time_data.append(
        ('All channels (non-hosts)', 
         gender_to_time - gender_to_time_host))
    for channel in channels:
        gender_screen_time_data.append(
            ('{} (non-hosts)'.format(channel), 
             channel_to_gender_to_time[channel] 
             - channel_to_gender_to_time_host[channel]))
        
    plot_gender_screen_time(gender_screen_time_data)
    
    def plot_black_screen_time(data):
        male_props = []
        female_props = []
        totals = []
        for name, total_face_time, gender_to_time in data:
            male_prop = gender_to_time[MALE_ID] / total_face_time
            female_prop = gender_to_time[FEMALE_ID] / total_face_time
            male_props.append(male_prop)
            female_props.append(female_prop)
            totals.append(total_face_time)
        
        x = np.arange(len(data))
        names = [x[0] for x in data]
        male_props = np.array(male_props)
        female_props = np.array(female_props)
        width = 0.8
        
        plt.figure(figsize=(11,3))
        ax = plt.gca()
        p1 = ax.bar(x, male_props * 100, width, 
                    color='lightblue', label='Black Men')
        p2 = ax.bar(x, female_props * 100, width,
                    bottom=male_props * 100,
                    color='salmon', label='Black Women')
        for i in range(len(male_props)):
            pm = male_props[i]
            pw = female_props[i]
            ax.text(i, pm * 100 / 2, 
                    str(round(pm * 100, 1)), ha='center', va='center',
                    color='darkblue')
            ax.text(i, (pm + pw / 2) * 100, 
                    str(round(pw * 100, 1)), ha='center', va='center',
                    color='darkred')
            ax.text(i, (pm + pw) * 100 + 5, 
                    str(round((pm + pw) * 100, 1)), ha='center',
                    color='black')
        plt.ylabel('Percentage')
        plt.ylim(0, 100)
        plt.xticks(x, names, rotation=45, ha='right')
        plt.legend()
        plt.show()
    
    print('\nBlack screen time:')
    black_screen_time_data = []
    black_screen_time_data.append(
        ('All channels', sum(gender_to_time.values()), gender_to_time_black))
    for channel in channels:
        black_screen_time_data.append(
            (channel, 
             sum(channel_to_gender_to_time[channel].values()), 
             channel_to_gender_to_time_black[channel]))
    black_screen_time_data.append(
        ('All channels (hosts)', sum(gender_to_time_host.values()),
         gender_to_time_black_host))
    for channel in channels:
        black_screen_time_data.append(
            ('{} (hosts)'.format(channel), 
             sum(channel_to_gender_to_time_host[channel].values()),
             channel_to_gender_to_time_black_host[channel]))
    black_screen_time_data.append(
        ('All channels (non-hosts)',
         sum((gender_to_time - gender_to_time_host).values()),
         gender_to_time_black - gender_to_time_black_host))
    for channel in channels:
        black_screen_time_data.append(
            ('{} (non-hosts)'.format(channel), 
             sum((channel_to_gender_to_time[channel] 
                  - channel_to_gender_to_time_host[channel]).values()),
             channel_to_gender_to_time_black[channel] 
             - channel_to_gender_to_time_black_host[channel]))
    plot_black_screen_time(black_screen_time_data)
    
    # Time by Identity
    face_idents = get_face_idents()
    ident_id_to_time = Counter()
    for video_id, postings in topic_results.video_to_segments.items():
        if video_id not in video_to_meta:
            continue
        video_idents = face_idents.get(str(video_id), [])
        ident_id_to_time.update(join_face_labels_and_postings(iter(video_idents), iter(postings)))
    top_n = 10
    print('\nPeople with most screen time (top-{}):'.format(top_n))
    for ident_id, seconds in ident_id_to_time.most_common(top_n):
        print('  {}: {:0.1f} minutes'.format(
            Identity.objects.get(id=ident_id).name, 
            seconds / 60))


def analysis():
    topic_results = find_segments(KEY_PHRASES, CONTEXT_PHRASES, get_filters())
    _analysis(topic_results)


def analysis_handlabeled():
    topic_segments = TopicSegments(
        {
            k: [
              CaptionIndex.Posting(a, b, None, None) for a, b in v  
            ] for k, v in GROUND_TRUTH.positive.items()
        }, None, None)
    _analysis(topic_segments)

In [None]:
analysis()

Run `analysis()` only on labeled segments.

In [None]:
analysis_handlabeled()

# Saving & Loading Progress

Save your progress. Locally.

In [12]:
STORY_DIRECTORY = '/app/data/stories/'
if not os.path.isdir(STORY_DIRECTORY):
    os.makedirs(STORY_DIRECTORY)

    
def save_notebook_state():
    name = input('Enter a story name: ').strip().replace(' ', '_')
    assert name != '', 'Name cannot be empty'
    out_path = os.path.join(STORY_DIRECTORY, '{}.json'.format(name))
    if os.path.exists(out_path):
        if input(
            'File: {} already exists. Overwrite (y/N)? '.format(out_path)
        ).strip().lower() != 'y':
            print('Canceled by user.')
            return

    with open(out_path, 'w') as f:
        json.dump({
            'key_phrases': list(KEY_PHRASES),
            'context_phrases': list(CONTEXT_PHRASES),
            'ground_truth': {
                'positive_labels': {
                    k: list(v) for k, v in GROUND_TRUTH.positive.items()
                },
                'negative_labels': {
                    k: list(v) for k, v in GROUND_TRUTH.negative.items()
                },
            }
        }, f)
    print('Saved:', out_path)
    
    
def load_notebook_state():
    print('The following stories are saved:')
    for fname in sorted(os.listdir(STORY_DIRECTORY)):
        print('', fname.split('.')[0].replace('_', ' '))
    
    name = input('Enter a story to load: ').strip().replace(' ', '_')
    in_path = os.path.join(STORY_DIRECTORY, '{}.json'.format(name))
    with open(in_path, 'r') as f:
        data = json.load(f)
    global KEY_PHRASES, CONTEXT_PHRASES, GROUND_TRUTH
    KEY_PHRASES = set(data['key_phrases'])
    CONTEXT_PHRASES = set(data['context_phrases'])
    GROUND_TRUTH = GroundTruth(
        {int(k): set(tuple(y) for y in v) 
         for k, v in data['ground_truth']['positive_labels'].items()},
        {int(k): set(tuple(y) for y in v) 
         for k, v in data['ground_truth']['negative_labels'].items()}
    )
    print('Loaded:', in_path)
    sync_context_widget()
    sync_key_widget()

In [None]:
save_notebook_state()

In [None]:
load_notebook_state()

# Demo Code
Load a debugging lexicon...

In [None]:
KEY_PHRASES = {
    'HURRICANE & IRMA :: 30'
}
CONTEXT_PHRASES = { 
    'ADVISORY', 'ATLANTIC', 'BANDS', 'BEACH', 'BOATS', 'BRACING', 'BRIDGES',
    'CARIBBEAN', 'CATASTROPHIC', 'CATEGORY', 'CLEANUP', 'COAST', 'COASTAL',
    'CUBA', 'DAMAGE', 'DEBRIS', 'DESTRUCTION', 'DESTRUCTIVE', 'DEVASTATED',
    'DEVASTATING', 'DEVASTATION', 'DISASTERS', 'DOWNTOWN', 'ELECTRICITY',
    'EVACUATE', 'EVACUATED', 'EVACUATION', 'EVACUATIONS', 'FEMA', 'FLOOD',
    'FLOODED', 'FLOODING', 'FLORIDA', 'FORECAST', 'GUSTS', 'HARVEY', 'HURRICANE',
    'HURRICANES', 'IMPACTED', 'IMPACTS', 'INTENSITY', 'IRMA', 'ISLAND', 'ISLANDS',
    'JOSE', 'KEYS', 'LANDFALL', 'MANDATORY', 'METEOROLOGIST', 'MIAMI', 'MONSTER',
    'MYERS', 'NURSING', 'ORLANDO', 'OUTAGES', 'OUTER', 'PALM', 'POWER',
    'PREPARATION', 'PUERTO', 'RAIN', 'RAINFALL', 'RAINS', 'REBUILD',
    'RESPONDERS', 'RESTORED', 'RICO', 'SHELTER', 'SHELTERS', 'STORM',
    'STORMS', 'STRONGEST', 'SUPPLIES', 'SURGE', 'SUSTAINED', 'TAMPA',
    'TIDE', 'TREES', 'TROPICAL', 'WARNINGS', 'WATER', 'WAVES', 'WIND', 'WINDS'
}
GROUND_TRUTH = GroundTruth({}, {})
sync_context_widget()
sync_key_widget()

In [None]:
KEY_PHRASES = {
    'MOSUL & (BATTLE | SIEGE) :: 60'
}
CONTEXT_PHRASES = {
    'MOSUL', 'COMMANDERS', 'KURDISH', 'STRATEGIC', 'EXPLOSIONS', 'ENEMY',
    'BOMBERS', 'GUNFIRE', 'CIVILIANS', 'OFFENSIVE', 'OPERATION', 'KURDS',
    'DEFEAT', 'FIERCE', 'IRAQIS', 'PROVINCE', 'EXPLOSIVES', 'BAGHDAD',
    'URBAN', 'BATTLES', 'DAM', 'ISIL', 'RETREAT', 'ISIS', 'COMBAT',
    'SURROUNDED', 'TERRITORY', 'DECISIVE', 'STRIKES', 'CIVILIAN', 'OPERATIONS',
    'BOMBINGS', 'FLEEING', 'SUNNI', 'BATTLE', 'FLEE', 'ARMY', 'COALITION',
    'FIGHTING', 'BATTLEFIELD', 'FLED', 'CASUALTIES', 'FIGHTERS', 'IRAQI',
    'DEFEATED', 'BOMBS', 'FORCES', 'TROOPS', 'TUNNELS', 'SIEGE', 'MILITIA',
    'MILITANTS', 'TACTICAL', 'ARTILLERY', 'IRAQ', 'ISLAMIC', 'RESISTANCE',
}
GROUND_TRUTH = GroundTruth({}, {})
sync_context_widget()
sync_key_widget()

In [None]:
KEY_PHRASES = {
    '(PARKLAND | STONEMAN DOUGLAS | FLORIDA) & SHOOTING :: 60'
}
CONTEXT_PHRASES = {
    'DEPUTIES', 'DEADLY', 'PARKLAND', 'HORRIFIC', 'FIREARMS', 'SHERIFF',
    'DOUGLAS', 'GUN', 'STONEMAN', 'SURVIVOR', 'MASSACRE', 'SHOOTER',
    'SHOOTINGS', 'FRESHMAN', 'RIFLES', 'MASS', 'CLASSES', 'SCHOOL',
    'SHOT', 'SURVIVORS', 'SHOOTING', 'HIGH', 'FLORIDA', 'STUDENTS',
    'TEACHERS', 'VICTIMS', 'SURVIVED', 'ORGANIZERS'
}
GROUND_TRUTH = GroundTruth({}, {})
sync_context_widget()
sync_key_widget()

In [None]:
KEY_WORDS = {
    'FIFA'
}
CONTEXT_WORDS = {
    'ETHICS', 'INVESTIGATING', 'INDICTMENT', 'RESIGN', 'CORRUPTION',
    'ARRESTS', 'PLEADED', 'ACCUSATION', 'SUSPENDED', 'RESIGNATION',
    'INDICTED', 'ALLEGATION', 'SCANDAL', 'ALLEGATIONS', 'ARRESTED',
    'SCANDALS', 'BRIBERY', 'RESIGNED', 'ABUSED', 'ACCUSATIONS', 
    'CHARGES', 'CORRUPT'
}
GROUND_TRUTH = GroundTruth({}, {})
sync_context_widget()
sync_key_widget()

In [None]:
GROUND_TRUTH = GroundTruth({}, {})