In [None]:
%matplotlib inline
from query.datasets.prelude import *
from query.datasets.tvnews.shot_detect import shot_detect, shot_stitch
from scannerpy.stdlib import parsers
from scipy.spatial import distance
from unionfind import unionfind
import cv2

In [None]:
gt = [
    226, 822, 2652, 3893, 4058, 4195, 4326, 4450, 4583, 4766, 5021, 5202, 5294, 5411, 6584,
    7140, 7236, 7388, 7547, 7673, 7823, 7984, 8148, 8338, 8494, 8625, 8914, 9042, 9207, 9308,
    11395, 11823, 12198, 12563, 13516, 13878, 13991, 14162, 14237, 14333, 14488, 14688, 14770,
    14825, 15017, 15537, 15701, 15866, 16012, 16112, 16295, 16452, 16601, 16880, 17018, 17184,
    17310, 17446, 17962, 18713, 18860, 19120, 19395, 19543, 19660, 19839, 19970, 20079, 20248,
    20291, 20862
]
gt = [n - 20 for n in gt]

In [None]:
video = Video.objects.get(path='tvnews/videos/CNNW_20161229_130000_New_Day.mp4')
labeler, _ = Labeler.objects.get_or_create(name='shot-histogram')

# with Database() as db:
#     frame = db.ops.FrameInput()
#     histogram = db.ops.Histogram(frame=frame, device=DeviceType.GPU)
#     output = db.ops.Output(columns=[histogram])
#     job = Job(
#         op_args={frame: db.table(video.path).column('frame'),
#                  output: video.path + '_hist'})
#     bulk_job = BulkJob(output=output, jobs=[job])
#     # [hists_table] = db.run(bulk_job, force=True, io_packet_size=10000)
#     hists_table = db.table(video.path + '_hist')

#     print('Loading histograms...')
#     hists = [h for _, h in hists_table.load(['histogram'], parsers.histograms)]
#     print('Loaded!')

In [None]:
diffs = np.array([
    np.mean([distance.chebyshev(hists[i - 1][j], hists[i][j]) for j in range(3)])
    for i in range(1, len(hists))
])
diffs = np.insert(diffs, 0, 0)

In [None]:
WINDOW_SIZE = 500
GROUP_THRESHOLD = 10
STD_DEV_FACTOR = 1
MAGNITUDE_THRESHOLD = 5000

def compute_shot_boundaries(hists):
    # Compute the mean difference between each pair of adjacent frames
    log.debug('Computing means')
    diffs = np.array([
        np.mean([distance.chebyshev(hists[i - 1][j], hists[i][j]) for j in range(3)])
        for i in range(1, len(hists))
    ])
    diffs = np.insert(diffs, 0, 0)
    n = len(diffs)

    # Do simple outlier detection to find boundaries between shots
    log.debug('Detecting outliers')
    boundaries = []
    for i in range(1, n):
        window = diffs[max(i - WINDOW_SIZE/2, 0):min(i + WINDOW_SIZE/2, n)]
        if diffs[i] > MAGNITUDE_THRESHOLD and diffs[i] - np.mean(window) > STD_DEV_FACTOR * np.std(window):
            boundaries.append(i)
               
    log.debug('Grouping adjacent frames')
    u = unionfind(len(boundaries))
    for i, bi in enumerate(boundaries):
        for j, bj in enumerate(boundaries):
            if abs(bi - bj) < GROUP_THRESHOLD:
                u.unite(i, j)
                break
        
    grouped_boundaries = [boundaries[g[len(g)/2]] for g in u.groups()]
            
    return grouped_boundaries

boundaries = compute_shot_boundaries(hists)
print('Done!')
print(boundaries)

