From 153c96fff54d20441173fa425f3e96bd4203d873 Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Wed, 26 Mar 2025 20:35:35 -0700 Subject: [PATCH] Tensor constructor to create with a single scalar. Summary: https://github.com/pytorch/executorch/issues/8366 Reviewed By: bsoyluoglu Differential Revision: D71930917 --- .../ExecuTorch/Exported/ExecuTorchTensor.h | 22 ++++++++++++++++++ .../ExecuTorch/Exported/ExecuTorchTensor.mm | 23 +++++++++++++++++++ .../ExecuTorch/__tests__/TensorTest.swift | 12 ++++++++++ 3 files changed, 57 insertions(+) diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h index 9569e96e039..108adabbc29 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h @@ -544,4 +544,26 @@ __attribute__((deprecated("This API is experimental."))) @end +@interface ExecuTorchTensor (Scalar) + +/** + * Initializes a tensor with a single scalar value and a specified data type. + * + * @param scalar An NSNumber representing the scalar value. + * @param dataType An ExecuTorchDataType value specifying the element type. + * @return An initialized ExecuTorchTensor instance representing the scalar. + */ +- (instancetype)initWithScalar:(NSNumber *)scalar + dataType:(ExecuTorchDataType)dataType NS_SWIFT_NAME(init(_:dataType:)); + +/** + * Initializes a tensor with a single scalar value, automatically deducing its data type. + * + * @param scalar An NSNumber representing the scalar value. + * @return An initialized ExecuTorchTensor instance representing the scalar. + */ +- (instancetype)initWithScalar:(NSNumber *)scalar NS_SWIFT_NAME(init(_:)); + +@end + NS_ASSUME_NONNULL_END diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm index c5b16a1982e..5ddb0cc3c97 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm @@ -455,3 +455,26 @@ - (instancetype)initWithScalars:(NSArray *)scalars { } @end + +@implementation ExecuTorchTensor (Scalar) + +- (instancetype)initWithScalar:(NSNumber *)scalar + dataType:(ExecuTorchDataType)dataType { + return [self initWithScalars:@[scalar] + shape:@[] + strides:@[] + dimensionOrder:@[] + dataType:dataType + shapeDynamism:ExecuTorchShapeDynamismDynamicBound]; +} + +- (instancetype)initWithScalar:(NSNumber *)scalar { + return [self initWithScalars:@[scalar] + shape:@[] + strides:@[] + dimensionOrder:@[] + dataType:static_cast(utils::deduceType(scalar)) + shapeDynamism:ExecuTorchShapeDynamismDynamicBound]; +} + +@end diff --git a/extension/apple/ExecuTorch/__tests__/TensorTest.swift b/extension/apple/ExecuTorch/__tests__/TensorTest.swift index c6145fcc5be..cbab95148cb 100644 --- a/extension/apple/ExecuTorch/__tests__/TensorTest.swift +++ b/extension/apple/ExecuTorch/__tests__/TensorTest.swift @@ -392,4 +392,16 @@ class TensorTest: XCTestCase { XCTAssertEqual(Array(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: UInt.self), count: count)), data) } } + + func testInitFloat() { + let tensor = Tensor(Float(42.0) as NSNumber) + XCTAssertEqual(tensor.dataType, .float) + XCTAssertEqual(tensor.shape, []) + XCTAssertEqual(tensor.strides, []) + XCTAssertEqual(tensor.dimensionOrder, []) + XCTAssertEqual(tensor.count, 1) + tensor.bytes { pointer, count, dataType in + XCTAssertEqual(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Float.self), count: count).first, 42.0) + } + } }