# Cheap Optical Flow: Is It Good?


## Quickstart

## Credits

Some portions of this notebook adapted from:
 * [Middlebury Flow code by Johannes Oswald](https://github.com/Johswald/flow-code-python/blob/master/readFlowFile.py)
 * [DeepDeform Demo Code](https://github.com/AljazBozic/DeepDeform)
 * [OpticalFlowToolkit by RUOTENG LI](https://github.com/liruoteng/OpticalFlowToolkit)
 * [OpenCV Samples](https://github.com/opencv/opencv/blob/master/samples/python/opt_flow.py)

In [None]:
# parameters
SHOW_DEMO_OUTPUT = True
DEMO_FPS = []

In [None]:
## Data Model & Utility Code

!pip3 install pypng

import attr
import cv2
import imageio
import IPython.display
import math
import os
import PIL.Image
import six

from oarphpy import plotting as op_plt
import numpy as np


@attr.s(slots=True, eq=False, weakref_slot=False)
class OpticalFlowPair(object):
    """A flyweight for a pair of images with an optical flow field.
    Supports lazy-loading of large data attributes."""
    
    dataset = attr.ib(type=str, default='')
    """To which dataset does this pair belong?"""
    
    id1 = attr.ib(type=str, default='')
    """Identifier or URI for the first image"""
    
    id2 = attr.ib(type=str, default='')
    """Identifier or URI for the second image"""
    
    img1 = attr.ib(default=None)
    """URI or numpy array for the first image (source image)"""

    img2 = attr.ib(default=None)
    """URI or numpy array for the second image (target image)"""
    
    flow = attr.ib(default=None)
    """A callable or numpy array representing optical flow from img1 -> img2"""
    
    def get_img1(self):
        if isinstance(self.img1, six.string_types):
            self.img1 = imageio.imread(self.img1)
        return self.img1
    
    def get_img2(self):
        if isinstance(self.img2, six.string_types):
            self.img2 = imageio.imread(self.img2)
        return self.img2
    
    def get_flow(self):
        if not isinstance(self.flow, (np.ndarray, np.generic)):
            self.flow = self.flow()
        return self.flow
    
    def to_html(self):
        im1 = self.get_img1()
        im2 = self.get_img2()
        flow = self.get_flow()
        fviz = draw_flow(im1, flow)
        html = """
            <table>
            
            <tr><td style="text-align:left"><b>Dataset:</b> {dataset}</td></tr>
            
            <tr><td style="text-align:left"><b>Source Image:</b> {id1}</td></tr>
            <tr><td><img src="{im1}" width="100%" /></td></tr>

            <tr><td style="text-align:left"><b>Target Image:</b> {id2}</td></tr>
            <tr><td><img src="{im2}" width="100%" /></td></tr>

            <tr><td style="text-align:left"><b>Flow</b></td></tr>
            <tr><td><img src="{fviz}" width="100%" /></td></tr>
            </table>
        """.format(
                dataset=self.dataset,
                id1=self.id1, id2=self.id2,
                im1=img_to_data_uri(im1), im2=img_to_data_uri(im2),
                fviz=img_to_data_uri(fviz))
        return html


## General Utilities
    
def imshow(x):
    IPython.display.display(PIL.Image.fromarray(x))

def show_html(x):
    from IPython.core.display import display, HTML
    display(HTML(x))
    
def load_image(x):
    if isinstance(x, (np.ndarray, np.generic)):
        return x
    if isinstance(x, six.string_types):
        return imageio.imread(x)
    else:
        raise ValueError("Can't handle %s" % (x,))

img_to_data_uri = lambda x: op_plt.img_to_data_uri(x, format='png')
# # TODO correct oarphpy img uri to be png or jpeg   data_url = 'data:image/png;base64,{}'.format(parse.quote(data)) 
# def img_to_data_uri(img, format='png', jpeg_quality=75):
#     """Given a numpy array `img`, return a `data:` URI suitable for use in 
#     an HTML image tag."""

#     from io import BytesIO
#     out = BytesIO()

#     import imageio
#     kwargs = dict(format=format)
#     if format == 'jpg':
#         kwargs.update(quality=jpeg_quality)
#     imageio.imwrite(out, img, **kwargs)

#     from base64 import b64encode
#     data = b64encode(out.getvalue()).decode('ascii')

#     from six.moves.urllib import parse
#     data_url = 'data:image/png;base64,{}'.format(parse.quote(data))

#     return data_url


def draw_flow(img, flow, step=8):
    """Based upon OpenCV sample: https://github.com/opencv/opencv/blob/master/samples/python/opt_flow.py"""
    h, w = img.shape[:2]
    y, x = np.mgrid[step/2:h:step, step/2:w:step].reshape(2,-1).astype(int)
    fx, fy = flow[y,x].T
    lines = np.vstack([x, y, x+fx, y+fy]).T.reshape(-1, 2, 2)
    lines = np.int32(lines + 0.5)
    vis = img.copy()
    cv2.polylines(vis, lines, 0, (0, 255, 0))
    for (x1, y1), (_x2, _y2) in lines:
        cv2.circle(vis, (x1, y1), 1, (0, 255, 0), -1)
    return vis


## Middlebury Optical Flow



In [None]:
# Please unzip `other-color-allframes.zip` and `other-gt-flow.zip` to a directory and provide the target below:
MIDD_DATA_ROOT = '/outer_root/host_mnt/Volumes/970-evo-raid0/middlebury-flow/'

# For the Middlebury Flow dataset, we only consider the real scenes
MIDD_SCENES = [
    {
        'input': 'other-data/Dimetrodon/frame10.png',
        'expected_out': 'other-data/Dimetrodon/frame11.png',
        'flow_gt': 'other-gt-flow/Dimetrodon/flow10.flo',
    },
        {
        'input': 'other-data/Hydrangea/frame10.png',
        'expected_out': 'other-data/Hydrangea/frame11.png',
        'flow_gt': 'other-gt-flow/Hydrangea/flow10.flo',
    },
        {
        'input': 'other-data/RubberWhale/frame10.png',
        'expected_out': 'other-data/RubberWhale/frame11.png',
        'flow_gt': 'other-gt-flow/RubberWhale/flow10.flo',
    },
]


def midd_read_flow(file):
    # Based upon: https://github.com/Johswald/flow-code-python/blob/master/readFlowFile.py
    # compute colored image to visualize optical flow file .flo
    # Author: Johannes Oswald, Technical University Munich
    # Contact: johannes.oswald@tum.de
    # Date: 26/04/2017
    # For more information, check http://vision.middlebury.edu/flow/ 

    assert type(file) is str, "file is not str %r" % str(file)
    assert os.path.isfile(file) is True, "file does not exist %r" % str(file)
    assert file[-4:] == '.flo', "file ending is not .flo %r" % file[-4:]
    f = open(file, 'rb')
    flo_number = np.fromfile(f, np.float32, count=1)[0]
    TAG_FLOAT = 202021.25
    assert flo_number == TAG_FLOAT, 'Flow number %r incorrect. Invalid .flo file' % flo_number
    w = np.fromfile(f, np.int32, count=1)
    h = np.fromfile(f, np.int32, count=1)

    #if error try: data = np.fromfile(f, np.float32, count=2*w[0]*h[0])
    data = np.fromfile(f, np.float32, count=int(2*w*h))
    
    # Reshape data into 3D array (columns, rows, bands)
    flow = np.resize(data, (int(h), int(w), 2))	
    f.close()
    
    # We found that there are some invalid (?) (i.e. very large) flows, so we're going
    # to ignore those for this experiment.
    invalid = (flow >= 1666)
    flow[invalid] = 0

    return flow


for i, scene in enumerate(MIDD_SCENES):
    p = OpticalFlowPair(
            dataset="Middlebury Optical Flow",
            id1=scene['input'],
            img1='file://' + os.path.join(MIDD_DATA_ROOT, scene['input']),
            id2=scene['expected_out'],
            img2='file://' + os.path.join(MIDD_DATA_ROOT, scene['expected_out']),
            flow=lambda: midd_read_flow(os.path.join(MIDD_DATA_ROOT, scene['flow_gt'])))
    
    if SHOW_DEMO_OUTPUT:
        show_html(p.to_html() + "<br/><br/><br/>")
        DEMO_FPS.append(p)


## DeepDeform

In [None]:
# Please extract deepdeform_v1.7z to a directory and provide the target below:
DD_DATA_ROOT = '/outer_root/host_mnt/Volumes/970-evo-raid0/deepdeform_v1/'

DD_DEMO_SCENES = [
    {
        "input": "train/seq000/color/000000.jpg",
        "expected_out": "train/seq000/color/000200.jpg",
        "flow_gt": "train/seq000/optical_flow/blackdog_000000_000200.oflow",
    },
    
    {
        "input": "train/seq000/color/000000.jpg",
        "expected_out": "train/seq000/color/001200.jpg",
        "flow_gt": "train/seq000/optical_flow/blackdog_000000_001200.oflow",
    },
    
    {
        "input": "train/seq001/color/003400.jpg",
        "expected_out": "train/seq001/color/003600.jpg",
        "flow_gt": "train/seq001/optical_flow/lady_003400_003600.oflow",
    },
    
    {
        "input": "train/seq337/color/000050.jpg",
        "expected_out": "train/seq337/color/000350.jpg",
        "flow_gt": "train/seq337/optical_flow/adult_000050_000350.oflow",
    },
]

def dd_load_flow(path):
    # Based upon https://github.com/AljazBozic/DeepDeform/blob/master/utils.py#L1
    import shutil
    import struct
    
    # Flow is stored row-wise in order [channels, height, width].
    assert os.path.isfile(path)

    flow_gt = None
    with open(path, 'rb') as fin:
        width = struct.unpack('I', fin.read(4))[0]
        height = struct.unpack('I', fin.read(4))[0]
        channels = struct.unpack('I', fin.read(4))[0]
        n_elems = height * width * channels

        flow = struct.unpack('f' * n_elems, fin.read(n_elems * 4))
        flow_gt = np.asarray(flow, dtype=np.float32).reshape([channels, height, width])
    
    # Match format used in this analysis
    flow_gt = np.moveaxis(flow_gt, 0, -1) # (h, w, 2)
    invalid_flow = flow_gt == -np.Inf
    flow_gt[invalid_flow] = 0.0
    return flow_gt

def dd_create_fp(info):
     return OpticalFlowPair(
                dataset="DeepDeform Semi-Synthetic Optical Flow",
                id1=scene['input'],
                img1='file://' + os.path.join(DD_DATA_ROOT, scene['input']),
                id2=scene['expected_out'],
                img2='file://' + os.path.join(DD_DATA_ROOT, scene['expected_out']),
                flow=lambda: dd_load_flow(os.path.join(DD_DATA_ROOT, scene['flow_gt'])))

import json
DD_ALIGNMENTS = json.load(open(os.path.join(DD_DATA_ROOT, 'train_alignments.json')))
ALL_DD_SCENES = [
    {
        "input": ascene['source_color'],
        "expected_out": ascene['target_color'],
        "flow_gt": ascene['optical_flow'],
    }
    for ascene in DD_ALIGNMENTS
]

print("Found %s DeepDeform scenes" % len(ALL_DD_SCENES))
if SHOW_DEMO_OUTPUT:
    for scene in DD_DEMO_SCENES:
        p = dd_create_fp(scene)
        show_html(p.to_html())
        DEMO_FPS.append(p)
    
# for i, scene in enumerate(ALL_DD_SCENES):
    

    
#     img_in = imageio.imread(os.path.join(DD_DATA_ROOT, scene['input']))
#     expected = imageio.imread(os.path.join(DD_DATA_ROOT, scene['expected_out']))
#     flow_gt = dd_load_flow(os.path.join(DD_DATA_ROOT, scene['flow_gt']))
    
#     show_scene(img_in, expected, flow_gt, dataset='deepdeform', exname=str(i))
#     continue
    
#     vis = draw_flow(img_in, flow_gt)
#     imshow(vis)

#     warped = warp_flow(img_in, flow_gt[:, :, :2])
#     imshow(warped)
    
#     exclude = (flow_gt == np.array([0, 0])).all(axis=-1)
#     warped[exclude] = np.array([0, 0, 0])
#     imshow(warped)
    
#     imshow(expected)



In [None]:
## Kitti Scene Flow Benchmark (2015)


In [None]:
# Please unzip `data_scene_flow.zip` and `data_scene_flow_calib.zip` to a directory and provide that target below:
KITTI_SF15_DATA_ROOT = '/outer_root/host_mnt/Volumes/970-evo-raid0/kitti_sceneflow_scratch/'

# You have to ls flow_occ 
KITTI_SF15_DEMO_SCENES = [
    {
        'input': 'training/image_2/000000_10.png',
        'expected_out': 'training/image_2/000000_11.png',
        'flow_gt': 'training/flow_occ/000000_10.png',
    },
    {
        'input': 'training/image_2/000007_10.png',
        'expected_out': 'training/image_2/000007_11.png',
        'flow_gt': 'training/flow_occ/000007_10.png',
    },
    {
        'input': 'training/image_2/000023_10.png',
        'expected_out': 'training/image_2/000023_11.png',
        'flow_gt': 'training/flow_occ/000023_10.png',
    },
    {
        'input': 'training/image_2/000051_10.png',
        'expected_out': 'training/image_2/000051_11.png',
        'flow_gt': 'training/flow_occ/000051_10.png',
    },
    {
        'input': 'training/image_2/000003_10.png',
        'expected_out': 'training/image_2/000003_11.png',
        'flow_gt': 'training/flow_occ/000003_10.png',
    },
]

from oarphpy import util as oputil
KITTI_SF15_ALL_FLOW_OCC = [
    os.path.basename(p)
    for p in oputil.all_files_recursive(
        os.path.join(KITTI_SF15_DATA_ROOT, 'training/flow_occ'), pattern='*.png')
]
    
KITTI_SF15_ALL_SCENES = [
    {
        "input": 'training/image_2/%s' % fname,
        "expected_out": 'training/image_2/%s' % fname.replace('_10', '_11'),
        "flow_gt": 'training/flow_occ/%s' % fname,
    }
    for fname in KITTI_SF15_ALL_FLOW_OCC
]
print("Found %s KITTI SceneFlow 2015 scenes" % len(KITTI_SF15_ALL_SCENES))


def kitti_sf15_load_flow_from_png(path):
    # Based upon https://github.com/liruoteng/OpticalFlowToolkit/blob/master/lib/flowlib.py#L559
    import png
    flow_object = png.Reader(filename=path)
    flow_direct = flow_object.asDirect()
    flow_data = list(flow_direct[2])
    w, h = flow_direct[3]['size']
    flow = np.zeros((h, w, 3), dtype=np.float64)
    for i in range(len(flow_data)):
        flow[i, :, 0] = flow_data[i][0::3]
        flow[i, :, 1] = flow_data[i][1::3]
        flow[i, :, 2] = flow_data[i][2::3]

    invalid_idx = (flow[:, :, 2] == 0)
    flow[:, :, 0:2] = (flow[:, :, 0:2] - 2 ** 15) / 64.0
    flow[invalid_idx, 0] = 0
    flow[invalid_idx, 1] = 0
    return flow[:, :, :2]

def kitti_sf15_create_fp(info):
     return OpticalFlowPair(
                dataset="KITTI Scene Flow 2015",
                id1=scene['input'],
                img1='file://' + os.path.join(KITTI_SF15_DATA_ROOT, scene['input']),
                id2=scene['expected_out'],
                img2='file://' + os.path.join(KITTI_SF15_DATA_ROOT, scene['expected_out']),
                flow=lambda: kitti_sf15_load_flow_from_png(os.path.join(KITTI_SF15_DATA_ROOT, scene['flow_gt'])))

if SHOW_DEMO_OUTPUT:
    for scene in KITTI_SF15_DEMO_SCENES:
        p = kitti_sf15_create_fp(scene)
        show_html(p.to_html())
        DEMO_FPS.append(p)

## Reconstruction via Optical Flow

In [None]:
## Reconstruction via Optical Flow

def zero_flow(flow):
    return (flow[:, :, :2] == np.array([0, 0])).all(axis=-1)

def warp_flow_backwards(img, flow):
    """Given an image, apply the inverse of `flow`"""
    h, w = flow.shape[:2]
    flow = -flow
    flow[:,:,0] += np.arange(w)
    flow[:,:,1] += np.arange(h)[:,np.newaxis]
    res = cv2.remap(img, flow.astype(np.float32), None, cv2.INTER_LINEAR)
    return res
    
def warp_flow_forwards(img, flow):
    """Given an image, apply the given optical flow `flow`.  Returns not only the warped
    image, but a `mask` indicating warped pixels (i.e. there was non-zero flow *into* these pixels ).
    With some help from https://stackoverflow.com/questions/41703210/inverting-a-real-valued-index-grid/46009462#46009462
    """
    h, w = img.shape[:2]
    pts = flow.copy()
    pts[:, :, 0] += np.arange(w)
    pts[:, :, 1] += np.arange(h)[:, np.newaxis]
    exclude = zero_flow(flow)
    if exclude.all():
        # No flow anywhere!
        return img.copy(), np.zeros((h, w)).astype(np.bool)
    else:
        inpts = pts[~exclude]
    
    from scipy.interpolate import griddata
    inpts = np.reshape(inpts, [-1, 2])
    grid_y, grid_x = np.mgrid[:h, :w]
    chan_out = []
    for ch in range(img.shape[-1]):
        spts = img[:, :, ch][~exclude].reshape([-1, 1])
        mapped = griddata(inpts, spts, (grid_x, grid_y), method='linear')
        chan_out.append(mapped.astype(img.dtype))
    out = np.stack(chan_out, axis=-1)
    out = out.reshape([h, w, len(chan_out)])

    mask = np.reshape(inpts, [-1, 2])
    mask = np.rint(mask).astype(np.int)
    mask = mask[np.where((mask[:, 0] >= 0) & (mask[:, 0] < w) & (mask[:, 1] >= 0) & (mask[:, 1] < h))]
    valid_mask = np.zeros((h, w))
    valid_mask[mask[:, 1], mask[:, 0]] = 1
    
    return out, valid_mask.astype(np.bool)

@attr.s(slots=True, eq=False, weakref_slot=False)
class FlowReconstructedImagePair(object):
    """A pair of reconstructed images using an input pair of images and optical
    flow field (i.e. an `OpticalFlowPair` instance)."""

    opair = attr.ib(default=OpticalFlowPair())
    """The original `OpticalFlowPair` with the source of the data for this reconstruction result."""
    
    img2_recon_fwd = attr.ib(default=np.array([]))
    """A Numpy image containing the result of FORWARDS-WARPING OpticalFlowPair::img1
    via OpticalFlowPair::flow to reconstruct OpticalFlowPair::img2"""

    img2_recon_fwd_valid = attr.ib(default=np.array([]))
    """A Numpy boolean mask indicating which pixels of `img2_recon_fwd` were modified via non-zero flow"""
    
    img1_recon_bkd = attr.ib(default=np.array([]))
    """A Numpy image containing the result of BACKWARDS-WARPING OpticalFlowPair::img2
    via OpticalFlowPair::flow to reconstruct OpticalFlowPair::img1"""

    img1_recon_bkd_valid = attr.ib(default=np.array([]))
    """A Numpy boolean mask indicating which pixels of `img1_recon_bkd` were modified via non-zero flow"""
        
    @classmethod
    def create_from(cls, oflow_pair: OpticalFlowPair):
        flow = oflow_pair.get_flow()
        
        # Forward Warp
        fwarped, fvalid = warp_flow_forwards(oflow_pair.get_img1(), flow)

        # Backwards Warp
        exclude = zero_flow(flow)
        bwarped = warp_flow_backwards(oflow_pair.get_img2(), -flow[:, :, :2])
        bvalid = ~exclude
        
        return FlowReconstructedImagePair(
                opair=oflow_pair,
                img2_recon_fwd=fwarped,
                img2_recon_fwd_valid=fvalid,
                img1_recon_bkd=bwarped,
                img1_recon_bkd_valid=bvalid)
        
        vwarped = warped.copy()
        if (~valid_mask).any():
            vwarped[~valid_mask] = im2[~valid_mask]#np.array([255, 0, 0])
        else:
            vwarped = im2.copy()
            print('!!!! no invalids')
        imshow(vwarped)
    
    def to_html(self):
        # We use pixels from the destination image in order to make the reconstruction 
        # easier to interpret; we'll fade them in intensity so that they are more
        # conspicuous.        
        FADE_UNTOUCHED_PIXELS = 0.3
        
        viz_fwd = self.img2_recon_fwd.copy().astype(np.float32)
        im2 = self.opair.get_img2()
        if (~self.img2_recon_fwd_valid).any():
            viz_fwd[~self.img2_recon_fwd_valid] = im2[~self.img2_recon_fwd_valid]
            viz_fwd[~self.img2_recon_fwd_valid] *= FADE_UNTOUCHED_PIXELS
        else:
            viz_fwd = im2.copy() * FADE_UNTOUCHED_PIXELS
        
        viz_bkd = self.img1_recon_bkd.copy().astype(np.float32)
        im1 = self.opair.get_img1()
        if (~self.img1_recon_bkd_valid).any():
            viz_bkd[~self.img1_recon_bkd_valid] = im1[~self.img1_recon_bkd_valid]
            viz_bkd[~self.img1_recon_bkd_valid] *= FADE_UNTOUCHED_PIXELS
        else:
            viz_bkd = im1.copy() * FADE_UNTOUCHED_PIXELS
        
        html = """
            <table>
            
            <tr><td style="text-align:left"><b>Forwards Warped:</b></td></tr>
            <tr><td><img src="{viz_fwd}" width="100%" /></td></tr>

            <tr><td style="text-align:left"><b>Backwards Warped:</b></td></tr>
            <tr><td><img src="{viz_bkd}" width="100%" /></td></tr>

            </table>
        """.format(
                viz_fwd=img_to_data_uri(viz_fwd.astype(np.uint8)),
                viz_bkd=img_to_data_uri(viz_bkd.astype(np.uint8)))
        return html

        
if SHOW_DEMO_OUTPUT:
    DEMO_RECONS = []
    for p in DEMO_FPS:
        recon = FlowReconstructedImagePair.create_from(p)
        show_html(recon.to_html() + "</br></br></br>")
        DEMO_RECONS.append(recon)


## Analysis

In [None]:
# Analysis Utils

def mse(i1, i2, valid):
    return np.mean((i1[valid] - i2[valid]) ** 2)

def rmse(i1, i2, valid):
    return math.sqrt(mse(i1, i2, valid))

def psnr(i1, i2, valid):
    return 20 * math.log10(255) - 10 * math.log10(max((mse(i1, i2, valid), 1e-12)))

def ssim(i1, i2, valid):
    # Some variance out there ...
    # https://github.com/scikit-image/scikit-image/blob/master/skimage/metrics/_structural_similarity.py#L12-L232
    # https://github.com/nianticlabs/monodepth2/blob/13200ab2f29f2f10dec3aa5db29c32a23e29d376/layers.py#L218
    # https://cvnote.ddlee.cn/2019/09/12/psnr-ssim-python
    # We will just use SKImage for now ...
    from skimage.metrics import structural_similarity as ssim
    mssim, S = ssim(i1, i2, win_size=11, multichannel=True, full=True)
    return np.mean(S[valid])

def to_edge_im(img):
    return np.stack([
        cv2.Laplacian(cv2.cvtColor(img, cv2.COLOR_RGB2GRAY), cv2.CV_32F, ksize=1),
        cv2.Sobel(cv2.cvtColor(img, cv2.COLOR_RGB2GRAY), cv2.CV_32F, 1, 0, ksize=3),
        cv2.Sobel(cv2.cvtColor(img, cv2.COLOR_RGB2GRAY), cv2.CV_32F, 0, 1, ksize=3),
    ], axis=-1)

def edges_mse(i1, i2, valid):
    return mse(to_edge_im(i1), to_edge_im(i2), valid)


def oflow_coverage(valid):
    return valid.sum() / (valid.shape[0] * valid.shape[1])

def oflow_magnitude_hist(flow, valid, bins=50):
    flow_l2s = np.sqrt( flow[valid][:, 0] ** 2 + flow[valid][:, 1] ** 2 )
    bin_edges = np.histogram_bin_edges(flow_l2s, bins=bins)
    bin_counts = np.histogram(flow_l2s, bins=bin_edges)
    return bin_edges, bin_counts


# Analysis Data Model

class OFlowReconErrors(object):
    """Various measures of reconstruction error for a `FlowReconstructedImagePair` instance.
    Encapsulated as two dictionaries of stats for easy interop with Spark SQL."""

    RECONSTRUCTION_ERR_METRICS = {
        'Mean Squared Error (MSE)': mse,
        'Root-MSE': rmse,
        'Peak Signal-to-Noise Ratio (PSNR)': psnr,
        'Structured Similarity (SSIM)': ssim,
        'Edges MSE (non-standard metric)': edges_mse,
    }
    
    def __init__(self, recon_pair: FlowReconstructedImagePair):
        im2 = recon_pair.opair.get_img2()
        img2_recon_fwd = recon_pair.img2_recon_fwd
        img2_recon_fwd_valid = recon_pair.img2_recon_fwd_valid
        self.forward_stats = dict(
            (name, func(im2, img2_recon_fwd, img2_recon_fwd_valid))
            for name, func in self.RECONSTRUCTION_ERR_METRICS.items())
        
        im1 = recon_pair.opair.get_img1()
        img1_recon_fwd = recon_pair.img1_recon_bkd
        img1_recon_fwd_valid = recon_pair.img1_recon_bkd_valid
        self.backward_stats = dict(
            (name, func(im1, img1_recon_fwd, img1_recon_fwd_valid))

    def to_html(self):
        stat_names = self.RECONSTRUCTION_ERR_METRICS.keys()

        rows = [
            """
            <tr>   <td><b>{name}</b></td>  <td>{fwd}</td>  <td>{bkd}</td>  </tr>
            """.format(name=name, fwd=self.forward_stats[name], bkd=self.backward_stats[name])
            for name in sorted(stat_names)
        ]
        
        
        html = """
            <table>
              <tr>
                  <th></th> <th><b>Forwards Warp</b></th> <th><b>Backwards Warp</b></th>
              </tr>

              {table_rows}

            </table>
        """.format(table_rows="".join(rows))
        
        return html
            
@attr.s(slots=True, eq=False, weakref_slot=False)
class OFlowStats(object):
    """Stats on the optical flow of a `OpticalFlowPair` instance"""

    opair = attr.ib(default=OpticalFlowPair())
    """The original `OpticalFlowPair` with the source of the data for this reconstruction result."""
    
    coverage = attr.ib(default=0)
    """Fraction of the image with valid flow"""
    
    magnitude_hist = attr.ib(default=[np.array([]), np.array([])])
    """Histogram [bin edges, bin counts] of flow magnitudes"""
    
    @classmethod
    def create_from(cls, oflow_pair: OpticalFlowPair):
        flow = oflow_pair.get_flow()
        valid = ~zero_flow(flow)
        return OFlowStats(
                 opair=oflow_pair,
                 coverage=oflow_coverage(valid),
                 magnitude_hist=oflow_magnitude_hist(flow, valid))
                 
    def to_html(self):
        import matplotlib.pyplot as plt
        flow_l2s = np.sqrt( flow[bvalid][:, 0] ** 2 + flow[bvalid][:, 1] ** 2 )
        print(flow_l2s.shape)
        fig = plt.figure()
        bin_edges, bin_counts = self.magnitude_hist
        plt.bar(bin_edges, bin_counts)
        plt.title("Histogram of Flow Magnitudes")
        plt.xlabel('Flow Magnitude (pixels)')
        plt.ylabel('Count')

        hist_img = matplotlib_fig_to_img(fig)
        
        html = """
            <table>           
            <tr><td style="text-align:left"><b>Flow Coverage:</b> {coverage} </td></tr>
            <tr><td><img src="{flow_hist}" width="100%" /></td></tr>
            </table>
        """.format(
                coverage=self.coverage,
                flow_hist=img_to_data_uri(matplotlib_fig_to_img))
        return html


# Misc

def matplotlib_fig_to_img(fig):
    import io
    from PIL import Image
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    im = Image.open(buf)
    im.show()
    buf.seek(0)

    import imageio
    hist_img = imageio.imread(buf)
    buf.close()
    return imageio.imread(hist_img)


if SHOW_DEMO_OUTPUT:
    for recon in DEMO_RECONS:
        p = recon.oflow_pair
        errors = OFlowReconError(recon)
        err_html = errors.to_html()  
            
        fstats = OFlowStats.create_from(p)
        stats_html = fstats.to_html()
            
        title = "<b>{dataset} {id1} -> {id2}</b>".format(dataset=p.dataset, id1=p.id1, id2=p.id2)
        
        show_html(title + stats_html + err_html + "</br></br></br>")
            

In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('/opt/psegs')

from psegs.datasets import nuscenes as psnusc
from psegs.spark import Spark
spark = Spark.getOrCreate()

from oarphpy import plotting as pl


from pyspark.sql import Row
df = spark.createDataFrame([
  Row(x=x, mod_11=int(x % 11), square=x*x)
  for x in range(101)
])


### Check plotting with custom example plotter
class PlotterWithMicroFacet(pl.HistogramWithExamplesPlotter):
  NUM_BINS = 25

  def display_bucket(self, sub_pivot, bucket_id, irows):
    rows_str = "<br />".join(
        "x: {x} square: {square} mod_11: {mod_11}".format(**row.asDict())
        for row in sorted(irows, key=lambda r: r.x))
    TEMPLATE = """
      <b>Pivot: {spv} Bucket: {bucket_id} </b> <br/>
      {rows}
      <br/> <br/>
    """
    disp = TEMPLATE.format(
              spv=sub_pivot,
              bucket_id=bucket_id,
              rows=rows_str)
    return bucket_id, disp

plotter = PlotterWithMicroFacet()
fig = plotter.run(df, 'square')

from bokeh.io import output_notebook
output_notebook()

from bokeh.plotting import show
show(fig)