diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h index 9569e96e039..108adabbc29 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h @@ -544,4 +544,26 @@ __attribute__((deprecated("This API is experimental."))) @end +@interface ExecuTorchTensor (Scalar) + +/** + * Initializes a tensor with a single scalar value and a specified data type. + * + * @param scalar An NSNumber representing the scalar value. + * @param dataType An ExecuTorchDataType value specifying the element type. + * @return An initialized ExecuTorchTensor instance representing the scalar. + */ +- (instancetype)initWithScalar:(NSNumber *)scalar + dataType:(ExecuTorchDataType)dataType NS_SWIFT_NAME(init(_:dataType:)); + +/** + * Initializes a tensor with a single scalar value, automatically deducing its data type. + * + * @param scalar An NSNumber representing the scalar value. + * @return An initialized ExecuTorchTensor instance representing the scalar. + */ +- (instancetype)initWithScalar:(NSNumber *)scalar NS_SWIFT_NAME(init(_:)); + +@end + NS_ASSUME_NONNULL_END diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm index c5b16a1982e..5ddb0cc3c97 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm @@ -455,3 +455,26 @@ - (instancetype)initWithScalars:(NSArray *)scalars { } @end + +@implementation ExecuTorchTensor (Scalar) + +- (instancetype)initWithScalar:(NSNumber *)scalar + dataType:(ExecuTorchDataType)dataType { + return [self initWithScalars:@[scalar] + shape:@[] + strides:@[] + dimensionOrder:@[] + dataType:dataType + shapeDynamism:ExecuTorchShapeDynamismDynamicBound]; +} + +- (instancetype)initWithScalar:(NSNumber *)scalar { + return [self initWithScalars:@[scalar] + shape:@[] + strides:@[] + dimensionOrder:@[] + dataType:static_cast(utils::deduceType(scalar)) + shapeDynamism:ExecuTorchShapeDynamismDynamicBound]; +} + +@end diff --git a/extension/apple/ExecuTorch/__tests__/TensorTest.swift b/extension/apple/ExecuTorch/__tests__/TensorTest.swift index c6145fcc5be..cbab95148cb 100644 --- a/extension/apple/ExecuTorch/__tests__/TensorTest.swift +++ b/extension/apple/ExecuTorch/__tests__/TensorTest.swift @@ -392,4 +392,16 @@ class TensorTest: XCTestCase { XCTAssertEqual(Array(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: UInt.self), count: count)), data) } } + + func testInitFloat() { + let tensor = Tensor(Float(42.0) as NSNumber) + XCTAssertEqual(tensor.dataType, .float) + XCTAssertEqual(tensor.shape, []) + XCTAssertEqual(tensor.strides, []) + XCTAssertEqual(tensor.dimensionOrder, []) + XCTAssertEqual(tensor.count, 1) + tensor.bytes { pointer, count, dataType in + XCTAssertEqual(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Float.self), count: count).first, 42.0) + } + } }