diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h index b840cd2faac..6bd9854526c 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h @@ -61,10 +61,32 @@ typedef NS_ENUM(uint8_t, ExecuTorchShapeDynamism) { ExecuTorchShapeDynamismDynamicUnbound, } NS_SWIFT_NAME(ShapeDynamism); +/** + * A tensor class for ExecuTorch operations. + * + * This class encapsulates a native TensorPtr instance and provides a variety of + * initializers and utility methods to work with tensor data. + */ NS_SWIFT_NAME(Tensor) __attribute__((deprecated("This API is experimental."))) @interface ExecuTorchTensor : NSObject +/** + * Pointer to the underlying native TensorPtr instance. + * + * @return A raw pointer to the native TensorPtr held by this Tensor class. + */ +@property(nonatomic, readonly) void *nativeInstance NS_SWIFT_UNAVAILABLE(""); + +/** + * Initializes a tensor with a native TensorPtr instance. + * + * @param nativeInstance A pointer to a native TensorPtr instance. + * @return An initialized ExecuTorchTensor instance. + */ +- (instancetype)initWithNativeInstance:(void *)nativeInstance + NS_DESIGNATED_INITIALIZER NS_SWIFT_UNAVAILABLE(""); + + (instancetype)new NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE; diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm index 4b072444bec..ef93b2eb842 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm @@ -12,8 +12,23 @@ #import +using namespace executorch::extension; + @implementation ExecuTorchTensor { - ::executorch::extension::TensorPtr _tensor; + TensorPtr _tensor; +} + +- (instancetype)initWithNativeInstance:(void *)nativeInstance { + ET_CHECK(nativeInstance); + if (self = [super init]) { + _tensor = std::move(*reinterpret_cast(nativeInstance)); + ET_CHECK(_tensor); + } + return self; +} + +- (void *)nativeInstance { + return &_tensor; } @end