diff --git a/caffe2/opt/onnxifi_transformer.cc b/caffe2/opt/onnxifi_transformer.cc index 8089314c3100..2dd8c8d2d8b4 100644 --- a/caffe2/opt/onnxifi_transformer.cc +++ b/caffe2/opt/onnxifi_transformer.cc @@ -1195,11 +1195,11 @@ void OnnxifiTransformer::applyFilteringRules( blocklistCpuPartition(net, blocklisted_ops); } -void OnnxifiTransformer::getBackendId() { +std::vector OnnxifiTransformer::getBackendId() { idx_ = 0; if (opts_.use_onnx) { - return; + return backend_ids_; } // Try to find a backend that support Caffe2 proto. Note that this is quite // opportunistic as we don't officially support Caffe2 proto. @@ -1214,6 +1214,7 @@ void OnnxifiTransformer::getBackendId() { break; } } + return backend_ids_; } NetDef OnnxifiTransformer::TransformViaC2( diff --git a/caffe2/opt/onnxifi_transformer.h b/caffe2/opt/onnxifi_transformer.h index d88eb739750c..d1af1731013d 100644 --- a/caffe2/opt/onnxifi_transformer.h +++ b/caffe2/opt/onnxifi_transformer.h @@ -61,6 +61,17 @@ class TORCH_API OnnxifiTransformer final : public BackendTransformerBase { const ShapeInfoMap& shape_hints, const std::unordered_set& blocklisted_ops) override; + // Query whether an operator is supported by passing C2 protobuf + bool supportOpC2( + const caffe2::OperatorDef& op, + const ShapeInfoMap& shape_hints, + const std::unordered_set& weights, + const std::unordered_set& blocklisted_ops, + onnxBackendID backend_id) const; + + // Determine backend id + std::vector getBackendId(); + private: // Since we create new tensors during the conversion process, we actually need // into inject them into the original workspace @@ -114,14 +125,6 @@ class TORCH_API OnnxifiTransformer final : public BackendTransformerBase { ShapeInfoMap* shape_hints_max_bs, const std::unordered_map &shape_hints_per_bs); - // Query whether an operator is supported by passing C2 protobuf - bool supportOpC2( - const caffe2::OperatorDef& op, - const ShapeInfoMap& shape_hints, - const std::unordered_set& weights, - const std::unordered_set& blocklisted_ops, - onnxBackendID backend_id) const; - // Query whether an operator is supported by passing ONNX protobuf bool supportOpOnnx( const caffe2::OperatorDef& op, @@ -152,9 +155,6 @@ class TORCH_API OnnxifiTransformer final : public BackendTransformerBase { const std::unordered_set& weights, std::unordered_set* blocklisted_ops) const; - // Determine backend id - void getBackendId(); - // Extract partition info from the original net void extractPartitionInfo(const NetDef& net);