Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,15 @@ __attribute__((deprecated("This API is experimental.")))
- (instancetype)initWithNativeInstance:(void *)nativeInstance
NS_DESIGNATED_INITIALIZER NS_SWIFT_UNAVAILABLE("");

/**
* Creates a new tensor by copying an existing tensor.
*
* @param otherTensor The tensor instance to copy.
* @return A new ExecuTorchTensor instance that is a copy of otherTensor.
*/
- (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor
NS_SWIFT_NAME(init(_:));

/**
* Executes a block with a pointer to the tensor's immutable byte data.
*
Expand Down
8 changes: 8 additions & 0 deletions extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ - (instancetype)initWithNativeInstance:(void *)nativeInstance {
return self;
}

- (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor {
ET_CHECK(otherTensor);
auto tensor = make_tensor_ptr(
**reinterpret_cast<TensorPtr *>(otherTensor.nativeInstance)
);
return [self initWithNativeInstance:&tensor];
}

- (void *)nativeInstance {
return &_tensor;
}
Expand Down
14 changes: 14 additions & 0 deletions extension/apple/ExecuTorch/__tests__/TensorTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,18 @@ class TensorTest: XCTestCase {
XCTAssertEqual(updatedData, [2, 4, 6, 8])
}
}

func testInitWithTensor() {
var data: [Int] = [10, 20, 30, 40]
let tensor1 = data.withUnsafeMutableBytes {
Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 2], dataType: .int)
}
let tensor2 = Tensor(tensor1)

XCTAssertEqual(tensor2.dataType, tensor1.dataType)
XCTAssertEqual(tensor2.shape, tensor1.shape)
XCTAssertEqual(tensor2.strides, tensor1.strides)
XCTAssertEqual(tensor2.dimensionOrder, tensor1.dimensionOrder)
XCTAssertEqual(tensor2.count, tensor1.count)
}
}
Loading