Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -821,9 +821,8 @@ public class Tensor<T: Scalar>: Equatable {
lhs.anyTensor == rhs.anyTensor
}

// MARK: Internal

let anyTensor: AnyTensor
// Wrapped AnyTensor instance.
public let anyTensor: AnyTensor
}

@available(*, deprecated, message: "This API is experimental.")
Expand Down
60 changes: 60 additions & 0 deletions extension/apple/ExecuTorch/__tests__/TensorTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,66 @@ class TensorTest: XCTestCase {
XCTAssertEqual(try tensor.scalars().first, 42)
}

func testExtractAnyTensorMatchesOriginalDataAndMetadata() {
let tensor = Tensor([1, 2, 3, 4], shape: [2, 2])
let anyTensor = tensor.anyTensor
XCTAssertEqual(anyTensor.shape, tensor.shape)
XCTAssertEqual(anyTensor.strides, tensor.strides)
XCTAssertEqual(anyTensor.dimensionOrder, tensor.dimensionOrder)
XCTAssertEqual(anyTensor.count, tensor.count)
XCTAssertEqual(anyTensor.dataType, tensor.dataType)
XCTAssertEqual(anyTensor.shapeDynamism, tensor.shapeDynamism)
let newTensor = Tensor<Int>(anyTensor)
XCTAssertEqual(newTensor, tensor)
}

func testReconstructGenericTensorViaInitAndAsTensor() {
let tensor = Tensor([5, 6, 7])
let anyTensor = tensor.anyTensor
let tensorInit = Tensor<Int>(anyTensor)
let tensorFromAny: Tensor<Int> = anyTensor.asTensor()!
XCTAssertEqual(tensorInit, tensorFromAny)
}

func testAsTensorMismatchedTypeReturnsNil() {
let tensor = Tensor([8, 9, 10])
let anyTensor = tensor.anyTensor
let wrongTypedTensor: Tensor<Float>? = anyTensor.asTensor()
XCTAssertNil(wrongTypedTensor)
}

func testViewSharesDataAndResizeAltersShapeNotData() throws {
var scalars = [11, 12, 13, 14]
let tensor = Tensor(&scalars, shape: [2, 2])
let viewTensor = Tensor(tensor)
let scalarsAddress = scalars.withUnsafeBufferPointer { $0.baseAddress }
let tensorDataAddress = try tensor.withUnsafeBytes { $0.baseAddress }
let viewTensorDataAddress = try viewTensor.withUnsafeBytes { $0.baseAddress }
XCTAssertEqual(tensorDataAddress, scalarsAddress)
XCTAssertEqual(tensorDataAddress, viewTensorDataAddress)

scalars[2] = 42
XCTAssertEqual(try tensor.scalars(), scalars)
XCTAssertEqual(try viewTensor.scalars(), scalars)

XCTAssertNoThrow(try viewTensor.resize(to: [4, 1]))
XCTAssertEqual(viewTensor.shape, [4, 1])
XCTAssertEqual(tensor.shape, [2, 2])
XCTAssertEqual(try tensor.scalars(), scalars)
XCTAssertEqual(try viewTensor.scalars(), scalars)
}

func testMultipleGenericFromAnyReflectChanges() {
let tensor = Tensor([2, 4, 6, 8], shape: [2, 2])
let anyTensor = tensor.anyTensor
let tensor1: Tensor<Int> = anyTensor.asTensor()!
let tensor2: Tensor<Int> = anyTensor.asTensor()!

XCTAssertEqual(tensor1, tensor2)
XCTAssertNoThrow(try tensor1.withUnsafeMutableBytes { $0[1] = 42 })
XCTAssertEqual(try tensor2.withUnsafeBytes { $0[1] }, 42)
}

func testEmpty() {
let tensor = Tensor<Float>.empty(shape: [3, 4])
XCTAssertEqual(tensor.shape, [3, 4])
Expand Down
Loading