In [None]:
import numpy as np
import skvideo.io
from matplotlib import pyplot as plt
from skimage.color import rgb2gray
from skimage.measure import label
from skimage.morphology import remove_small_objects, erosion, dilation, disk

### Utils for data I/O and visualization

In [None]:
def show_rand_imgs(data, num=5, cmap=None):
    plt.figure(figsize=(20,5))
    for i in range(num):
        rand_idx = np.random.randint(0, data.shape[0])
        plt.subplot('1%d%d' % (num, i+1))
        plt.imshow(data[rand_idx], cmap=cmap)
        plt.axis('off')
    plt.show()
    
def load_video(filename, down_sample=3):
    # load data
    video = skvideo.io.vread(filename)
    if down_sample!=1:
        video = video[::down_sample]
    video_gray = rgb2gray(video)
    print(video.shape, video_gray.shape)
    return video, video_gray

def save_video(video, video_gray, center_video):
    video = (video*255).astype(np.uint8)
    dummy = np.zeros(video_gray.shape, dtype=np.uint8)
    
    struct = disk(3)[None,:,:]
    center_video = dilation(center_video, struct)
    center_video = (center_video*255).astype(np.uint8)
    center_video = np.stack([center_video, dummy, dummy], 3)
    
    output_video = np.maximum(video, center_video)
    print(output_video.shape)
    skvideo.io.vwrite("outputvideo.mp4", output_video)

### Object segmentation functions

In [None]:
def find_valid_region(image, thres=64, size_thres=64, show_imgs=False):
    image = (image-image.min())/(image.max()-image.min())
    image = (image*255).astype(np.uint8)
    binary = (image > thres).astype(np.uint8)
    segmentation = label(binary)
    segmentation = remove_small_objects(segmentation, size_thres)
    indices, counts = np.unique(segmentation, return_counts=True)
    # print(indices, counts)
    
    if show_imgs:
        plt.imshow(image)
        plt.show()
        for i in np.unique(indices):
            temp = (segmentation==i).astype(np.uint8)*255
            plt.imshow(temp)
            plt.title(i)
            plt.show()
            
    return (segmentation==0).astype(np.uint8)
    
def segment_image(image, show_imgs=False, thres=64, size_thres=64, valid_region=None):
    image = (image-image.min())/(image.max()-image.min())
    image = (image*255).astype(np.uint8)
    binary = (image > thres).astype(np.uint8)
    binary = erosion(binary)
    if valid_region is not None:
        binary = binary * valid_region
    
    segmentation = label(binary)
    if len(np.unique(segmentation)>1):
        segmentation = remove_small_objects(segmentation, size_thres)
            
    indices, counts = np.unique(segmentation, return_counts=True)
    pos = [i for i in range(len(counts)) if counts[i]>300 and counts[i]<1500]
    # print(indices, counts, pos)
    if len(pos)>=1: 
        target_idx = indices[pos[0]]
        target = (segmentation==target_idx).astype(np.uint8)
        target = erosion(target)
        foreground_coord = np.where(target!=0)
        center = [int(foreground_coord[0].astype(float).mean()),
                  int(foreground_coord[1].astype(float).mean())]

        if show_imgs:
            plt.figure(figsize=(20,10))
            plt.subplot(141)
            plt.imshow(binary*255, cmap='gray')
            plt.axis('off')
            plt.subplot(142)
            plt.imshow(segmentation, cmap='tab20c')
            plt.axis('off')
            plt.subplot(143)
            plt.imshow(target, cmap='gray')
            plt.axis('off')
            plt.subplot(144)
            plt.imshow(image, cmap='gray')
            plt.scatter(center[1], center[0], c='r', s=20)
            plt.axis('off')
            plt.show()
            
    else:
        center = []
    
    return center

### 0. Load the video

* Change `filename = "name_of_video.wmv"` to the name of your own video.
* Change `down_sample=1` to n to speed-up the video by n times.

In [None]:
filename = "name_of_video.wmv"
video, video_gray = load_video(filename, down_sample=5)
show_rand_imgs(video, 3)
show_rand_imgs(video_gray, 3, 'gray')

### 1. Find valid region in the field-of-view (FoV)

In [None]:
valid_region = find_valid_region(video_gray.mean(0))

### 2. Process and save the video

In [None]:
center_video = np.zeros(video_gray.shape, dtype=np.uint8)
for i in list(range(video_gray.shape[0])):
    center = segment_image(video_gray[i], show_imgs=False, valid_region=valid_region)
    if len(center)==2:
        center_video[i, center[0], center[1]] = 1
save_video(video, video_gray, center_video)