In [1]:
from video_mask_dataset import VideoMaskDataset
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import display, HTML
import numpy as np

config = {
    "network": {
        "arch": "Custom"
    },
    "hp": {
        "instance_size": 255,
        "base_size": 8,
        "out_size": 127,
        "seg_thr": 0.35,
        "penalty_k": 0.12,
        "window_influence": 0.42,
        "lr": 0.5
    },
    "lr": {
        "type": "log",
        "start_lr": 0.005,
        "end_lr": 0.0025
    },
    "loss": {
        "weight": [0, 0, 36]
    },
    "train_datasets": {
        "datasets": {
            "ytb_vos": {
                "root": "../data/ytb_vos/crop511",
                "anno": "../data/ytb_vos/train.json",
                "num_use": 100000,
                "frame_range": 20
            },
            "coco": {
                "root": "../data/coco/crop511",
                "anno": "../data/coco/train2017.json",
                "frame_range": 1
            }
        },
        "template_size": 127,
        "search_size": 255,
        "base_size": 8,
        "size": 25,

        "num" : 200000,

        "augmentation": {
            "template": {
                "shift": 4, "scale": 0.05
            },
            "search": {
                "shift": 8, "scale": 0.18, "blur": 0.18
            },
            "neg": 0,
            "gray": 0.25
        }
    },
    "anchors": {
        "stride": 8,
        "ratios": [0.33, 0.5, 1, 2, 3],
        "scales": [8],
        "round_dight": 0
    }
}


dataset = VideoMaskDataset(config['train_datasets'], config['anchors'])

In [15]:
template, search, clses, deltas, delta_weights,\
    bboxes, masks, mask_weight = dataset.__getitem__(349, debug=True)
print(template.shape, search.shape, clses.shape, deltas.shape, delta_weights.shape,
        bboxes.shape, masks.shape, mask_weight.shape)

def plot_sequence_images(image_array):
    dpi = 72.0

    if len(image_array.shape) < 4:
        image_array = np.expand_dims(image_array, 1)
    if image_array.shape[1] == 1:
        image_array = np.column_stack([image_array]*3)

    xpixels, ypixels = image_array.shape[2:]
    fig = plt.figure(figsize=(ypixels/dpi, xpixels/dpi), dpi=dpi)
    image_array = image_array[:, [2, 1, 0], :, :].astype(np.uint8)
    
    im = plt.figimage(image_array[0].transpose(1, 2, 0))

    def animate(i):
        im.set_array(image_array[i].transpose(1, 2, 0))
        return (im,)

    anim = animation.FuncAnimation(fig, animate, frames=len(image_array), interval=60, repeat_delay=1, repeat=True)
    display(HTML(anim.to_html5_video()))
    
def plot_video(search, masks=None):
    if masks is not None:
        masks = np.expand_dims(masks, 1)
        masks += 1
        masks = np.column_stack([masks]*3)
        masks[:, 2:3] = 0
        search += masks * 50
    plot_sequence_images(search.clip(0, 255))

plot_video(search, masks=masks)

(3, 127, 127) (20, 3, 255, 255) (20, 5, 25, 25) (20, 4, 5, 25, 25) (20, 5, 25, 25) (20, 4) (20, 255, 255) (20, 25, 25)


<Figure size 254x254 with 0 Axes>