Permalink
Find file
Fetching contributors…
Cannot retrieve contributors at this time
113 lines (91 sloc) 4.08 KB
#! /usr/bin/python
# -*- coding: utf-8 -*-
"""以下URLをそれっぽく日本語にしてテキスト化したやつ.
http://dlib.net/train_object_detector.py.html
"""
import os
import sys
import glob
import dlib
from skimage import io
'''
このスクリプトは引数にface_datasetディレクトリを指定し、それを学習します
'''
if len(sys.argv) != 2:
print(
"Give the path to the examples/faces directory as the argument to this"
"program. For example, if you are in the python_examples folder then "
"execute this program by running:\n"
" ./train_object_detector.py ../examples/faces")
exit()
faces_folder = sys.argv[1]
# simple_object_detectorの訓練用オプションを取ってくる
options = dlib.simple_object_detector_training_options()
# 左右対照に学習データを増やすならtrueで訓練(メモリを使う)
options.add_left_right_image_flips = True
# SVMを使ってるのでC値を設定する必要がある
options.C = 5
# スレッド数指定
options.num_threads = 16
# 学習途中の出力をするかどうか
options.be_verbose = True
# 停止許容範囲
options.epsilon = 0.001
# サンプルを増やす最大数(大きすぎるとメモリを使う)
options.upsample_limit = 8
# 矩形検出の最小窓サイズ(80*80=6400となる)
options.detection_window_size = 6400
# ファルダパスとxmlパスを設定
# xmlを作るにはdlib本体のtools/imglabディレクトリ
# GUIツールがあるからdlib本体のtools/imglab/README.txtを参考
training_xml_path = os.path.join(faces_folder, "training.xml")
testing_xml_path = os.path.join(faces_folder, "testing.xml")
# detector.svmに学習結果が保存されるように設定
dlib.train_simple_object_detector(training_xml_path, "detector.svm", options)
# 学習用、テスト用データを入れてprecision, recallとその平均を出力
print("Training accuracy: {}".format(
dlib.test_simple_object_detector(training_xml_path, "detector.svm")))
print("Testing accuracy: {}".format(
dlib.test_simple_object_detector(testing_xml_path, "detector.svm")))
# 保存されたdetector.svmによる検出
detector = dlib.simple_object_detector("detector.svm")
# 学習に使ったHoGフィルタを見る時は以下の通り
win_det = dlib.image_window()
win_det.set_image(detector)
# フォルダ内の画像で検出します
win = dlib.image_window()
for f in glob.glob(os.path.join(faces_folder, "*.jpg")):
print("Processing file: {}".format(f))
img = io.imread(f)
dets = detector(img)
print("Number of faces detected: {}".format(len(dets)))
for k, d in enumerate(dets):
print("Detection {}: Left: {} Top: {} Right: {} Bottom: {}".format(
k, d.left(), d.top(), d.right(), d.bottom()))
# 画像オーバーレイして表示してEnter待つとかもできるよ
win.clear_overlay()
win.set_image(img)
win.add_overlay(dets)
dlib.hit_enter_to_continue()
# 実はXML形式で入力しなくても以下のように呼び出して使う事もできるよ
# io.imreadした画像のリストを作成
images = [io.imread(faces_folder + '/2008_002506.jpg'),
io.imread(faces_folder + '/2009_004587.jpg')]
# それぞれの矩形部分をrectangleで保存しておく
# そのリストを作成
boxes_img1 = ([dlib.rectangle(left=329, top=78, right=437, bottom=186),
dlib.rectangle(left=224, top=95, right=314, bottom=185),
dlib.rectangle(left=125, top=65, right=214, bottom=155)])
boxes_img2 = ([dlib.rectangle(left=154, top=46, right=228, bottom=121),
dlib.rectangle(left=266, top=280, right=328, bottom=342)])
boxes = [boxes_img1, boxes_img2]
# 学習(上記とは違った感じで)
detector2 = dlib.train_simple_object_detector(images, boxes, options)
# 結果を保存する時はこんな感じでもOK
detector2.save('detector2.svm')
# 新しいHoGフィルタの結果
win_det.set_image(detector2)
dlib.hit_enter_to_continue()
# 読み込み形式での出力
print("\nTraining accuracy: {}".format(
dlib.test_simple_object_detector(images, boxes, detector2)))