In [None]:
from rekall.video_interval_collection import VideoIntervalCollection
from rekall.interval_list import Interval, IntervalList
from rekall.temporal_predicates import *
from rekall.logical_predicates import *
from esper.rekall import intrvllists_to_result, add_intrvllists_to_result
#from vgrid_jupyter import VGridWidget
#from esper.widget import intervals_with_metadata
from esper.prelude import esper_widget
from query.models import FaceIdentity
from django.db.models import IntegerField, F
from django.db.models.functions import Cast

def get_fps_map(vids):
    from query.models import Video
    vs = Video.objects.filter(id__in=vids)
    return {v.id: v.fps for v in vs}

def frame_second_conversion(c, mode='f2s'):
    fps_map = get_fps_map(set(c.get_intervals().keys()))
    
    def second_to_frame(fps):
        def map_fn(intrvl):
            i2 = intrvl.copy()
            t1,t2 = intrvl.t
            i2.t = (int(t1*fps), int(t2*fps))
            return i2
        return map_fn
    
    def frame_to_second(fps):
        def map_fn(intrvl):
            i2 = intrvl.copy()
            t1,t2 = intrvl.t
            i2.t = (int(t1/fps), int(t2/fps))
            return i2
        return map_fn
    
    if mode=='f2s':
        fn = frame_to_second
    if mode=='s2f':
        fn = second_to_frame
    output = {}
    for vid, intervals in c.get_grouped_intervals().items():
        output[vid] = intervals.map(fn(fps_map[vid]))
    return VideoIntervalCollection(output)

def frame_to_second_collection(c):
    return frame_second_conversion(c, 'f2s')

def second_to_frame_collection(c):
    return frame_second_conversion(c, 's2f')

from query.models import LabeledInterview, LabeledPanel, LabeledCommercial, Video, FaceIdentity, Face
sandbox_videos = [529, 763, 2648, 3459, 3730, 3769, 3952, 4143, 4611, 5281, 6185, 7262, 8220,
    8697, 8859, 9215, 9480, 9499, 9901, 10323, 10335, 11003, 11555, 11579, 11792,
    12837, 13058, 13141, 13247, 13556, 13827, 13927, 13993, 14482, 15916, 16215,
    16542, 16693, 16879, 17458, 17983, 19882, 19959, 20380, 20450, 23181, 23184,
    24193, 24847, 24992, 25463, 26386, 27188, 27410, 29001, 31378, 32472, 32996,
    33004, 33387, 33541, 33800, 34359, 34642, 36755, 37107, 37113, 37170, 38275,
    38420, 40203, 40856, 41480, 41725, 42756, 45472, 45645, 45655, 45698, 48140,
    49225, 49931, 50164, 50561, 51175, 52075, 52749, 52945, 53355, 53684, 54377,
    55711, 57384, 57592, 57708, 57804, 57990, 59122, 59398, 60186]

interviews_django_qs = LabeledInterview.objects \
        .annotate(fps=F('video__fps')) \
        .annotate(min_frame=F('fps') * F('start')) \
        .annotate(max_frame=F('fps') * F('end')).filter(original=True)
panels = LabeledPanel.objects \
        .annotate(fps=F('video__fps')) \
        .annotate(min_frame=F('fps') * F('start')) \
        .annotate(max_frame=F('fps') * F('end'))
commercials = LabeledCommercial.objects \
        .annotate(fps=F('video__fps')) \
        .annotate(min_frame=F('fps') * F('start')) \
        .annotate(max_frame=F('fps') * F('end'))

def display(cols_and_colors):
    result = intrvllists_to_result(
        cols_and_colors[0][0].get_allintervals(),
        color=cols_and_colors[0][1]
    )
    for col, color in cols_and_colors[1:]:
        add_intrvllists_to_result(result, col.get_allintervals(), color=color)
    return esper_widget(result, jupyter_keybindings=True, disable_captions=True)

def intersect(intrvl1, intrvl2):
    return [intrvl1.overlap(intrvl2)]

def span(intrvl1, intrvl2):
    return [intrvl1.merge(intrvl2)]

# Step 1: Watch a bunch of TV

