In [1]:
from __future__ import print_function
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

plt.rcParams['figure.figsize'] = (10, 10)
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

# Make sure that caffe is on the python path:
import os
os.chdir('..')
caffe_root = './'
import sys
sys.path.insert(0, caffe_root + 'python')

import caffe
caffe.set_device(0)
caffe.set_mode_gpu()

voc_net = caffe.Net(caffe_root + 'models/VGGNet/VOC0712/DJI_6classes/deploy.prototxt',
                    caffe_root + 'models/VGGNet/VOC0712/DJI_6classes/DJI_6classes__iter_120000.caffemodel',
                    caffe.TEST)

In [2]:
from google.protobuf import text_format
from caffe.proto import caffe_pb2

# load PASCAL VOC model specs
file = open(caffe_root + 'models/VGGNet/VOC0712/DJI_6classes/deploy.prototxt', 'r')
voc_netspec = caffe_pb2.NetParameter()
text_format.Merge(str(file.read()), voc_netspec)

# load PASCAL VOC labels
voc_labelmap_file = caffe_root + 'data/VOC0712/labelmap_voc.prototxt'
file = open(voc_labelmap_file, 'r')
voc_labelmap = caffe_pb2.LabelMap()
text_format.Merge(str(file.read()), voc_labelmap)

def get_labelname(labelmap, labels):
    num_labels = len(labelmap.item)
    labelnames = []
    if type(labels) is not list:
        labels = [labels]
    for label in labels:
        found = False
        for i in xrange(0, num_labels):
            if label == labelmap.item[i].label:
                found = True
                labelnames.append(labelmap.item[i].display_name)
                break
        assert found == True
    return labelnames

In [3]:
# input preprocessing: 'data' is the name of the input blob == net.inputs[0]
transformer = caffe.io.Transformer({'data': voc_net.blobs['data'].data.shape})
transformer.set_transpose('data', (2, 0, 1))
transformer.set_mean('data', np.array([104,117,123])) # mean pixel
transformer.set_raw_scale('data', 255)  # the reference model operates on images in [0,255] range instead of [0,1]
transformer.set_channel_swap('data', (2,1,0))  # the reference model has channels in BGR order instead of RGB

In [13]:
import os
import os.path                               
import cv2

image_resize = 320
voc_net.blobs['data'].reshape(1,3,image_resize,image_resize)

videoCapture = cv2.VideoCapture('/home/dawn/DJI_0001.MOV')
fourcc = cv2.VideoWriter_fourcc(*'MPEG')
frame_width = 1920
frame_height = 1080
videoWriter = cv2.VideoWriter('/home/dawn/DJI_0001_out.avi', fourcc, 25, (frame_width, frame_height))
if not videoWriter.open:
    print('videoWriter can not write!')
success, frame = videoCapture.read()

while success:
    
    image = frame/255.0
    transformed_image = transformer.preprocess('data', image)
    voc_net.blobs['data'].data[...] = transformed_image

    # Forward pass.
    detections = voc_net.forward()['detection_out']

    # Parse the outputs.
    det_label = detections[0,0,:,1]
    det_conf = detections[0,0,:,2]
    det_xmin = detections[0,0,:,3]
    det_ymin = detections[0,0,:,4]
    det_xmax = detections[0,0,:,5]
    det_ymax = detections[0,0,:,6]

    # Get detections with confidence higher than 0.6.
    top_indices = [i for i, conf in enumerate(det_conf) if conf >= 0.2]

    top_conf = det_conf[top_indices]
    top_label_indices = det_label[top_indices].tolist()
    top_labels = get_labelname(voc_labelmap, top_label_indices)
    top_xmin = det_xmin[top_indices]
    top_ymin = det_ymin[top_indices]
    top_xmax = det_xmax[top_indices]
    top_ymax = det_ymax[top_indices]

    colors = [(255,255,255),(255,0,0),(0,255,0),(0,0,255),(255,255,0),(255,0,255),(0,255,255)]

    for i in xrange(top_conf.shape[0]):
        xmin = int(round(top_xmin[i] * frame_width))
        ymin = int(round(top_ymin[i] * frame_height))
        xmax = int(round(top_xmax[i] * frame_width))
        ymax = int(round(top_ymax[i] * frame_height))
        score = top_conf[i]
        label = int(top_label_indices[i])
        label_name = top_labels[i]
        display_txt = '%s: %.2f'%(label_name, score)
        color = colors[label]
        cv2.rectangle(frame, (xmin,ymin), (xmax,ymax), color, 2)
        cv2.putText(frame, display_txt, (xmin,ymin), 1, 1, color, 2, cv2.LINE_AA)
    cv2.imshow('Output', frame)
    cv2.waitKey(1)
    videoWriter.write(frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break
    success, frame = videoCapture.read()

videoCapture.release()
videoWriter.release()
cv2.destroyAllWindows()
    
print('Done.')

#  
#                       _oo0oo_  
#                      o8888888o  
#                      88" . "88  
#                      (| -_- |)  
#                      0\  =  /0  
#                    ___/`---'\___  
#                  .' \\|     |# '.  
#                 / \\|||  :  |||# \  
#                / _||||| -:- |||||- \  
#               |   | \\\  -  #/ |   |  
#               | \_|  ''\---/''  |_/ |  
#               \  .-\__  '-'  ___/-. /  
#             ___'. .'  /--.--\  `. .'___  
#          ."" '<  `.___\_<|>_/___.' >' "".  
#         | | :  `- \`.;`\ _ /`;.`/ - ` : | |  
#         \  \ `_.   \_ __\ /__ _/   .-` /  /  
#     =====`-.____`.___ \_____/___.-`___.-'=====  
#                       `=---='  
#  
#  
#     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~  
#  
#               佛祖保佑         永无BUG  
#  
# 
#

IndexError: list index out of range