Skip to content

Commit

Permalink
feat(python): 检测器实现
Browse files Browse the repository at this point in the history
  • Loading branch information
zjZSTU committed May 13, 2020
1 parent 1d55c2d commit b873fe1
Showing 1 changed file with 122 additions and 0 deletions.
122 changes: 122 additions & 0 deletions py/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import glob
import os
import time

import torch
from PIL import Image
from vizer.draw import draw_boxes

from ssd.config import cfg
from ssd.data.datasets import COCODataset, VOCDataset
import argparse
import numpy as np

from ssd.data.transform import build_transforms
from ssd.models.detector import build_detection_model
from ssd.utils import mkdir
from ssd.utils.checkpoint import CheckPointer


@torch.no_grad()
def run_demo(cfg, ckpt, score_threshold, images_dir, output_dir, dataset_type):
if dataset_type == "voc":
class_names = VOCDataset.class_names
elif dataset_type == 'coco':
class_names = COCODataset.class_names
else:
raise NotImplementedError('Not implemented now.')
device = torch.device(cfg.MODEL.DEVICE)

model = build_detection_model(cfg)
model = model.to(device)
checkpointer = CheckPointer(model, save_dir=cfg.OUTPUT_DIR)
checkpointer.load(ckpt, use_latest=ckpt is None)
weight_file = ckpt if ckpt else checkpointer.get_checkpoint_file()
print('Loaded weights from {}'.format(weight_file))

image_paths = glob.glob(os.path.join(images_dir, '*.jpg'))
mkdir(output_dir)

cpu_device = torch.device("cpu")
transforms = build_transforms(cfg, is_train=False)
model.eval()
for i, image_path in enumerate(image_paths):
start = time.time()
image_name = os.path.basename(image_path)

image = np.array(Image.open(image_path).convert("RGB"))
height, width = image.shape[:2]
images = transforms(image)[0].unsqueeze(0)
load_time = time.time() - start

start = time.time()
result = model(images.to(device))[0]
inference_time = time.time() - start

result = result.resize((width, height)).to(cpu_device).numpy()
boxes, labels, scores = result['boxes'], result['labels'], result['scores']

indices = scores > score_threshold
boxes = boxes[indices]
labels = labels[indices]
scores = scores[indices]
meters = ' | '.join(
[
'objects {:02d}'.format(len(boxes)),
'load {:03d}ms'.format(round(load_time * 1000)),
'inference {:03d}ms'.format(round(inference_time * 1000)),
'FPS {}'.format(round(1.0 / inference_time))
]
)
print('({:04d}/{:04d}) {}: {}'.format(i + 1, len(image_paths), image_name, meters))

drawn_image = draw_boxes(image, boxes, labels, scores, class_names).astype(np.uint8)
Image.fromarray(drawn_image).save(os.path.join(output_dir, image_name))


def main():
parser = argparse.ArgumentParser(description="SSD Demo.")
parser.add_argument(
"--config-file",
default="",
metavar="FILE",
help="path to config file",
type=str,
)
parser.add_argument("--ckpt", type=str, default=None, help="Trained weights.")
parser.add_argument("--score_threshold", type=float, default=0.7)
parser.add_argument("--images_dir", default='demo', type=str, help='Specify a image dir to do prediction.')
parser.add_argument("--output_dir", default='demo/result', type=str,
help='Specify a image dir to save predicted images.')
parser.add_argument("--dataset_type", default="voc", type=str,
help='Specify dataset type. Currently support voc and coco.')

parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
args = parser.parse_args()
print(args)

cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()

print("Loaded configuration file {}".format(args.config_file))
with open(args.config_file, "r") as cf:
config_str = "\n" + cf.read()
print(config_str)
print("Running with config:\n{}".format(cfg))

run_demo(cfg=cfg,
ckpt=args.ckpt,
score_threshold=args.score_threshold,
images_dir=args.images_dir,
output_dir=args.output_dir,
dataset_type=args.dataset_type)


if __name__ == '__main__':
main()

0 comments on commit b873fe1

Please sign in to comment.