Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,14 @@ __attribute__((deprecated("This API is experimental.")))
error:(NSError **)error
NS_SWIFT_NAME(resize(to:));

/**
* Determines whether the current tensor is equal to another tensor.
*
* @param other Another ExecuTorchTensor instance to compare against.
* @return YES if the tensors have the same type, shape, strides, and data; otherwise, NO.
*/
- (BOOL)isEqualToTensor:(nullable ExecuTorchTensor *)other;

+ (instancetype)new NS_UNAVAILABLE;
- (instancetype)init NS_UNAVAILABLE;

Expand Down
26 changes: 26 additions & 0 deletions extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,32 @@ - (BOOL)resizeToShape:(NSArray<NSNumber *> *)shape
return YES;
}

- (BOOL)isEqualToTensor:(nullable ExecuTorchTensor *)other {
if (!other) {
return NO;
}
const auto *data = _tensor->unsafeGetTensorImpl()->data();
const auto *otherData = other->_tensor->unsafeGetTensorImpl()->data();
const auto size = self.count * ExecuTorchSizeOfDataType(self.dataType);
return self.dataType == other.dataType &&
self.count == other.count &&
[self.shape isEqual:other.shape] &&
[self.dimensionOrder isEqual:other.dimensionOrder] &&
[self.strides isEqual:other.strides] &&
self.shapeDynamism == other.shapeDynamism &&
(data && otherData ? std::memcmp(data, otherData, size) == 0 : data == otherData);
}

- (BOOL)isEqual:(nullable id)other {
if (self == other) {
return YES;
}
if (![other isKindOfClass:[ExecuTorchTensor class]]) {
return NO;
}
return [self isEqualToTensor:(ExecuTorchTensor *)other];
}

@end

@implementation ExecuTorchTensor (BytesNoCopy)
Expand Down
20 changes: 20 additions & 0 deletions extension/apple/ExecuTorch/Exported/ExecuTorchValue.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,30 @@ typedef NS_ENUM(uint32_t, ExecuTorchValueTag) {
ExecuTorchValueTagOptionalTensorList,
} NS_SWIFT_NAME(ValueTag);

/**
* A dynamic value type used by ExecuTorch.
*
* ExecuTorchValue encapsulates a value that may be of various types such as
* a tensor or a scalar. The value’s type is indicated by its tag.
*/
NS_SWIFT_NAME(Value)
__attribute__((deprecated("This API is experimental.")))
@interface ExecuTorchValue : NSObject

/**
* The tag that indicates the dynamic type of the value.
*
* @return An ExecuTorchValueTag value.
*/
@property(nonatomic, readonly) ExecuTorchValueTag tag;

/**
* Returns YES if the value is of type None.
*
* @return A BOOL indicating whether the value is None.
*/
@property(nonatomic, readonly) BOOL isNone;

@end

NS_ASSUME_NONNULL_END
13 changes: 0 additions & 13 deletions extension/apple/ExecuTorch/Exported/ExecuTorchValue.m

This file was deleted.

44 changes: 44 additions & 0 deletions extension/apple/ExecuTorch/Exported/ExecuTorchValue.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#import "ExecuTorchValue.h"

@interface ExecuTorchValue ()

- (instancetype)initWithTag:(ExecuTorchValueTag)tag
value:(nullable id)value NS_DESIGNATED_INITIALIZER;

@end

@implementation ExecuTorchValue {
ExecuTorchValueTag _tag;
id _value;
}

- (instancetype)init {
return [self initWithTag:ExecuTorchValueTagNone value:nil];
}

- (instancetype)initWithTag:(ExecuTorchValueTag)tag
value:(nullable id)value {
if (self = [super init]) {
_tag = tag;
_value = value;
}
return self;
}

- (ExecuTorchValueTag)tag {
return _tag;
}

- (BOOL)isNone {
return _tag == ExecuTorchValueTagNone;
}

@end
23 changes: 23 additions & 0 deletions extension/apple/ExecuTorch/__tests__/TensorTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,27 @@ class TensorTest: XCTestCase {
}
XCTAssertThrowsError(try tensor.resize(to: [2, 3]))
}

func testIsEqual() {
var data: [Float] = [1.0, 2.0, 3.0, 4.0]
let tensor1 = data.withUnsafeMutableBytes {
Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 2], dataType: .float)
}
let tensor2 = Tensor(tensor1)
XCTAssertTrue(tensor1.isEqual(tensor2))
XCTAssertTrue(tensor2.isEqual(tensor1))

var dataModified: [Float] = [1.0, 2.0, 3.0, 5.0]
let tensor3 = dataModified.withUnsafeMutableBytes {
Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 2], dataType: .float)
}
XCTAssertFalse(tensor1.isEqual(tensor3))
let tensor4 = data.withUnsafeMutableBytes {
Tensor(bytesNoCopy: $0.baseAddress!, shape: [4, 1], dataType: .float)
}
XCTAssertFalse(tensor1.isEqual(tensor4))
XCTAssertTrue(tensor1.isEqual(tensor1))
XCTAssertFalse(tensor1.isEqual(NSString(string: "Not a tensor")))
XCTAssertFalse(tensor4.isEqual(tensor2.copy()))
}
}
4 changes: 3 additions & 1 deletion extension/apple/ExecuTorch/__tests__/ValueTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import XCTest

class ValueTest: XCTestCase {
func test() {
func testNone() {
let value = Value()
XCTAssertTrue(value.isNone)
}
}
Loading