Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f9f4817
commit 07ba056
Showing
5 changed files
with
331 additions
and
158 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,322 @@ | ||
""" | ||
An example: | ||
Please make sure you are using torchvision >= 0.5.0. | ||
wget https://cloudstor.aarnet.edu.au/plus/s/38fQAdi2HBkn274/download -O fcos_imprv_R_50_FPN_1x.onnx | ||
python onnx/test_fcos_onnx_model.py \ | ||
--onnx-model fcos_imprv_R_50_FPN_1x.onnx \ | ||
--config-file configs/fcos/fcos_imprv_R_50_FPN_1x.yaml \ | ||
TEST.IMS_PER_BATCH 1 \ | ||
INPUT.MIN_SIZE_TEST 800 | ||
If you encounter an out of memory error, please try to reduce INPUT.MIN_SIZE_TEST. | ||
""" | ||
from fcos_core.utils.env import setup_environment # noqa F401 isort:skip | ||
|
||
import argparse | ||
import os | ||
|
||
import torch | ||
from torch import nn | ||
import onnx | ||
from fcos_core.config import cfg | ||
from fcos_core.data import make_data_loader | ||
from fcos_core.engine.inference import inference | ||
from fcos_core.utils.collect_env import collect_env_info | ||
from fcos_core.utils.comm import synchronize, get_rank | ||
from fcos_core.utils.logger import setup_logger | ||
from fcos_core.utils.miscellaneous import mkdir | ||
from fcos_core.modeling.rpn.fcos.inference import make_fcos_postprocessor | ||
import caffe2.python.onnx.backend as backend | ||
import numpy as np | ||
from fcos_core.structures.bounding_box import BoxList | ||
from fcos_core.structures.boxlist_ops import cat_boxlist | ||
from fcos_core.structures.boxlist_ops import remove_small_boxes | ||
from torchvision.ops.boxes import batched_nms | ||
|
||
|
||
class FCOSPostProcessor(torch.nn.Module): | ||
""" | ||
Performs post-processing on the outputs of the RetinaNet boxes. | ||
This is only used in the testing. | ||
""" | ||
def __init__( | ||
self, | ||
pre_nms_thresh, | ||
pre_nms_top_n, | ||
nms_thresh, | ||
fpn_post_nms_top_n, | ||
min_size, | ||
num_classes | ||
): | ||
""" | ||
Arguments: | ||
pre_nms_thresh (float) | ||
pre_nms_top_n (int) | ||
nms_thresh (float) | ||
fpn_post_nms_top_n (int) | ||
min_size (int) | ||
num_classes (int) | ||
box_coder (BoxCoder) | ||
""" | ||
super(FCOSPostProcessor, self).__init__() | ||
self.pre_nms_thresh = pre_nms_thresh | ||
self.pre_nms_top_n = pre_nms_top_n | ||
self.nms_thresh = nms_thresh | ||
self.fpn_post_nms_top_n = fpn_post_nms_top_n | ||
self.min_size = min_size | ||
self.num_classes = num_classes | ||
|
||
def forward_for_single_feature_map( | ||
self, locations, box_cls, | ||
box_regression, centerness, | ||
image_sizes): | ||
""" | ||
Arguments: | ||
anchors: list[BoxList] | ||
box_cls: tensor of size N, A * C, H, W | ||
box_regression: tensor of size N, A * 4, H, W | ||
""" | ||
N, C, H, W = box_cls.shape | ||
|
||
# put in the same format as locations | ||
box_cls = box_cls.view(N, C, H, W).permute(0, 2, 3, 1) | ||
box_cls = box_cls.reshape(N, -1, C).sigmoid() | ||
box_regression = box_regression.view(N, 4, H, W).permute(0, 2, 3, 1) | ||
box_regression = box_regression.reshape(N, -1, 4) | ||
centerness = centerness.view(N, 1, H, W).permute(0, 2, 3, 1) | ||
centerness = centerness.reshape(N, -1).sigmoid() | ||
|
||
candidate_inds = box_cls > self.pre_nms_thresh | ||
pre_nms_top_n = candidate_inds.view(N, -1).sum(1) | ||
pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n) | ||
|
||
# multiply the classification scores with centerness scores | ||
box_cls = box_cls * centerness[:, :, None] | ||
|
||
results = [] | ||
for i in range(N): | ||
per_box_cls = box_cls[i] | ||
per_candidate_inds = candidate_inds[i] | ||
per_box_cls = per_box_cls[per_candidate_inds] | ||
|
||
per_candidate_nonzeros = per_candidate_inds.nonzero() | ||
per_box_loc = per_candidate_nonzeros[:, 0] | ||
per_class = per_candidate_nonzeros[:, 1] + 1 | ||
|
||
per_box_regression = box_regression[i] | ||
per_box_regression = per_box_regression[per_box_loc] | ||
per_locations = locations[per_box_loc] | ||
|
||
per_pre_nms_top_n = pre_nms_top_n[i] | ||
|
||
if per_candidate_inds.sum().item() > per_pre_nms_top_n.item(): | ||
per_box_cls, top_k_indices = \ | ||
per_box_cls.topk(per_pre_nms_top_n, sorted=False) | ||
per_class = per_class[top_k_indices] | ||
per_box_regression = per_box_regression[top_k_indices] | ||
per_locations = per_locations[top_k_indices] | ||
|
||
detections = torch.stack([ | ||
per_locations[:, 0] - per_box_regression[:, 0], | ||
per_locations[:, 1] - per_box_regression[:, 1], | ||
per_locations[:, 0] + per_box_regression[:, 2], | ||
per_locations[:, 1] + per_box_regression[:, 3], | ||
], dim=1) | ||
|
||
h, w = image_sizes[i] | ||
boxlist = BoxList(detections, (int(w), int(h)), mode="xyxy") | ||
boxlist.add_field("labels", per_class) | ||
boxlist.add_field("scores", torch.sqrt(per_box_cls)) | ||
boxlist = boxlist.clip_to_image(remove_empty=False) | ||
boxlist = remove_small_boxes(boxlist, self.min_size) | ||
results.append(boxlist) | ||
|
||
return results | ||
|
||
def forward(self, locations, box_cls, box_regression, centerness, image_sizes): | ||
""" | ||
Arguments: | ||
anchors: list[list[BoxList]] | ||
box_cls: list[tensor] | ||
box_regression: list[tensor] | ||
image_sizes: list[(h, w)] | ||
Returns: | ||
boxlists (list[BoxList]): the post-processed anchors, after | ||
applying box decoding and NMS | ||
""" | ||
sampled_boxes = [] | ||
for _, (l, o, b, c) in enumerate(zip(locations, box_cls, box_regression, centerness)): | ||
sampled_boxes.append( | ||
self.forward_for_single_feature_map( | ||
l, o, b, c, image_sizes | ||
) | ||
) | ||
|
||
boxlists = list(zip(*sampled_boxes)) | ||
boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] | ||
boxlists = self.select_over_all_levels(boxlists) | ||
|
||
return boxlists | ||
|
||
# TODO very similar to filter_results from PostProcessor | ||
# but filter_results is per image | ||
# TODO Yang: solve this issue in the future. No good solution | ||
# right now. | ||
def select_over_all_levels(self, boxlists): | ||
num_images = len(boxlists) | ||
results = [] | ||
for i in range(num_images): | ||
# multiclass nms | ||
keep = batched_nms( | ||
boxlists[i].bbox, | ||
boxlists[i].get_field("scores"), | ||
boxlists[i].get_field("labels"), | ||
self.nms_thresh | ||
) | ||
result = boxlists[i][keep] | ||
number_of_detections = len(result) | ||
|
||
# Limit to max_per_image detections **over all classes** | ||
if number_of_detections > self.fpn_post_nms_top_n > 0: | ||
cls_scores = result.get_field("scores") | ||
image_thresh, _ = torch.kthvalue( | ||
cls_scores.cpu(), | ||
number_of_detections - self.fpn_post_nms_top_n + 1 | ||
) | ||
keep = cls_scores >= image_thresh.item() | ||
keep = torch.nonzero(keep).squeeze(1) | ||
result = result[keep] | ||
results.append(result) | ||
return results | ||
|
||
|
||
class ONNX_FCOS(nn.Module): | ||
def __init__(self, onnx_model_path, cfg): | ||
super(ONNX_FCOS, self).__init__() | ||
self.onnx_model = backend.prepare( | ||
onnx.load(onnx_model_path), | ||
device=cfg.MODEL.DEVICE.upper() | ||
) | ||
# Note that we still use PyTorch for postprocessing | ||
self.postprocessing = FCOSPostProcessor( | ||
pre_nms_thresh=cfg.MODEL.FCOS.INFERENCE_TH, | ||
pre_nms_top_n=cfg.MODEL.FCOS.PRE_NMS_TOP_N, | ||
nms_thresh=cfg.MODEL.FCOS.NMS_TH, | ||
fpn_post_nms_top_n=cfg.TEST.DETECTIONS_PER_IMG, | ||
min_size=0, | ||
num_classes=cfg.MODEL.FCOS.NUM_CLASSES | ||
) | ||
self.cfg = cfg | ||
self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDES | ||
|
||
def forward(self, images): | ||
outputs = self.onnx_model.run(images.tensors.cpu().numpy()) | ||
outputs = [torch.from_numpy(o).to(self.cfg.MODEL.DEVICE) for o in outputs] | ||
num_outputs = len(outputs) // 3 | ||
logits = outputs[:num_outputs] | ||
bbox_reg = outputs[num_outputs:2 * num_outputs] | ||
centerness = outputs[2 * num_outputs:] | ||
|
||
locations = self.compute_locations(logits) | ||
boxes = self.postprocessing(locations, logits, bbox_reg, centerness, images.image_sizes) | ||
return boxes | ||
|
||
def compute_locations(self, features): | ||
locations = [] | ||
for level, feature in enumerate(features): | ||
h, w = feature.size()[-2:] | ||
locations_per_level = self.compute_locations_per_level( | ||
h, w, self.fpn_strides[level], | ||
feature.device | ||
) | ||
locations.append(locations_per_level) | ||
return locations | ||
|
||
def compute_locations_per_level(self, h, w, stride, device): | ||
shifts_x = torch.arange( | ||
0, w * stride, step=stride, | ||
dtype=torch.float32, device=device | ||
) | ||
shifts_y = torch.arange( | ||
0, h * stride, step=stride, | ||
dtype=torch.float32, device=device | ||
) | ||
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) | ||
shift_x = shift_x.reshape(-1) | ||
shift_y = shift_y.reshape(-1) | ||
locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2 | ||
return locations | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Test onnx models of FCOS") | ||
parser.add_argument( | ||
"--config-file", | ||
default="/private/home/fmassa/github/detectron.pytorch_v2/configs/e2e_faster_rcnn_R_50_C4_1x_caffe2.yaml", | ||
metavar="FILE", | ||
help="path to config file", | ||
) | ||
parser.add_argument( | ||
"--onnx-model", | ||
default="fcos_imprv_R_50_FPN_1x.onnx", | ||
metavar="FILE", | ||
help="path to the onnx model", | ||
) | ||
parser.add_argument( | ||
"opts", | ||
help="Modify config options using the command-line", | ||
default=None, | ||
nargs=argparse.REMAINDER, | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
cfg.merge_from_file(args.config_file) | ||
cfg.merge_from_list(args.opts) | ||
|
||
# The onnx model can only be used with DATALOADER.NUM_WORKERS = 0 | ||
cfg.DATALOADER.NUM_WORKERS = 0 | ||
|
||
cfg.freeze() | ||
|
||
save_dir = "" | ||
logger = setup_logger("fcos_core", save_dir, get_rank()) | ||
logger.info(cfg) | ||
|
||
logger.info("Collecting env info (might take some time)") | ||
logger.info("\n" + collect_env_info()) | ||
|
||
model = ONNX_FCOS(args.onnx_model, cfg) | ||
model.to(cfg.MODEL.DEVICE) | ||
|
||
iou_types = ("bbox",) | ||
if cfg.MODEL.MASK_ON: | ||
iou_types = iou_types + ("segm",) | ||
if cfg.MODEL.KEYPOINT_ON: | ||
iou_types = iou_types + ("keypoints",) | ||
output_folders = [None] * len(cfg.DATASETS.TEST) | ||
dataset_names = cfg.DATASETS.TEST | ||
if cfg.OUTPUT_DIR: | ||
for idx, dataset_name in enumerate(dataset_names): | ||
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) | ||
mkdir(output_folder) | ||
output_folders[idx] = output_folder | ||
data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=False) | ||
for output_folder, dataset_name, data_loader_val in zip(output_folders, dataset_names, data_loaders_val): | ||
inference( | ||
model, | ||
data_loader_val, | ||
dataset_name=dataset_name, | ||
iou_types=iou_types, | ||
box_only=False if cfg.MODEL.FCOS_ON or cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY, | ||
device=cfg.MODEL.DEVICE, | ||
expected_results=cfg.TEST.EXPECTED_RESULTS, | ||
expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, | ||
output_folder=output_folder, | ||
) | ||
synchronize() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.