In [22]:
from pathlib import Path
import importlib

import numpy as np
import chainer
from chainer import training, optimizer_hooks
from sklearn.model_selection import train_test_split

In [23]:
from YOLOv3.lib.links.yolov3 import YOLOv3
from YOLOv3.lib.links.loss import YOLOv3Loss
from YOLOv3.lib.links.predict import YOLOv3Predictor

from src import Dataset
importlib.reload(Dataset)

<module 'src.Dataset' from '/home/kitamura/workdir/YOLO/src/Dataset.py'>

In [13]:
# モデルの用意

n_class = 5
base = None # default is yolo
ignore_thresh = 0.5
device = -1 # cpu (< 0) or gpu (>= 0)

yolov3 = YOLOv3(n_class, None, ignore_thresh)
model = YOLOv3Loss(yolov3)

optimizer = chainer.optimizers.MomentumSGD(lr=0.001)
optimizer.setup(model)
optimizer.add_hook(optimizer_hooks.WeightDecay(0.0005), "hook_decay")
optimizer.add_hook(optimizer_hooks.GradientClipping(10.0), "hook_grad_clip")

In [21]:
# データの用意
dataset_path = Path.home() / Path("work/dataset/COCO/")
batchsize = 8

dataset = Dataset.load_dataset(dataset_path)
train_dataset, test_dataset = train_test_split(dataset, test_size=0.2, random_state=0)
train, test = Dataset.YOLODataset(train_dataset), Dataset.YOLODataset(test_dataset)

train_iter = chainer.iterators.SerialIterator(train, batchsize)
test_iter = chainer.iterators.SerialIterator(test, batchsize, repeat=False, shuffle=False)

TypeError: Expected sequence or array-like, got <class 'NoneType'>

In [None]:
# トレーニングの構築

iteration = 5000
output_dir = "result"
display_interval = 100
snapshot_interval = 1000

if not Path(output_dir).exists():
    Path(output_dir).mkdir()

updater = training.StandardUpdater(train_iter, optimizer, converter=Dataset.concat_yolo, device=device)
trainer = training.Trainer(updater, (iteration, 'iteration'), out=output_dir)

trainer.extend(extensions.Evaluator(test_iter, model, converter=Dataset.concat_yolo, device=device),
               trigger=display_interval)
trainer.extend(extensons.dump_graph("main/loss"))
trainer.extend(extensons.LogReport(trigger=display_interval))
trainer.extend(extensions.PrintReport(
    ['epoch', 'iteration', 'main/loss', 'validation/main/loss', 'elapsed_time'],
    trigger=display_interval))
trainer.extend(extensions.snapshot_object(yolov3, "yolov3_snapshot_epoch-{.updater.epoch}"),
               trigger=snapshot_interval)

In [None]:
trainer.run()

In [None]:
# デバッグ、プロット
thresh = 0.5
detector = YOLOv3Predictor(yolov3, thresh)

image, bboxes, labels = test_dataset[0]
detection = detector(image[np.newaxis, :, :, :])[0]

fig = plt.figure()
ax = plt.axes()

Plot.plot_image_and_bbox(ax, image, detection['box'], detection['prob'])

plt.show()