Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -544,4 +544,26 @@ __attribute__((deprecated("This API is experimental.")))

@end

@interface ExecuTorchTensor (Scalar)

/**
* Initializes a tensor with a single scalar value and a specified data type.
*
* @param scalar An NSNumber representing the scalar value.
* @param dataType An ExecuTorchDataType value specifying the element type.
* @return An initialized ExecuTorchTensor instance representing the scalar.
*/
- (instancetype)initWithScalar:(NSNumber *)scalar
dataType:(ExecuTorchDataType)dataType NS_SWIFT_NAME(init(_:dataType:));

/**
* Initializes a tensor with a single scalar value, automatically deducing its data type.
*
* @param scalar An NSNumber representing the scalar value.
* @return An initialized ExecuTorchTensor instance representing the scalar.
*/
- (instancetype)initWithScalar:(NSNumber *)scalar NS_SWIFT_NAME(init(_:));

@end

NS_ASSUME_NONNULL_END
23 changes: 23 additions & 0 deletions extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,26 @@ - (instancetype)initWithScalars:(NSArray<NSNumber *> *)scalars {
}

@end

@implementation ExecuTorchTensor (Scalar)

- (instancetype)initWithScalar:(NSNumber *)scalar
dataType:(ExecuTorchDataType)dataType {
return [self initWithScalars:@[scalar]
shape:@[]
strides:@[]
dimensionOrder:@[]
dataType:dataType
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
}

- (instancetype)initWithScalar:(NSNumber *)scalar {
return [self initWithScalars:@[scalar]
shape:@[]
strides:@[]
dimensionOrder:@[]
dataType:static_cast<ExecuTorchDataType>(utils::deduceType(scalar))
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
}

@end
12 changes: 12 additions & 0 deletions extension/apple/ExecuTorch/__tests__/TensorTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -392,4 +392,16 @@ class TensorTest: XCTestCase {
XCTAssertEqual(Array(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: UInt.self), count: count)), data)
}
}

func testInitFloat() {
let tensor = Tensor(Float(42.0) as NSNumber)
XCTAssertEqual(tensor.dataType, .float)
XCTAssertEqual(tensor.shape, [])
XCTAssertEqual(tensor.strides, [])
XCTAssertEqual(tensor.dimensionOrder, [])
XCTAssertEqual(tensor.count, 1)
tensor.bytes { pointer, count, dataType in
XCTAssertEqual(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Float.self), count: count).first, 42.0)
}
}
}
Loading