From 89b7a1bbd33a93589487b426185dc68ec99aec4d Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Wed, 26 Mar 2025 20:37:42 -0700 Subject: [PATCH] Overloads for scalar constructor. Summary: https://github.com/pytorch/executorch/issues/8366 Reviewed By: bsoyluoglu Differential Revision: D71931436 --- .../ExecuTorch/Exported/ExecuTorchTensor.h | 104 ++++++++++++++++ .../ExecuTorch/Exported/ExecuTorchTensor.mm | 117 ++++++++++++++++++ .../ExecuTorch/__tests__/TensorTest.swift | 2 +- 3 files changed, 222 insertions(+), 1 deletion(-) diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h index 108adabbc29..45b6bd768cb 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h @@ -564,6 +564,110 @@ __attribute__((deprecated("This API is experimental."))) */ - (instancetype)initWithScalar:(NSNumber *)scalar NS_SWIFT_NAME(init(_:)); +/** + * Initializes a tensor with a byte scalar value. + * + * @param scalar A uint8_t value. + * @return An initialized ExecuTorchTensor instance. + */ +- (instancetype)initWithByte:(uint8_t)scalar NS_SWIFT_NAME(init(_:)); + +/** + * Initializes a tensor with a char scalar value. + * + * @param scalar An int8_t value. + * @return An initialized ExecuTorchTensor instance. + */ +- (instancetype)initWithChar:(int8_t)scalar NS_SWIFT_NAME(init(_:)); + +/** + * Initializes a tensor with a short scalar value. + * + * @param scalar An int16_t value. + * @return An initialized ExecuTorchTensor instance. + */ +- (instancetype)initWithShort:(int16_t)scalar NS_SWIFT_NAME(init(_:)); + +/** + * Initializes a tensor with an int scalar value. + * + * @param scalar An int32_t value. + * @return An initialized ExecuTorchTensor instance. + */ +- (instancetype)initWithInt:(int32_t)scalar NS_SWIFT_NAME(init(_:)); + +/** + * Initializes a tensor with a long scalar value. + * + * @param scalar An int64_t value. + * @return An initialized ExecuTorchTensor instance. + */ +- (instancetype)initWithLong:(int64_t)scalar NS_SWIFT_NAME(init(_:)); + +/** + * Initializes a tensor with a float scalar value. + * + * @param scalar A float value. + * @return An initialized ExecuTorchTensor instance. + */ +- (instancetype)initWithFloat:(float)scalar NS_SWIFT_NAME(init(_:)); + +/** + * Initializes a tensor with a double scalar value. + * + * @param scalar A double value. + * @return An initialized ExecuTorchTensor instance. + */ +- (instancetype)initWithDouble:(double)scalar NS_SWIFT_NAME(init(_:)); + +/** + * Initializes a tensor with a boolean scalar value. + * + * @param scalar A BOOL value. + * @return An initialized ExecuTorchTensor instance. + */ +- (instancetype)initWithBool:(BOOL)scalar NS_SWIFT_NAME(init(_:)); + +/** + * Initializes a tensor with a uint16 scalar value. + * + * @param scalar A uint16_t value. + * @return An initialized ExecuTorchTensor instance. + */ +- (instancetype)initWithUInt16:(uint16_t)scalar NS_SWIFT_NAME(init(_:)); + +/** + * Initializes a tensor with a uint32 scalar value. + * + * @param scalar A uint32_t value. + * @return An initialized ExecuTorchTensor instance. + */ +- (instancetype)initWithUInt32:(uint32_t)scalar NS_SWIFT_NAME(init(_:)); + +/** + * Initializes a tensor with a uint64 scalar value. + * + * @param scalar A uint64_t value. + * @return An initialized ExecuTorchTensor instance. + */ +- (instancetype)initWithUInt64:(uint64_t)scalar NS_SWIFT_NAME(init(_:)); + +/** + * Initializes a tensor with an NSInteger scalar value. + * + * @param scalar An NSInteger value. + * @return An initialized ExecuTorchTensor instance. + */ +- (instancetype)initWithInteger:(NSInteger)scalar NS_SWIFT_NAME(init(_:)); + +/** + * Initializes a tensor with an NSUInteger scalar value. + * + * @param scalar An NSUInteger value. + * @return An initialized ExecuTorchTensor instance. + */ +- (instancetype)initWithUnsignedInteger:(NSUInteger)scalar NS_SWIFT_NAME(init(_:)); + @end NS_ASSUME_NONNULL_END diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm index 5ddb0cc3c97..547fc6f1950 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm @@ -477,4 +477,121 @@ - (instancetype)initWithScalar:(NSNumber *)scalar { shapeDynamism:ExecuTorchShapeDynamismDynamicBound]; } +- (instancetype)initWithByte:(uint8_t)scalar { + return [self initWithBytes:&scalar + shape:@[] + strides:@[] + dimensionOrder:@[] + dataType:ExecuTorchDataTypeByte + shapeDynamism:ExecuTorchShapeDynamismDynamicBound]; +} + +- (instancetype)initWithChar:(int8_t)scalar { + return [self initWithBytes:&scalar + shape:@[] + strides:@[] + dimensionOrder:@[] + dataType:ExecuTorchDataTypeChar + shapeDynamism:ExecuTorchShapeDynamismDynamicBound]; +} + +- (instancetype)initWithShort:(int16_t)scalar { + return [self initWithBytes:&scalar + shape:@[] + strides:@[] + dimensionOrder:@[] + dataType:ExecuTorchDataTypeShort + shapeDynamism:ExecuTorchShapeDynamismDynamicBound]; +} + +- (instancetype)initWithInt:(int32_t)scalar { + return [self initWithBytes:&scalar + shape:@[] + strides:@[] + dimensionOrder:@[] + dataType:ExecuTorchDataTypeInt + shapeDynamism:ExecuTorchShapeDynamismDynamicBound]; +} + +- (instancetype)initWithLong:(int64_t)scalar { + return [self initWithBytes:&scalar + shape:@[] + strides:@[] + dimensionOrder:@[] + dataType:ExecuTorchDataTypeLong + shapeDynamism:ExecuTorchShapeDynamismDynamicBound]; +} + +- (instancetype)initWithFloat:(float)scalar { + return [self initWithBytes:&scalar + shape:@[] + strides:@[] + dimensionOrder:@[] + dataType:ExecuTorchDataTypeFloat + shapeDynamism:ExecuTorchShapeDynamismDynamicBound]; +} + +- (instancetype)initWithDouble:(double)scalar { + return [self initWithBytes:&scalar + shape:@[] + strides:@[] + dimensionOrder:@[] + dataType:ExecuTorchDataTypeDouble + shapeDynamism:ExecuTorchShapeDynamismDynamicBound]; +} + +- (instancetype)initWithBool:(BOOL)scalar { + return [self initWithBytes:&scalar + shape:@[] + strides:@[] + dimensionOrder:@[] + dataType:ExecuTorchDataTypeBool + shapeDynamism:ExecuTorchShapeDynamismDynamicBound]; +} + +- (instancetype)initWithUInt16:(uint16_t)scalar { + return [self initWithBytes:&scalar + shape:@[] + strides:@[] + dimensionOrder:@[] + dataType:ExecuTorchDataTypeUInt16 + shapeDynamism:ExecuTorchShapeDynamismDynamicBound]; +} + +- (instancetype)initWithUInt32:(uint32_t)scalar { + return [self initWithBytes:&scalar + shape:@[] + strides:@[] + dimensionOrder:@[] + dataType:ExecuTorchDataTypeUInt32 + shapeDynamism:ExecuTorchShapeDynamismDynamicBound]; +} + +- (instancetype)initWithUInt64:(uint64_t)scalar { + return [self initWithBytes:&scalar + shape:@[] + strides:@[] + dimensionOrder:@[] + dataType:ExecuTorchDataTypeUInt64 + shapeDynamism:ExecuTorchShapeDynamismDynamicBound]; +} + +- (instancetype)initWithInteger:(NSInteger)scalar { + return [self initWithBytes:&scalar + shape:@[] + strides:@[] + dimensionOrder:@[] + dataType:(sizeof(scalar) == 8 ? ExecuTorchDataTypeLong : ExecuTorchDataTypeInt) + shapeDynamism:ExecuTorchShapeDynamismDynamicBound]; +} + +- (instancetype)initWithUnsignedInteger:(NSUInteger)scalar { + return [self initWithBytes:&scalar + shape:@[] + strides:@[] + dimensionOrder:@[] + dataType:(sizeof(scalar) == 8 ? ExecuTorchDataTypeUInt64 : ExecuTorchDataTypeUInt32) + shapeDynamism:ExecuTorchShapeDynamismDynamicBound]; +} + @end diff --git a/extension/apple/ExecuTorch/__tests__/TensorTest.swift b/extension/apple/ExecuTorch/__tests__/TensorTest.swift index cbab95148cb..4f89216886a 100644 --- a/extension/apple/ExecuTorch/__tests__/TensorTest.swift +++ b/extension/apple/ExecuTorch/__tests__/TensorTest.swift @@ -394,7 +394,7 @@ class TensorTest: XCTestCase { } func testInitFloat() { - let tensor = Tensor(Float(42.0) as NSNumber) + let tensor = Tensor(Float(42.0)) XCTAssertEqual(tensor.dataType, .float) XCTAssertEqual(tensor.shape, []) XCTAssertEqual(tensor.strides, [])