In [13]:
from scripts.models import VehTraj, PedTraj, Trajectory
from esper.widget import vgrid_widget
from vgrid import VideoBlockFormat, NestedFormat, NamedIntervalSet, IntervalBlock
import cv2
import numpy as np
import matplotlib.pyplot as plt
import pyro
import pyro.distributions as dist
from pyro.poutine import trace
import math
from tqdm import tqdm_notebook as tqdm
from scripts.utils import img_grid, vgrid_traj, make_batch
import pandas as pd
from rekall import IntervalSetMapping, IntervalSet, Interval, Bounds3D
from enum import Enum
from torch import tensor
import torch
from pyro import plate
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, JitTraceEnum_ELBO, config_enumerate
from pyro.optim import Adam
from pprint import pprint
from torch.distributions import constraints
from copy import copy
import re
from scripts import pattern
from esper.spark import EsperSpark

pyro.enable_validation(True)

In [None]:
vehicles = list(Vehicle.objects.all())
veh_trajectories = [VehTraj(veh) for veh in vehicles]

pedestrians = list(Pedestrian.objects.all())
ped_trajectories = [PedTraj(ped) for ped in pedestrians]

In [None]:
def model(N):
    ncat = 4
    transition_probs = pyro.sample(
        'transition_probs', 
        dist.Dirichlet(0.9 * torch.eye(ncat) + 0.1).to_event(1))
    
    state = 0
    for i in pyro.markov(N):
        state = pyro.sample('state_{}'.format(i), dist.Categorical(transition_probs[state]), 
                            infer={"enumerate": "parallel"})
        pyro.sample('vel_{}'.format(i))

In [None]:
vehicle = sample(uniform(Vehicle))
start, end = sample(uniform_interval(vehicle.start, vehicle.end))
stopping_car = sample(TrajectoryRegex('<+_'))
observe(stopping_car - vehicle.trajectory()[start:end] = 0)

obj = sample(uniform(Object))
time = vehicle.trajectory()[end].frame
observe(||obj.pos - time.pos|| = 0)

In [19]:
def slows_to_stop(traj, eps=0.01):
    vel = traj.velocity(stride=3)
    vel_mag = np.linalg.norm(vel, axis=1)
    return vel_mag[0] > eps * 7.0 and vel_mag[-1] < eps

vehicle = pattern.Vehicle()
stopping_car = vehicle.trajectory().window(minlen=10, maxlen=30, stride=10).where(slows_to_stop)
close_obj = pattern.Object()
cars_stopping_for = stopping_car.end().time().match(pattern.Close(close_obj, vehicle))

traces, _ = cars_stopping_for.eval()
vgrid_traj([(trace[stopping_car], trace[close_obj].trajectory()) for trace in traces[:50]])

VGridWidget(vgrid_spec={'database': {'videos': [{'num_frames': 575, 'fps': 23.976, 'id': 4, 'width': 838, 'hei…

In [5]:
def turning_left(traj):
    turn = traj.pos[0].psi - traj.pos[-1].psi
    return abs(turn - math.pi / 2)
        
vehicle = pattern.Vehicle()
turning_car = vehicle.trajectory().weight(turning_left)

traces, scores = turning_car.eval()
vgrid_traj([trace[turning_car] for trace in traces[:10]])

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 39.12it/s]


VGridWidget(vgrid_spec={'database': {'videos': [{'num_frames': 263, 'fps': 23.976, 'id': 22, 'width': 786, 'he…

In [4]:
def is_speeding(traj):
    vel = traj.velocity(stride=3)
    vel_mag = np.linalg.norm(vel, axis=1)
    return 1/vel_mag.mean()

def is_close(tup):
    (o1, o2) = tup
    start = max(o1.pos[0].frame, o2.pos[0].frame)
    end = min(o1.pos[-1].frame, o2.pos[-1].frame)
    diff = np.array([o1.frame(t).for_numpy() - o2.frame(t).for_numpy() for t in range(start, end+1)])
    if len(diff) == 0:
        return 10000000
    else:
        return np.linalg.norm(diff).min()

vehicle = pattern.Vehicle()
pedestrian = pattern.Pedestrian().trajectory()
fast_car = vehicle.trajectory().window(minlen=6, maxlen=6+1).weight(is_speeding, max=1.5)
fast_car_near_person = fast_car.join(pedestrian).weight(is_close)

traces, scores = fast_car_near_person.eval()
vgrid_traj([(trace[fast_car], trace[pedestrian]) for trace in traces[:30]])

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:19<00:00,  1.16it/s]


VGridWidget(vgrid_spec={'database': {'videos': [{'num_frames': 166, 'fps': 23.976, 'id': 2, 'width': 808, 'hei…