In [1]:
from video_mask_dataset import VideoMaskDataset
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import display, HTML
import numpy as np
import cv2

config = {
    "network": {
        "arch": "Custom"
    },
    "hp": {
        "instance_size": 255,
        "base_size": 8,
        "out_size": 127,
        "seg_thr": 0.35,
        "penalty_k": 0.12,
        "window_influence": 0.42,
        "lr": 0.5
    },
    "lr": {
        "type": "log",
        "start_lr": 0.005,
        "end_lr": 0.0025
    },
    "loss": {
        "weight": [0, 0, 36]
    },
    "train_datasets": {
        "datasets": {
            "ytb_vos": {
                "root": "../data/ytb_vos/crop511",
                "anno": "../data/ytb_vos/train.json",
                "num_use": 200000,
                "frame_range": 20
            }
        },
        "template_size": 127,
        "search_size": 255,
        "base_size": 8,
        "size": 25,

        "num" : 200000,

        "augmentation": {
            "template": {
                "shift": 4, "scale": 0.05
            },
            "search": {
                "shift": 8, "scale": 0.18, "blur": 0.18
            },
            "neg": 0,
            "gray": 0.25
        }
    },
    "anchors": {
        "stride": 8,
        "ratios": [0.33, 0.5, 1, 2, 3],
        "scales": [8],
        "round_dight": 0
    }
}


dataset = VideoMaskDataset(config['train_datasets'], config['anchors'])

In [5]:
def plot_sequence_images(image_array):
    dpi = 72.0

    if len(image_array.shape) < 4:
        image_array = np.expand_dims(image_array, 1)
    if image_array.shape[1] == 1:
        image_array = np.column_stack([image_array]*3)

    xpixels, ypixels = image_array.shape[2:]
    fig = plt.figure(figsize=(ypixels/dpi, xpixels/dpi), dpi=dpi)
    image_array = image_array[:, [2, 1, 0], :, :].astype(np.uint8)
    
    im = plt.figimage(image_array[0].transpose(1, 2, 0))

    def animate(i):
        im.set_array(image_array[i].transpose(1, 2, 0))
        return (im,)

    anim = animation.FuncAnimation(fig, animate, frames=len(image_array), interval=150, repeat_delay=1, repeat=True)
    display(HTML(anim.to_html5_video()))
    
def plot_video(search, masks=None, mask_weight=None, bboxes=None):
    print('Augmented video')
    plot_sequence_images(search.clip(0, 255))
    if masks is not None:
        masks = np.expand_dims(masks, 1)
        masks += 1
        masks = np.column_stack([masks]*3)
        masks[:, 2:3] = 0
        search1 = search + masks * 50
        print('Label mask')
        plot_sequence_images(search1.clip(0, 255))
    if mask_weight is not None:
        mask_weight = [cv2.resize(mask_weight[i], (255, 255), interpolation=cv2.INTER_AREA) \
                         for i in range(mask_weight.shape[0])]
        mask_weight = np.stack(mask_weight, 0)
        mask_weight = np.expand_dims(mask_weight, 1) + 1
        search2 = search.copy()
        search2[:, 1:2, :, :] += mask_weight * 50
        print('Label mask weight')
        plot_sequence_images(search2.clip(0, 255))
    if bboxes is not None:
        search3 = []
        for i in range(search.shape[0]):
            image = search[i].transpose(1, 2, 0)
            x1, y1, x2, y2 = map(lambda x: int(round(x)), bboxes[i])
            cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0))
            search3.append(image.transpose(2, 0, 1))
        search3 = np.stack(search3, 0)
        print('Label bounding box')
        plot_sequence_images(search3.clip(0, 255))

template, search, clses, deltas, delta_weights,\
    bboxes, masks, mask_weight = dataset.__getitem__(14, debug=True)
print(template.shape, search.shape, clses.shape, deltas.shape, delta_weights.shape,
        bboxes.shape, masks.shape, mask_weight.shape)

plot_video(search, masks=masks, mask_weight=mask_weight, bboxes=bboxes)

(3, 127, 127) (29, 3, 255, 255) (29, 5, 25, 25) (29, 4, 5, 25, 25) (29, 5, 25, 25) (29, 4) (29, 255, 255) (29, 25, 25)
Augmented video


Label mask


Label mask weight


Label bounding box


<Figure size 254x254 with 0 Axes>

<Figure size 254x254 with 0 Axes>

<Figure size 254x254 with 0 Axes>

<Figure size 254x254 with 0 Axes>