In [35]:
__author__ = 'stephen'

import caffe
import cv2, os
import numpy as np
import matplotlib.pyplot as plt
from lib.vdbc.dataset_factory import VDBC
from lib.vdbc.evaluate import Evaluator
from lib.vdbc.sample import gaussian_sample

from lib.data_layer.layer import get_next_mini_batch

PARAMS = (0.3, 0.3, 0.05, 0.7, 0.3)
IMS_PER_FRAME = 256
PRE_TRAIN_ITERS = IMS_PER_FRAME * 15
VISUAL = False
print 'Number of samples for each frame is {} and pre-train iteration is {}.'.format(IMS_PER_FRAME, PRE_TRAIN_ITERS)

# get the deploy solver and net with pre-trained caffe model
TRAIN = os.path.join('model', 'deploy_solver.prototxt')
TEST = os.path.join('model', 'deploy_test.prototxt')
WEIGHTS = os.path.join('model', 'MDNet_iter_2946816.caffemodel')

# set GPU mode
caffe.set_mode_gpu()
caffe.set_device(1)

# get the Evaluator
dtype = 'VOT'
dbpath = os.path.join('data', 'vot2014')
gtpath = dbpath
vdbc = VDBC(dbtype=dtype, dbpath=dbpath, gtpath=gtpath, flush=True)
evl = Evaluator(vdbc)

Number of samples for each frame is 256 and pre-train iteration is 3840.
save image_list.json successfully.
save gt_info.json successfully.
VDBC instance built.
[Evaluator]Video set: ['ball', 'basketball', 'bicycle', 'bolt', 'car', 'david', 'diving', 'drunk', 'fernando', 'fish1', 'fish2', 'gymnastics', 'hand1', 'hand2', 'jogging', 'motocross', 'polarbear', 'skating', 'sphere', 'sunshade', 'surfing', 'torus', 'trellis', 'tunnel', 'woman']
[Evaluator]Number of sets: 25


In [36]:
def vis_detection(im_path, gt, box):
    im = cv2.imread(im_path)[:, :, (2, 1, 0)]
    plt.cla()
    plt.imshow(im)
    # add ground-truth box
    plt.gca().add_patch(
        plt.Rectangle(
            (gt[0], gt[1]),
            gt[2], gt[3],
            fill=False,
            edgecolor='red',
            linewidth=1.5
        )
    )
    # add detection box
    plt.gca().add_patch(
        plt.Rectangle(
            (box[0], box[1]),
            box[2], box[3],
            fill=False,
            edgecolor='blue',
            linewidth=1.5
        )
    )

    plt.show()

In [37]:
def get_solver_net(train, test, weights):
    solver = caffe.SGDSolver(train)
    solver.net.copy_from(weights)

    net = caffe.Net(test, caffe.TEST)
    net.share_with(solver.net)

    return solver, net

In [38]:
def evaluate(evl):
    # Step 1: get solver and net, and then pre-train it 
    # Get solver and net with prototxt and weights
    solver, net = get_solver_net(TRAIN, TEST, WEIGHTS)
    # Initialize the net with the first frame
    im_path, gt = evl.init_frame()
    im = cv2.imread(im_path)
    samples = gaussian_sample(im, gt, PARAMS, PRE_TRAIN_ITERS)
    db = []
    for i in range(len(samples)):
        db.append({
                'path': im_path,
                'img': im,
                'gt': gt,
                'samples': [samples[i]]
            })
    solver.net.layers[0].get_db(db)
    solver.step(PRE_TRAIN_ITERS)
    
    # Step 2: sample from each frame and then take the 
    # ROI with highest score as the target
    for it in range(10):
        im_path = evl.next_frame()
        im = cv2.imread(im_path)
        samples=gaussian_sample(im, gt, PARAMS, IMS_PER_FRAME)
        # get all boxes
        candidate_boxes = [s['box'] for s in samples]
        scores = np.zeros(IMS_PER_FRAME, dtype=np.float32)
        for i in range(IMS_PER_FRAME):
            db = [{
                'path': im_path,
                'img': im,
                'samples': [samples[i]]
            }]
            blob = get_next_mini_batch(db)
            blob = {'data': blob['data'].astype(np.float32, copy=True)}
            net.blobs['data'].reshape(*blob['data'].shape)
            out = net.forward(data=blob['data'])['cls_prob']
            scores[i] = out[0, 1]
        # Choose the target box
        ind = scores.argmax()
        target_box = candidate_boxes[ind]
        evl.report(target_box) # report the target
        ground_truth_box = evl.get_ground_truth()
        if VISUAL:
            vis_detection(im_path, ground_truth_box, target_box)

In [39]:
evl.set_video(0)
evaluate(evl)

[DataLayer]Number of classes: 2.
[DataLayer] Get the database.


In [40]:
evl.get_results()

{'ball': [0.85987132310407821,
  0.69519236330605982,
  0.81187381151241267,
  0.8395410990819332,
  0.70098948699289088,
  0.86046756913150346,
  0.70393134769170185,
  0.8645565993632337,
  0.80748055335571534,
  0.52798452758593373]}