From 4ccb8b538774e109f2a64fcdaac044d133ebdb1d Mon Sep 17 00:00:00 2001 From: Michael Shamash Date: Sun, 3 Apr 2022 18:33:21 -0400 Subject: [PATCH 1/3] Add NMS to CoreML model output, works with Vision Reference issues: #5157 , #343 , #7011 The current version of the export.py script outputs a CoreML model without NMS. This means that certain Vision APIs cannot be used with the model directly, as the output during inference is VNCoreMLFeatureValueObservation. The changes implemented here add a NMS layer to the CoreML output, so the output from inference is VNRecognizedObjectObservation. By adding NMS to the model directly, as opposed to later in code, the performance of the overall image/video processing is improved. This also allows use of the "Preview" tab in Xcode for quickly testing the model. Default IoU and confidence thresholds are taken from the `--iou-thres` and `--conf-thres` arguments during export.py script runtime. The user can also change these later by using a CoreML MLFeatureProvider in their application (see [https://developer.apple.com/documentation/coreml/mlfeatureprovider](https://developer.apple.com/documentation/coreml/mlfeatureprovider)). This has no effect on training, as it only adds an additional layer during CoreML export for NMS. --- export.py | 127 ++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 123 insertions(+), 4 deletions(-) diff --git a/export.py b/export.py index 87be00376778..bee91227dbec 100644 --- a/export.py +++ b/export.py @@ -180,7 +180,24 @@ def export_openvino(model, im, file, prefix=colorstr('OpenVINO:')): LOGGER.info(f'\n{prefix} export failure: {e}') -def export_coreml(model, im, file, prefix=colorstr('CoreML:')): +class CoreMLExportModel(torch.nn.Module): + def __init__(self, base_model, img_size): + super(CoreMLExportModel, self).__init__() + self.base_model = base_model + self.img_size = img_size + + def forward(self, x): + x = self.base_model(x)[0] + x = x.squeeze(0) + # Convert box coords to normalized coordinates [0 ... 1] + w = self.img_size[0] + h = self.img_size[1] + objectness = x[:, 4:5] + class_probs = x[:, 5:] * objectness + boxes = x[:, :4] * torch.tensor([1./w, 1./h, 1./w, 1./h]) + return class_probs, boxes + +def export_coreml(model, im, file, num_boxes, num_classes, labels, conf_thres, iou_thres, prefix=colorstr('CoreML:')): # YOLOv5 CoreML export try: check_requirements(('coremltools',)) @@ -189,8 +206,109 @@ def export_coreml(model, im, file, prefix=colorstr('CoreML:')): LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...') f = file.with_suffix('.mlmodel') - ts = torch.jit.trace(model, im, strict=False) # TorchScript model - ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])]) + export_model = CoreMLExportModel(model, img_size=opt.imgsz) + + ts = torch.jit.trace(export_model, im, strict=False) # TorchScript model + orig_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])]) + + spec = orig_model.get_spec() + old_box_output_name = spec.description.output[1].name + old_scores_output_name = spec.description.output[0].name + ct.utils.rename_feature(spec, old_scores_output_name, "raw_confidence") + ct.utils.rename_feature(spec, old_box_output_name, "raw_coordinates") + spec.description.output[0].type.multiArrayType.shape.extend([num_boxes, num_classes]) + spec.description.output[1].type.multiArrayType.shape.extend([num_boxes, 4]) + spec.description.output[0].type.multiArrayType.dataType = ct.proto.FeatureTypes_pb2.ArrayFeatureType.DOUBLE + spec.description.output[1].type.multiArrayType.dataType = ct.proto.FeatureTypes_pb2.ArrayFeatureType.DOUBLE + + yolo_model = ct.models.MLModel(spec) + + # Build Non Maximum Suppression model + nms_spec = ct.proto.Model_pb2.Model() + nms_spec.specificationVersion = 3 + + for i in range(2): + decoder_output = spec.description.output[i].SerializeToString() + + nms_spec.description.input.add() + nms_spec.description.input[i].ParseFromString(decoder_output) + + nms_spec.description.output.add() + nms_spec.description.output[i].ParseFromString(decoder_output) + + nms_spec.description.output[0].name = "confidence" + nms_spec.description.output[1].name = "coordinates" + + output_sizes = [num_classes, 4] + for i in range(2): + ma_type = nms_spec.description.output[i].type.multiArrayType + ma_type.shapeRange.sizeRanges.add() + ma_type.shapeRange.sizeRanges[0].lowerBound = 0 + ma_type.shapeRange.sizeRanges[0].upperBound = -1 + ma_type.shapeRange.sizeRanges.add() + ma_type.shapeRange.sizeRanges[1].lowerBound = output_sizes[i] + ma_type.shapeRange.sizeRanges[1].upperBound = output_sizes[i] + del ma_type.shape[:] + + nms = nms_spec.nonMaximumSuppression + nms.confidenceInputFeatureName = "raw_confidence" + nms.coordinatesInputFeatureName = "raw_coordinates" + nms.confidenceOutputFeatureName = "confidence" + nms.coordinatesOutputFeatureName = "coordinates" + nms.iouThresholdInputFeatureName = "iouThreshold" + nms.confidenceThresholdInputFeatureName = "confidenceThreshold" + + nms.iouThreshold = iou_thres + nms.confidenceThreshold = conf_thres + nms.pickTop.perClass = False + nms.stringClassLabels.vector.extend(labels) + + nms_model = ct.models.MLModel(nms_spec) + + # Assembling a pipeline model from the two models + input_features = [("image", ct.models.datatypes.Array(3, 300, 300)), + ("iouThreshold", ct.models.datatypes.Double()), + ("confidenceThreshold", ct.models.datatypes.Double())] + + output_features = ["confidence", "coordinates"] + + pipeline = ct.models.pipeline.Pipeline(input_features, output_features) + + pipeline.add_model(yolo_model) + pipeline.add_model(nms_model) + + # The "image" input should really be an image, not a multi-array + pipeline.spec.description.input[0].ParseFromString(spec.description.input[0].SerializeToString()) + + # Copy the declarations of the "confidence" and "coordinates" outputs + # The Pipeline makes these strings by default + pipeline.spec.description.output[0].ParseFromString(nms_spec.description.output[0].SerializeToString()) + pipeline.spec.description.output[1].ParseFromString(nms_spec.description.output[1].SerializeToString()) + + # Add descriptions to the inputs and outputs + pipeline.spec.description.input[1].shortDescription = "(optional) IOU Threshold override" + pipeline.spec.description.input[2].shortDescription = "(optional) Confidence Threshold override" + pipeline.spec.description.output[0].shortDescription = u"Boxes Class confidence" + pipeline.spec.description.output[1].shortDescription = u"Boxes [x, y, width, height] (normalized to [0...1])" + + # Add metadata to the model + pipeline.spec.description.metadata.shortDescription = "YOLOv5 object detector" + pipeline.spec.description.metadata.author = "Ultralytics" + + # Add the default threshold values and list of class labels + user_defined_metadata = { + "iou_threshold": str(iou_thres), + "confidence_threshold": str(conf_thres), + "classes": ", ".join(labels) + } + pipeline.spec.description.metadata.userDefined.update(user_defined_metadata) + + # Don't forget this or Core ML might attempt to run the model on an unsupported operating system version! + pipeline.spec.specificationVersion = 3 + + ct_model = ct.models.MLModel(pipeline.spec) + + f = str(file).replace('.pt', '.mlmodel') ct_model.save(f) LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') @@ -504,7 +622,8 @@ def run( if xml: # OpenVINO f[3] = export_openvino(model, im, file) if coreml: - _, f[4] = export_coreml(model, im, file) + nb = shape[1] + _, f[4] = export_coreml(model, im, file, nb, nc, names, conf_thres, iou_thres) # TensorFlow Exports if any((saved_model, pb, tflite, edgetpu, tfjs)): From 749d8eb7689eeaf591cf88792359e7bdd6025261 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 3 Apr 2022 22:34:58 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- export.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/export.py b/export.py index bee91227dbec..e2552667be3b 100644 --- a/export.py +++ b/export.py @@ -182,7 +182,7 @@ def export_openvino(model, im, file, prefix=colorstr('OpenVINO:')): class CoreMLExportModel(torch.nn.Module): def __init__(self, base_model, img_size): - super(CoreMLExportModel, self).__init__() + super().__init__() self.base_model = base_model self.img_size = img_size @@ -194,9 +194,10 @@ def forward(self, x): h = self.img_size[1] objectness = x[:, 4:5] class_probs = x[:, 5:] * objectness - boxes = x[:, :4] * torch.tensor([1./w, 1./h, 1./w, 1./h]) + boxes = x[:, :4] * torch.tensor([1. / w, 1. / h, 1. / w, 1. / h]) return class_probs, boxes + def export_coreml(model, im, file, num_boxes, num_classes, labels, conf_thres, iou_thres, prefix=colorstr('CoreML:')): # YOLOv5 CoreML export try: @@ -288,8 +289,8 @@ def export_coreml(model, im, file, num_boxes, num_classes, labels, conf_thres, i # Add descriptions to the inputs and outputs pipeline.spec.description.input[1].shortDescription = "(optional) IOU Threshold override" pipeline.spec.description.input[2].shortDescription = "(optional) Confidence Threshold override" - pipeline.spec.description.output[0].shortDescription = u"Boxes Class confidence" - pipeline.spec.description.output[1].shortDescription = u"Boxes [x, y, width, height] (normalized to [0...1])" + pipeline.spec.description.output[0].shortDescription = "Boxes Class confidence" + pipeline.spec.description.output[1].shortDescription = "Boxes [x, y, width, height] (normalized to [0...1])" # Add metadata to the model pipeline.spec.description.metadata.shortDescription = "YOLOv5 object detector" @@ -299,15 +300,14 @@ def export_coreml(model, im, file, num_boxes, num_classes, labels, conf_thres, i user_defined_metadata = { "iou_threshold": str(iou_thres), "confidence_threshold": str(conf_thres), - "classes": ", ".join(labels) - } + "classes": ", ".join(labels)} pipeline.spec.description.metadata.userDefined.update(user_defined_metadata) # Don't forget this or Core ML might attempt to run the model on an unsupported operating system version! pipeline.spec.specificationVersion = 3 ct_model = ct.models.MLModel(pipeline.spec) - + f = str(file).replace('.pt', '.mlmodel') ct_model.save(f) From 6bff76942f3bc77e72e370ebffde363be05fcea4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 9 Apr 2022 12:51:04 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- export.py | 1 + 1 file changed, 1 insertion(+) diff --git a/export.py b/export.py index 8b82bdbcda80..a24f16c61180 100644 --- a/export.py +++ b/export.py @@ -187,6 +187,7 @@ def export_openvino(model, im, file, prefix=colorstr('OpenVINO:')): class CoreMLExportModel(torch.nn.Module): + def __init__(self, base_model, img_size): super().__init__() self.base_model = base_model