In [1]:
%matplotlib widget

from PIL import Image
from IPython.display import display
import matplotlib.pyplot as plt
import ipywidgets as widgets

import math
import numpy as np
import cv2
from tensorflow.keras.applications import vgg16
from tensorflow.keras.applications.vgg16 import preprocess_input

from vgg16_window_walker_lib_a import resize_frame, key_point_grid, next_pos_play, extract_window, paint_windows, MemoryGraph, build_graph

In [None]:
class VideoPatchBrowser:
    def __init__(self, video_file, ax1, ax2):
        with out:
            print("VideoPatchBrowser")
        self.orb = cv2.ORB_create(nfeatures=100000, fastThreshold=7)
        self.memory_graph = MemoryGraph(db_path, space='cosine', dim=512)
        self.model = vgg16.VGG16(weights="imagenet", include_top=False, input_shape=(32, 32, 3))
        self.cap = cv2.VideoCapture(video_file) 
        self.ax1 = ax1
        self.ax2 = ax2
        self.next_frame()


    def on_press(self, event):
        with out:
            print("on_press")
        self.next_frame()


    def on_click(self, event):
        if ax1 != event.inaxes:
            with out:
                print("outside")
            return

        pos = (event.ydata, event.xdata)

        with out:
            print("on_click", pos)
            
        self.next_patches(pos)


    def next_frame(self):
        if self.cap.isOpened():  
            ret, frame = self.cap.read() 

            if ret == True: 
                self.frame = frame
                self.update_ax1()
                return

        with out:
            print("No More Frames")


    def next_patches(pos):   
        print("show_patches")

        res_frame = resize_frame(self.frame)
        kp_grid = key_point_grid(self.orb, res_frame, stride)
        print("len(kp_grid)", len(kp_grid))

        grid_offset_x = ((self.frame.shape[0] - 32) % stride)/2.0 + 16
        grid_offset_y = ((self.frame.shape[1] - 32) % stride)/2.0 + 16
        g_pos = (int(math.floor((pos[0]-grid_offset_x)/stride)), int(math.floor((pos[1]-grid_offset_y)/stride)))

        print("g_pos", g_pos)
        path = []

        for i in range(playback_random_walk_length):
            g_pos, pos = next_pos_play(kp_grid, res_frame.shape, g_pos, stride)
            print("g_pos, pos", g_pos, pos)
            if g_pos is None:
                break
            path.append(pos)

        path = list(set(path))

        windows = np.array([extract_window(res_frame, p, window_size) for p in path])

        preprocess_input(windows)
        features = self.model.predict(windows)
        features = features.reshape((windows.shape[0], 512))
        
        print("windows.shape, feats.shape", windows.shape, features.shape)

        self.patches = self.build_patches(windows, features, path, self.frame.shape, self.memory_graph)
        self.update_ax2()


    def build_patches(path_windows, path_features, path_positions, frame_shape, memory_graph, window_size):
        
        frame = np.zeros((frame_shape[0], frame_shape[1], 3), np.uint8)

        paint_windows(path_positions, path_windows, frame, window_size, 0)

        # features, feature_dis, community_dis, k=30
        groups = list(memory_graph.search_group(path_features, .2, .2, 30))

        print("groups", groups)

        for i in range(len(groups)):
            group = list(groups[i])
            

            # node_ids = memory_graph.get_nodes(group)
            
            observation_ids = []
            for node_id in group:
                # print("node_id", node_id)
                integrated_observations = memory_graph.get_integrated_observations(node_id)
                observation_ids.extend(integrated_observations)
                predicted_observations = memory_graph.get_predicted_observations(node_id)
                observation_ids.extend(predicted_observations)

            windows = np.array([cv2.imread('../patches/patch'+str(observation_id)+'.png') for observation_id in observation_ids])

            observations = memory_graph.get_observations(observation_ids)

            positions = [(obs["y"], obs["x"]) for obs in observations]

            paint_windows(positions, windows, frame, window_size, i+1)
        
        return frame
        

    def update_ax1(self):
        self.ax1.cla()
        self.ax1.axis("off")
        self.ax1.imshow(cv2.cvtColor(self.frame, cv2.COLOR_BGR2RGB))
        

    def update_ax2(self):
        self.ax2.cla()
        self.ax2.axis("off")
        self.ax2.imshow(cv2.cvtColor(self.patches, cv2.COLOR_BGR2RGB))

# Parameters

In [2]:
stride = 16
window_size = 32
playback_random_walk_length = 10
# video_path = '../media/Tabletop Objects/videos/'
video_path = "../../media/Tabletop Objects/videos/001_apple.mp4"
db_path = "../../data/tabletop_objects_001_apple.db"
patch_dir = "../../patches"

# Build Graph

In [3]:
build_graph(db_path, video_path, patch_dir, stride=stride, window_size=window_size)

Starting...
MemoryGraph: loading nodes
MemoryGraph: loading edges
MemoryGraph: loaded 0 nodes, 0 edges
Run 0
vid 1 frame 1 res 200 near 27 iden 0 pred 0 acc 0 many 0
vid 1 frame 2 res 3 near 61 iden 2 pred 0 acc 0 many 0
vid 1 frame 3 res 2 near 67 iden 2 pred 49 acc 9 many 0
vid 1 frame 4 res 2 near 71 iden 1 pred 65 acc 13 many 0
vid 1 frame 5 res 1 near 69 iden 2 pred 70 acc 16 many 0
vid 1 frame 6 res 0 near 82 iden 3 pred 69 acc 22 many 0
vid 1 frame 7 res 1 near 90 iden 2 pred 82 acc 26 many 0
vid 1 frame 8 res 2 near 82 iden 3 pred 89 acc 28 many 0
vid 1 frame 9 res 2 near 81 iden 2 pred 80 acc 22 many 1
vid 1 frame 10 res 4 near 87 iden 3 pred 79 acc 29 many 0
vid 1 frame 11 res 2 near 85 iden 3 pred 86 acc 25 many 0
vid 1 frame 12 res 1 near 94 iden 2 pred 84 acc 21 many 0
vid 1 frame 13 res 1 near 85 iden 2 pred 94 acc 26 many 0
vid 1 frame 14 res 3 near 92 iden 6 pred 84 acc 32 many 0
vid 1 frame 15 res 2 near 90 iden 2 pred 91 acc 30 many 2
vid 1 frame 16 res 6 near 103 ide

# Play Video

In [None]:
video_file = "../../media/Tabletop Objects/videos/001_apple.mp4"
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12,14), dpi= 100)

browser = VideoPatchBrowser(video_file, ax1, ax2)

fig.canvas.mpl_connect('button_release_event', browser.on_click)
fig.canvas.mpl_connect('key_press_event', browser.on_press)

fig.tight_layout()

ax1.axis("off")
ax2.axis("off")

out = widgets.Output()
display(out)

plt.show()