-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add multiple directories and associated code
- Loading branch information
1 parent
a82e48c
commit 8d3be73
Showing
63 changed files
with
8,855 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding:utf-8 -*- | ||
# Copyright (c) Megvii, Inc. and its affiliates. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,320 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding:utf-8 -*- | ||
# Copyright (c) Megvii, Inc. and its affiliates. | ||
|
||
import argparse | ||
import os | ||
import time | ||
from loguru import logger | ||
|
||
import cv2 | ||
|
||
import torch | ||
|
||
from yolox.data.data_augment import ValTransform | ||
from yolox.data.datasets import COCO_CLASSES | ||
from yolox.exp import get_exp | ||
from yolox.utils import fuse_model, get_model_info, postprocess, vis | ||
|
||
IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"] | ||
|
||
|
||
def make_parser(): | ||
parser = argparse.ArgumentParser("YOLOX Demo!") | ||
parser.add_argument( | ||
"demo", default="image", help="demo type, eg. image, video and webcam" | ||
) | ||
parser.add_argument("-expn", "--experiment-name", type=str, default=None) | ||
parser.add_argument("-n", "--name", type=str, default=None, help="model name") | ||
|
||
parser.add_argument( | ||
"--path", default="./assets/dog.jpg", help="path to images or video" | ||
) | ||
parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id") | ||
parser.add_argument( | ||
"--save_result", | ||
action="store_true", | ||
help="whether to save the inference result of image/video", | ||
) | ||
|
||
# exp file | ||
parser.add_argument( | ||
"-f", | ||
"--exp_file", | ||
default=None, | ||
type=str, | ||
help="please input your experiment description file", | ||
) | ||
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval") | ||
parser.add_argument( | ||
"--device", | ||
default="cpu", | ||
type=str, | ||
help="device to run our model, can either be cpu or gpu", | ||
) | ||
parser.add_argument("--conf", default=0.3, type=float, help="test conf") | ||
parser.add_argument("--nms", default=0.3, type=float, help="test nms threshold") | ||
parser.add_argument("--tsize", default=None, type=int, help="test img size") | ||
parser.add_argument( | ||
"--fp16", | ||
dest="fp16", | ||
default=False, | ||
action="store_true", | ||
help="Adopting mix precision evaluating.", | ||
) | ||
parser.add_argument( | ||
"--legacy", | ||
dest="legacy", | ||
default=False, | ||
action="store_true", | ||
help="To be compatible with older versions", | ||
) | ||
parser.add_argument( | ||
"--fuse", | ||
dest="fuse", | ||
default=False, | ||
action="store_true", | ||
help="Fuse conv and bn for testing.", | ||
) | ||
parser.add_argument( | ||
"--trt", | ||
dest="trt", | ||
default=False, | ||
action="store_true", | ||
help="Using TensorRT model for testing.", | ||
) | ||
return parser | ||
|
||
|
||
def get_image_list(path): | ||
image_names = [] | ||
for maindir, subdir, file_name_list in os.walk(path): | ||
for filename in file_name_list: | ||
apath = os.path.join(maindir, filename) | ||
ext = os.path.splitext(apath)[1] | ||
if ext in IMAGE_EXT: | ||
image_names.append(apath) | ||
return image_names | ||
|
||
|
||
class Predictor(object): | ||
def __init__( | ||
self, | ||
model, | ||
exp, | ||
cls_names=COCO_CLASSES, | ||
trt_file=None, | ||
decoder=None, | ||
device="cpu", | ||
fp16=False, | ||
legacy=False, | ||
): | ||
self.model = model | ||
self.cls_names = cls_names | ||
self.decoder = decoder | ||
self.num_classes = exp.num_classes | ||
self.confthre = exp.test_conf | ||
self.nmsthre = exp.nmsthre | ||
self.test_size = exp.test_size | ||
self.device = device | ||
self.fp16 = fp16 | ||
self.preproc = ValTransform(legacy=legacy) | ||
if trt_file is not None: | ||
from torch2trt import TRTModule | ||
|
||
model_trt = TRTModule() | ||
model_trt.load_state_dict(torch.load(trt_file)) | ||
|
||
x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda() | ||
self.model(x) | ||
self.model = model_trt | ||
|
||
def inference(self, img): | ||
img_info = {"id": 0} | ||
if isinstance(img, str): | ||
img_info["file_name"] = os.path.basename(img) | ||
img = cv2.imread(img) | ||
else: | ||
img_info["file_name"] = None | ||
|
||
height, width = img.shape[:2] | ||
img_info["height"] = height | ||
img_info["width"] = width | ||
img_info["raw_img"] = img | ||
|
||
ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1]) | ||
img_info["ratio"] = ratio | ||
|
||
img, _ = self.preproc(img, None, self.test_size) | ||
img = torch.from_numpy(img).unsqueeze(0) | ||
img = img.float() | ||
if self.device == "gpu": | ||
img = img.cuda() | ||
if self.fp16: | ||
img = img.half() # to FP16 | ||
|
||
with torch.no_grad(): | ||
t0 = time.time() | ||
outputs = self.model(img) | ||
if self.decoder is not None: | ||
outputs = self.decoder(outputs, dtype=outputs.type()) | ||
outputs = postprocess( | ||
outputs, self.num_classes, self.confthre, | ||
self.nmsthre | ||
) | ||
logger.info("Infer time: {:.4f}s".format(time.time() - t0)) | ||
return outputs, img_info | ||
|
||
def visual(self, output, img_info, cls_conf=0.35): | ||
ratio = img_info["ratio"] | ||
img = img_info["raw_img"] | ||
if output is None: | ||
return img | ||
output = output.cpu() | ||
|
||
bboxes = output[:, 0:4] | ||
|
||
# preprocessing: resize | ||
bboxes /= ratio | ||
|
||
cls = output[:, 6] | ||
scores = output[:, 4] * output[:, 5] | ||
|
||
vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names) | ||
return vis_res | ||
|
||
|
||
def image_demo(predictor, vis_folder, path, current_time, save_result): | ||
if os.path.isdir(path): | ||
files = get_image_list(path) | ||
else: | ||
files = [path] | ||
files.sort() | ||
for image_name in files: | ||
outputs, img_info = predictor.inference(image_name) | ||
result_image = predictor.visual(outputs[0], img_info, predictor.confthre) | ||
if save_result: | ||
save_folder = os.path.join( | ||
vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time) | ||
) | ||
os.makedirs(save_folder, exist_ok=True) | ||
save_file_name = os.path.join(save_folder, os.path.basename(image_name)) | ||
logger.info("Saving detection result in {}".format(save_file_name)) | ||
cv2.imwrite(save_file_name, result_image) | ||
ch = cv2.waitKey(0) | ||
if ch == 27 or ch == ord("q") or ch == ord("Q"): | ||
break | ||
|
||
|
||
def imageflow_demo(predictor, vis_folder, current_time, args): | ||
cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid) | ||
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float | ||
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float | ||
fps = cap.get(cv2.CAP_PROP_FPS) | ||
if args.save_result: | ||
save_folder = os.path.join( | ||
vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time) | ||
) | ||
os.makedirs(save_folder, exist_ok=True) | ||
if args.demo == "video": | ||
save_path = os.path.join(save_folder, os.path.basename(args.path)) | ||
else: | ||
save_path = os.path.join(save_folder, "camera.mp4") | ||
logger.info(f"video save_path is {save_path}") | ||
vid_writer = cv2.VideoWriter( | ||
save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height)) | ||
) | ||
while True: | ||
ret_val, frame = cap.read() | ||
if ret_val: | ||
outputs, img_info = predictor.inference(frame) | ||
result_frame = predictor.visual(outputs[0], img_info, predictor.confthre) | ||
if args.save_result: | ||
vid_writer.write(result_frame) | ||
else: | ||
cv2.namedWindow("yolox", cv2.WINDOW_NORMAL) | ||
cv2.imshow("yolox", result_frame) | ||
ch = cv2.waitKey(1) | ||
if ch == 27 or ch == ord("q") or ch == ord("Q"): | ||
break | ||
else: | ||
break | ||
|
||
|
||
def main(exp, args): | ||
if not args.experiment_name: | ||
args.experiment_name = exp.exp_name | ||
|
||
file_name = os.path.join(exp.output_dir, args.experiment_name) | ||
os.makedirs(file_name, exist_ok=True) | ||
|
||
vis_folder = None | ||
if args.save_result: | ||
vis_folder = os.path.join(file_name, "vis_res") | ||
os.makedirs(vis_folder, exist_ok=True) | ||
|
||
if args.trt: | ||
args.device = "gpu" | ||
|
||
logger.info("Args: {}".format(args)) | ||
|
||
if args.conf is not None: | ||
exp.test_conf = args.conf | ||
if args.nms is not None: | ||
exp.nmsthre = args.nms | ||
if args.tsize is not None: | ||
exp.test_size = (args.tsize, args.tsize) | ||
|
||
model = exp.get_model() | ||
logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size))) | ||
|
||
if args.device == "gpu": | ||
model.cuda() | ||
if args.fp16: | ||
model.half() # to FP16 | ||
model.eval() | ||
|
||
if not args.trt: | ||
if args.ckpt is None: | ||
ckpt_file = os.path.join(file_name, "best_ckpt.pth") | ||
else: | ||
ckpt_file = args.ckpt | ||
logger.info("loading checkpoint") | ||
ckpt = torch.load(ckpt_file, map_location="cpu") | ||
# load the model state dict | ||
model.load_state_dict(ckpt["model"]) | ||
logger.info("loaded checkpoint done.") | ||
|
||
if args.fuse: | ||
logger.info("\tFusing model...") | ||
model = fuse_model(model) | ||
|
||
if args.trt: | ||
assert not args.fuse, "TensorRT model is not support model fusing!" | ||
trt_file = os.path.join(file_name, "model_trt.pth") | ||
assert os.path.exists( | ||
trt_file | ||
), "TensorRT model is not found!\n Run python3 tools/trt.py first!" | ||
model.head.decode_in_inference = False | ||
decoder = model.head.decode_outputs | ||
logger.info("Using TensorRT to inference") | ||
else: | ||
trt_file = None | ||
decoder = None | ||
|
||
predictor = Predictor( | ||
model, exp, COCO_CLASSES, trt_file, decoder, | ||
args.device, args.fp16, args.legacy, | ||
) | ||
current_time = time.localtime() | ||
if args.demo == "image": | ||
image_demo(predictor, vis_folder, args.path, current_time, args.save_result) | ||
elif args.demo == "video" or args.demo == "webcam": | ||
imageflow_demo(predictor, vis_folder, current_time, args) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = make_parser().parse_args() | ||
exp = get_exp(args.exp_file, args.name) | ||
|
||
main(exp, args) |
Oops, something went wrong.