Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,28 @@ typedef NS_ENUM(uint8_t, ExecuTorchShapeDynamism) {
ExecuTorchShapeDynamismDynamicUnbound,
} NS_SWIFT_NAME(ShapeDynamism);

/**
* Returns the size in bytes of the specified data type.
*
* @param dataType An ExecuTorchDataType value representing the tensor's element type.
* @return An NSInteger indicating the size in bytes.
*/
FOUNDATION_EXPORT
__attribute__((deprecated("This API is experimental.")))
NSInteger ExecuTorchSizeOfDataType(ExecuTorchDataType dataType)
NS_SWIFT_NAME(size(ofDataType:));

/**
* Computes the total number of elements in a tensor based on its shape.
*
* @param shape An NSArray of NSNumber objects, where each element represents a dimension size.
* @return An NSInteger equal to the product of the sizes of all dimensions.
*/
FOUNDATION_EXPORT
__attribute__((deprecated("This API is experimental.")))
NSInteger ExecuTorchElementCountOfShape(NSArray<NSNumber *> *shape)
NS_SWIFT_NAME(elementCount(ofShape:));

/**
* A tensor class for ExecuTorch operations.
*
Expand Down
12 changes: 12 additions & 0 deletions extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@
using namespace executorch::aten;
using namespace executorch::extension;

NSInteger ExecuTorchSizeOfDataType(ExecuTorchDataType dataType) {
return elementSize(static_cast<ScalarType>(dataType));
}

NSInteger ExecuTorchElementCountOfShape(NSArray<NSNumber *> *shape) {
NSInteger count = 1;
for (NSNumber *dimension in shape) {
count *= dimension.integerValue;
}
return count;
}

@implementation ExecuTorchTensor {
TensorPtr _tensor;
NSArray<NSNumber *> *_shape;
Expand Down
44 changes: 44 additions & 0 deletions extension/apple/ExecuTorch/__tests__/TensorTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,50 @@
import XCTest

class TensorTest: XCTestCase {
func testElementCountOfShape() {
XCTAssertEqual(elementCount(ofShape: [2, 3, 4]), 24)
XCTAssertEqual(elementCount(ofShape: [5]), 5)
XCTAssertEqual(elementCount(ofShape: []), 1)
}

func testSizeOfDataType() {
let expectedSizes: [DataType: Int] = [
.byte: 1,
.char: 1,
.short: 2,
.int: 4,
.long: 8,
.half: 2,
.float: 4,
.double: 8,
.complexHalf: 4,
.complexFloat: 8,
.complexDouble: 16,
.bool: 1,
.qInt8: 1,
.quInt8: 1,
.qInt32: 4,
.bFloat16: 2,
.quInt4x2: 1,
.quInt2x4: 1,
.bits1x8: 1,
.bits2x4: 1,
.bits4x2: 1,
.bits8: 1,
.bits16: 2,
.float8_e5m2: 1,
.float8_e4m3fn: 1,
.float8_e5m2fnuz: 1,
.float8_e4m3fnuz: 1,
.uInt16: 2,
.uInt32: 4,
.uInt64: 8,
]
for (dataType, expectedSize) in expectedSizes {
XCTAssertEqual(size(ofDataType: dataType), expectedSize, "Size for \(dataType) should be \(expectedSize)")
}
}

func testInitBytesNoCopy() {
var data: [Float] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
let tensor = data.withUnsafeMutableBytes {
Expand Down
Loading