In [None]:
# Load interviews from database
# interviews_django_qs is a representation of a SQL table
#   rows are hand-annotated interviews
interviews_gt = VideoIntervalCollection.from_django_qs(interviews_django_qs)

In [None]:
display([(interviews_gt, 'red')])

# Step 2: Gather ground truth

In [None]:
# Filter the SQL table
jake_bernie_interviews_qs = interviews_django_qs.filter(
    interviewer1="jake tapper",
    guest1="bernie sanders"
)

In [None]:
# import into rekall
interviews_jake_bernie_gt = VideoIntervalCollection.from_django_qs(
    jake_bernie_interviews_qs
)

In [None]:
display([(interviews_jake_bernie_gt, 'black')])

# First query: frames where Jake Tapper and Bernie Sanders appear together

In [None]:
# Let's query for Bernie Sanders interviews with Jake Tapper. This may take a while to materialize all the data.
identities = FaceIdentity.objects.filter(face__shot__video_id__in=sandbox_videos)
jake_qs = identities.filter(identity__name="jake tapper").filter(probability__gt=0.7)
bernie_qs = identities.filter(identity__name="bernie sanders").filter(probability__gt=0.7)=

jake = VideoIntervalCollection.from_django_qs(jake_qs
    .annotate(video_id=F("face__shot__video_id"))
    .annotate(min_frame=F("face__shot__min_frame"))
    .annotate(max_frame=F("face__shot__max_frame")))
bernie = VideoIntervalCollection.from_django_qs(bernie_qs
    .annotate(video_id=F("face__shot__video_id"))
    .annotate(min_frame=F("face__shot__min_frame"))
    .annotate(max_frame=F("face__shot__max_frame")))

In [None]:
jake.get_allintervals().keys()

In [None]:
# Join Jake and Bernie, and merge any intervals that are touching or overlapping
jake_and_bernie = jake.join(
    bernie,
    predicate = overlaps(),
    merge_op = intersect
).coalesce()

In [None]:
display([
    (interviews_jake_bernie_gt, 'black'),
    (jake_and_bernie, 'red')
])

# Increase Recall by expressing a temporal pattern

In [None]:
'''
We're going to look for the following patterns:
    (Bernie + Jake) -> Jake OR
    Jake -> (Bernie + Jake) OR
    (Bernie + Jake) -> Bernie OR
    Bernie -> (Bernie + Jake)

We'll coalesce that, and then check in with the Esper widget again.
'''
new_interviews = jake_and_bernie.join(
    jake, predicate = or_pred(before(max_dist=10), after(max_dist=10), arity=2),
    merge_op = span
).set_union(
    jake_and_bernie.join(
        bernie, predicate=or_pred(before(max_dist=10), after(max_dist=10), arity=2),
        merge_op = span
    )
).coalesce()

display([
    (interviews_jake_bernie_gt, 'black'),
    (new_interviews, 'red')
])

# Get rid of small gaps with dilation

In [None]:
# dilate by 20 seconds
new_interviews_no_gaps = new_interviews.dilate(600).coalesce().dilate(-600)

display([
    (interviews_jake_bernie_gt, 'black'),
    (new_interviews_no_gaps, 'red')
])

# Increase precision by filtering

In [None]:
# filter out segments shorter than two minute2
more_precise_interviews = new_interviews_no_gaps.filter_length(min_length=3600)

display([
    (interviews_jake_bernie_gt, 'black'),
    (more_precise_interviews, 'red')
])

# Full query


In [None]:
bernie_jake_interviews = jake.join(
    bernie,
    predicate = overlaps(),
    merge_op = intersect
).join(
    jake, predicate = or_pred(before(max_dist=10), after(max_dist=10), arity=2),
    merge_op = span
).set_union(
    jake.join(
        bernie,
        predicate = overlaps(),
        merge_op = intersect
    ).join(
        bernie, predicate=or_pred(before(max_dist=10), after(max_dist=10), arity=2),
        merge_op = span
    )
).coalesce().dilate(
    600
).coalesce().dilate(
    -600
).filter_length(min_length=3600)

display([
    (interviews_jake_bernie_gt, 'black'),
    (bernie_jake_interviews, 'red')
])