In [4]:
stride = 16
window_size = 32
playback_random_walk_length = 10

video_path = "../../media/tabletop_objects/videos/288_brush_carrot_clippers_cup_flowers_hanger_ketchup.mp4"
db_path = "../../data/test1.db"
patch_dir = "../../patches/test1"

# Build Graph

In [None]:
# from vgg16_window_walker_lib_b import build_graph
# build_graph(db_path, video_path, patch_dir, stride=stride, window_size=window_size, max_frames=30*30)

# Video Patch Browser

In [5]:
%matplotlib widget
import matplotlib.pyplot as plt

from PIL import Image
from IPython.display import display
import ipywidgets as widgets
from os.path import split
import traceback
import math
import numpy as np
import cv2

from tensorflow.keras.applications import vgg16
from tensorflow.keras.applications.vgg16 import preprocess_input

# These imports need to be in a separate cell than the class or "error src/tcmalloc.cc:332] Attempt to free invalid pointer"
from vgg16_window_walker_lib_c import resize_frame, key_point_grid, next_pos_play, extract_window, paint_windows, MemoryGraph

In [6]:
class VideoPatchBrowser:
    def __init__(self, video_path, db_path, patch_dir, out, ax1, ax2, playback_random_walk_length=10, stride=16, window_size=32):
        self.video_path = video_path
        self.patch_dir = patch_dir
        self.out = out
        self.stride = stride
        self.window_size = window_size
        self.orb = cv2.ORB_create(nfeatures=100000, fastThreshold=7)
        self.memory_graph = MemoryGraph(db_path, space='cosine', dim=512)
        self.model = vgg16.VGG16(weights="imagenet", include_top=False, input_shape=(32, 32, 3))
        self.cap = cv2.VideoCapture(video_path) 
        self.ax1 = ax1
        self.ax2 = ax2
        self.next_frame()


    def on_press(self, event):
        try:
            with self.out:
                print("on_press")
            self.next_frame()
            
        except Exception as err:
            with self.out:
                traceback.print_exc()
            raise

            
    def on_click(self, event):
        try:
            if ax1 != event.inaxes:
                with self.out:
                    print("outside")
                return

            pos = (event.ydata, event.xdata)

            with self.out:
                print("on_click", pos)

            self.next_patches(pos)
            
        except Exception as err:
            with self.out:
                traceback.print_exc()
            raise

            
    def next_frame(self):
        if self.cap.isOpened():  
            ret, frame = self.cap.read() 

            if ret == True: 
                self.frame = frame
                self.update_ax1()
                return

        with self.out:
            print("No More Frames")



    def next_patches(self, pos):   
        print("show_patches")

        res_frame = resize_frame(self.frame)
        kp_grid = key_point_grid(self.orb, res_frame, self.stride)
        print("len(kp_grid)", len(kp_grid))

        grid_offset_x = ((self.frame.shape[0] - 32) % self.stride)/2.0 + 16
        grid_offset_y = ((self.frame.shape[1] - 32) % self.stride)/2.0 + 16
        g_pos = (int(math.floor((pos[0]-grid_offset_x)/self.stride)), int(math.floor((pos[1]-grid_offset_y)/self.stride)))

        print("g_pos", g_pos)
        path = []

        for i in range(playback_random_walk_length):
            g_pos, pos = next_pos_play(kp_grid, res_frame.shape, g_pos, self.stride)
            print("g_pos, pos", g_pos, pos)
            if g_pos is None:
                break
            path.append(pos)

        path = list(set(path))

        windows = np.array([extract_window(res_frame, p, self.window_size) for p in path])

        preprocess_input(windows)
        features = self.model.predict(windows)
        features = features.reshape((windows.shape[0], 512))
        
        print("windows.shape, feats.shape", windows.shape, features.shape)

        self.patches = self.build_patches(windows, features, path, self.frame.shape, self.memory_graph)
        self.update_ax2()


    def build_patches(self, path_windows, path_features, path_positions, frame_shape, memory_graph):
        
        video_file_name = split(self.video_path)[1]
        
        frame = np.zeros((frame_shape[0], frame_shape[1], 3), np.uint8)

        paint_windows(path_positions, path_windows, frame, self.window_size, 0)

        # features, feature_dis, community_dis, k=30
        groups = list(memory_graph.search_group(path_features, .2, .2, 30))

        print("groups", groups)

        for i in range(len(groups)):
            group = list(groups[i])
            
            observation_ids = []
            for node_id in group:
                # print("node_id", node_id)
                integrated_observations = memory_graph.get_integrated_observations(node_id)
                observation_ids.extend(integrated_observations)
                predicted_observations = memory_graph.get_predicted_observations(node_id)
                observation_ids.extend(predicted_observations)

            windows = np.array([cv2.imread(self.patch_dir  + "/" + video_file_name + '/patch_'+str(observation_id)+'.png') for observation_id in observation_ids])

            observations = memory_graph.get_observations(observation_ids)

            positions = [(obs["y"], obs["x"]) for obs in observations]
            
            paint_windows(positions, windows, frame, self.window_size, i+1)
        
        return frame
        

    def update_ax1(self):
        self.ax1.cla()
        self.ax1.axis("off")
        self.ax1.imshow(cv2.cvtColor(self.frame, cv2.COLOR_BGR2RGB))
        

    def update_ax2(self):
        self.ax2.cla()
        self.ax2.axis("off")
        self.ax2.imshow(cv2.cvtColor(self.patches, cv2.COLOR_BGR2RGB))

In [7]:
out = widgets.Output()

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,10), dpi= 100)

browser = VideoPatchBrowser(video_path, db_path, patch_dir, out, ax1, ax2, stride=stride, window_size=window_size)

fig.canvas.mpl_connect('button_release_event', browser.on_click)
fig.canvas.mpl_connect('key_press_event', browser.on_press)

fig.tight_layout()

ax1.axis("off")
ax2.axis("off")

plt.show()

display(out)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

MemoryGraph: loading nodes
MemoryGraph: loading graph
MemoryGraph: loaded 88462 nodes 121267 edges


Output()