Load necessary libs and set up caffe and caffe_root

In [1]:
from __future__ import print_function
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

plt.rcParams['figure.figsize'] = (10, 10)
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

# Make sure that caffe is on the python path:
import os
os.chdir('..')
caffe_root = './'
import sys
sys.path.insert(0, caffe_root + 'python')

import caffe
caffe.set_device(0)
caffe.set_mode_gpu()

voc_net = caffe.Net(caffe_root + 'models/VGGNet/VOC0712/DJI_6classes/deploy.prototxt',
                    caffe_root + 'models/VGGNet/VOC0712/DJI_6classes/DJI_6classes__iter_120000.caffemodel',
                    caffe.TEST)

Set Caffe to CPU mode, load the net in the test phase for inference, and configure input preprocessing.

In [2]:
from google.protobuf import text_format
from caffe.proto import caffe_pb2

# load PASCAL VOC model specs
file = open(caffe_root + 'models/VGGNet/VOC0712/DJI_6classes/deploy.prototxt', 'r')
voc_netspec = caffe_pb2.NetParameter()
text_format.Merge(str(file.read()), voc_netspec)

# load PASCAL VOC labels
voc_labelmap_file = caffe_root + 'data/VOC0712/labelmap_voc.prototxt'
file = open(voc_labelmap_file, 'r')
voc_labelmap = caffe_pb2.LabelMap()
text_format.Merge(str(file.read()), voc_labelmap)

def get_labelname(labelmap, labels):
    num_labels = len(labelmap.item)
    labelnames = []
    if type(labels) is not list:
        labels = [labels]
    for label in labels:
        found = False
        for i in xrange(0, num_labels):
            if label == labelmap.item[i].label:
                found = True
                labelnames.append(labelmap.item[i].display_name)
                break
        assert found == True
    return labelnames

In [3]:
# input preprocessing: 'data' is the name of the input blob == net.inputs[0]
transformer = caffe.io.Transformer({'data': voc_net.blobs['data'].data.shape})
transformer.set_transpose('data', (2, 0, 1))
transformer.set_mean('data', np.array([104,117,123])) # mean pixel
transformer.set_raw_scale('data', 255)  # the reference model operates on images in [0,255] range instead of [0,1]
transformer.set_channel_swap('data', (2,1,0))  # the reference model has channels in BGR order instead of RGB

Load an image. 

In [4]:
image = caffe.io.load_image('/home/dawn/data/VOCdevkit/VOC2007/JPEGImages/DJI_0001_00230.jpg')
transformed_image = transformer.preprocess('data', image)

# set net to batch size of 1
# coco_net.blobs['data'].reshape(1,3,320,320)
voc_net.blobs['data'].reshape(1,3,320,320)

# resizes the image to the right size, applies transformation etc. 
# coco_net.blobs['data'].data[...] = transformed_image
voc_net.blobs['data'].data[...] = transformed_image

orig_image = transformer.deprocess('data', voc_net.blobs['data'].data)

In [5]:
import shutil

dataset_dir = '/home/dawn/data/VOCdevkit/VOC2007/'
testlist_path = dataset_dir + 'ImageSets/Main/test.txt'

test_img_dir = '/home/dawn/data/VOCdevkit/VOC2007/test_img/'
test_output_dir = '/home/dawn/data/VOCdevkit/VOC2007/test_img_output/'

if not os.path.exists(test_img_dir):
    os.mkdir(test_img_dir)
    
if not os.path.exists(test_output_dir):
    os.mkdir(test_output_dir)

testlist = open(testlist_path)
for eachline in testlist:
    imgname = eachline[:-1]+'.jpg'
    shutil.copy(dataset_dir+'JPEGImages/'+imgname, test_img_dir+imgname)

print('Test images are save to : ' + test_img_dir)

print('Test images output are save to : ' + test_output_dir)

Test images are save to : /home/dawn/data/VOCdevkit/VOC2007/test_img/
Test images output are save to : /home/dawn/data/VOCdevkit/VOC2007/test_img_output/


Top5 detections using voc model.

In [6]:
# set net to batch size of 1
image_resize = 320
voc_net.blobs['data'].reshape(1,3,image_resize,image_resize)

import os
import os.path                               

for parent,dirnames,filenames in os.walk(test_img_dir):      #三个参数：分别返回1.父目录 2.所有文件夹名字（不含路径） 3.所有文件名字
    for filename in filenames:

        image = caffe.io.load_image(parent+str(filename))
        #plt.imshow(image)
        transformed_image = transformer.preprocess('data', image)
        voc_net.blobs['data'].data[...] = transformed_image
        orig_image = transformer.deprocess('data', voc_net.blobs['data'].data)

        # Forward pass.
        detections = voc_net.forward()['detection_out']

        # Parse the outputs.
        det_label = detections[0,0,:,1]
        det_conf = detections[0,0,:,2]
        det_xmin = detections[0,0,:,3]
        det_ymin = detections[0,0,:,4]
        det_xmax = detections[0,0,:,5]
        det_ymax = detections[0,0,:,6]

        # Get detections with confidence higher than 0.6.
        top_indices = [i for i, conf in enumerate(det_conf) if conf >= 0.2]

        top_conf = det_conf[top_indices]
        top_label_indices = det_label[top_indices].tolist()
        top_labels = get_labelname(voc_labelmap, top_label_indices)
        top_xmin = det_xmin[top_indices]
        top_ymin = det_ymin[top_indices]
        top_xmax = det_xmax[top_indices]
        top_ymax = det_ymax[top_indices]

        colors = plt.cm.hsv(np.linspace(0, 1, 21)).tolist()
        plt.figure(filename)
        plt.imshow(orig_image)
        currentAxis = plt.gca()

        for i in xrange(top_conf.shape[0]):
#             xmin = int(round(top_xmin[i] * image.shape[1]))
#             ymin = int(round(top_ymin[i] * image.shape[0]))
#             xmax = int(round(top_xmax[i] * image.shape[1]))
#             ymax = int(round(top_ymax[i] * image.shape[0]))
            xmin = int(round(top_xmin[i] * image_resize))
            ymin = int(round(top_ymin[i] * image_resize))
            xmax = int(round(top_xmax[i] * image_resize))
            ymax = int(round(top_ymax[i] * image_resize))
            score = top_conf[i]
            label = int(top_label_indices[i])
            label_name = top_labels[i]
            display_txt = '%s: %.2f'%(label_name, score)
            coords = (xmin, ymin), xmax-xmin+1, ymax-ymin+1
            color = colors[label]
            currentAxis.add_patch(plt.Rectangle(*coords, fill=False, edgecolor=color, linewidth=2))
            currentAxis.text(xmin, ymin, display_txt, bbox={'facecolor':color, 'alpha':0.5})
        plt.savefig(test_output_dir+str(filename))
        
        plt.close(filename)
        
print('Done.')