In [None]:
#initial_boundaries = [s.min_frame for s in shot_detect([video])[0]]
#final_boundaries = [s.min_frame for s in shot_stitch([video], None, None, None, None)[0]]
initial_boundaries = [s.min_frame for s in Shot.objects.filter(labeler_id=1, video=video).order_by('min_frame')]
final_boundaries = [s.min_frame for s in Shot.objects.filter(labeler_id=2, video=video).order_by('min_frame')]
print(len(initial_boundaries), len(final_boundaries))

In [None]:
sp.check_call('gsutil cp gs://esper/{} /tmp'.format(video.path), shell=True)

In [None]:
local_path = '/tmp/{}'.format(video.path.split('/')[-1])
#local_path = '/tmp/MSNBCW_20130404_060000_Hardball_With_Chris_Matthews.mp4'

in_vid = cv2.VideoCapture(local_path)
out_vid = cv2.VideoWriter('test.mkv', cv2.VideoWriter_fourcc(*'XVID'), 30.0, (video.width, video.height))

YMAX = 50000
for i in range(20000):
    if i % 50 == 0: print(i)
    ret, frame = in_vid.read()

    wmin = max(i-WINDOW_SIZE/2, 0)
    wmax = i+WINDOW_SIZE/2
    wsize = wmax-wmin
    window = diffs[wmin:wmax]
    mpx = i if i < WINDOW_SIZE/2 else wsize/2
    
    fig = plt.figure(tight_layout=True)
    fig.add_subplot(111)
    
    plt.plot(window, 'tab:blue')

    plt.axvline(x=mpx, color='tab:red')
    plt.axhline(y=np.mean(window) + STD_DEV_FACTOR * np.std(window), color='tab:green')
    plt.axhline(y=MAGNITUDE_THRESHOLD, color='tab:olive')
    
    for b in initial_boundaries:
        idx = b - wmin
        if idx > 0 and idx < wsize:
            plt.scatter(idx, min(diffs[b], YMAX-2000), 
                        c='tab:orange' if b in final_boundaries else 'tab:purple', s=300, marker='X')
    
    plt.ylim(0, YMAX)
    plt.axis('off')
    axes = plt.gca().axes
    axes.get_xaxis().set_visible(False)
    axes.get_yaxis().set_visible(False)
    fig.canvas.draw()
    
    plt.close(fig)
    plt.clf()

    plt_img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
    plt_img = plt_img.reshape(fig.canvas.get_width_height()[::-1] + (3,))

    plt_img = cv2.resize(plt_img, None, fx=0.5, fy=0.5)
    plt_img = cv2.cvtColor(plt_img, cv2.COLOR_RGB2BGR)
    [ph, pw, pc] = plt_img.shape

    [fh, fw, fc] = frame.shape
    frame[(fh-ph):fh, (fw-pw):fw, :] = plt_img
    
    out_vid.write(frame)
    
out_vid.release()

In [None]:
Shot.objects.filter(video=video).delete()
shots = []
for i in range(len(boundaries) - 1):
    start = 0 if i == 0 else boundaries[i]
    end = boundaries[i + 1] - 1
    shots.append(Shot(video=video, labeler=labeler, min_frame=start, max_frame=end))

_ = Shot.objects.bulk_create(shots)

In [None]:
DIST_THRESHOLD = 15
gt_copy = gt[:]

boundaries = [n for n in boundaries if n < gt[-1]]

tp = 0
fp = 0
for i in boundaries:
    valid = None
    for k, j in enumerate(gt_copy):
        if abs(i - j) < DIST_THRESHOLD:
            valid = k
            break
    if valid is None:
        fp += 1
    else:
        tp += 1
        gt_copy = gt_copy[:k] + gt_copy[(k+1):]

fn = len(gt_copy)
print(tp, fp, fn)

precision = tp / float(tp + fp)
recall = tp / float(tp + fn)
print('Precision: {:.3f}, recall: {:.3f}, #det/#gt: {:.3f}'.format(precision, recall, len(boundaries) / float(len(gt))))
print(len(boundaries), len(gt))

print(boundaries)
print(gt_copy)
print(gt)