diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 3bef6a2f95a..34b14250314 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -248,7 +248,16 @@ class ComputeGraph final { return values_.at(idx).is##type_name(); \ } - GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(vTensorPtr, tensor, Tensor) + protected: + inline vTensorPtr get_tensor(const ValueRef idx) { + return vTensorPtr(this, idx); + } + + public: + inline bool val_is_tensor(const ValueRef idx) const { + return values_.at(idx).isTensor(); + } + GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(TensorRefPtr, tref, TensorRef) GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(StagingPtr, staging, Staging) GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(IntListPtr, int_list, IntList) @@ -970,6 +979,8 @@ class ComputeGraph final { friend class SymIntPtr; friend struct TmpTensor; + friend struct SharedObject; + friend class BlitNode; }; template