From afe7fc99965a588bbde608211990bba59eb2b691 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Tue, 5 Aug 2025 14:12:57 -0700 Subject: [PATCH] [ET-VK][ez] Make `get_tensor()` API protected ## Changes As title; make the `get_tensor()` API protected. ## Motivation See the below diff/PR in the stack. The goal is to encourage operator authors to go through the `ComputeGraph` to access/modify tensors so that the activity can be tracked. Differential Revision: [D79564596](https://our.internmc.facebook.com/intern/diff/D79564596/) [ghstack-poisoned] --- backends/vulkan/runtime/graph/ComputeGraph.h | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 3bef6a2f95a..34b14250314 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -248,7 +248,16 @@ class ComputeGraph final { return values_.at(idx).is##type_name(); \ } - GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(vTensorPtr, tensor, Tensor) + protected: + inline vTensorPtr get_tensor(const ValueRef idx) { + return vTensorPtr(this, idx); + } + + public: + inline bool val_is_tensor(const ValueRef idx) const { + return values_.at(idx).isTensor(); + } + GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(TensorRefPtr, tref, TensorRef) GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(StagingPtr, staging, Staging) GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(IntListPtr, int_list, IntList) @@ -970,6 +979,8 @@ class ComputeGraph final { friend class SymIntPtr; friend struct TmpTensor; + friend struct SharedObject; + friend class BlitNode; }; template