diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift b/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift index dc9ca543649..29af8f78a5a 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift +++ b/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift @@ -11,40 +11,40 @@ /// A protocol that types conform to in order to be used as tensor element types. /// Provides the mapping from the Swift type to the underlying `DataType`. @available(*, deprecated, message: "This API is experimental.") -protocol Scalar { +public protocol Scalar { /// The `DataType` corresponding to this scalar type. static var dataType: DataType { get } } @available(*, deprecated, message: "This API is experimental.") -extension UInt8: Scalar { static var dataType: DataType { .byte } } +extension UInt8: Scalar { public static var dataType: DataType { .byte } } @available(*, deprecated, message: "This API is experimental.") -extension Int8: Scalar { static var dataType: DataType { .char } } +extension Int8: Scalar { public static var dataType: DataType { .char } } @available(*, deprecated, message: "This API is experimental.") -extension Int16: Scalar { static var dataType: DataType { .short } } +extension Int16: Scalar { public static var dataType: DataType { .short } } @available(*, deprecated, message: "This API is experimental.") -extension Int32: Scalar { static var dataType: DataType { .int } } +extension Int32: Scalar { public static var dataType: DataType { .int } } @available(*, deprecated, message: "This API is experimental.") -extension Int64: Scalar { static var dataType: DataType { .long } } +extension Int64: Scalar { public static var dataType: DataType { .long } } @available(*, deprecated, message: "This API is experimental.") -extension Int: Scalar { static var dataType: DataType { .long } } +extension Int: Scalar { public static var dataType: DataType { .long } } @available(*, deprecated, message: "This API is experimental.") -extension Float: Scalar { static var dataType: DataType { .float } } +extension Float: Scalar { public static var dataType: DataType { .float } } @available(*, deprecated, message: "This API is experimental.") -extension Double: Scalar { static var dataType: DataType { .double } } +extension Double: Scalar { public static var dataType: DataType { .double } } @available(*, deprecated, message: "This API is experimental.") -extension Bool: Scalar { static var dataType: DataType { .bool } } +extension Bool: Scalar { public static var dataType: DataType { .bool } } @available(*, deprecated, message: "This API is experimental.") -extension UInt16: Scalar { static var dataType: DataType { .uInt16 } } +extension UInt16: Scalar { public static var dataType: DataType { .uInt16 } } @available(*, deprecated, message: "This API is experimental.") -extension UInt32: Scalar { static var dataType: DataType { .uInt32 } } +extension UInt32: Scalar { public static var dataType: DataType { .uInt32 } } @available(*, deprecated, message: "This API is experimental.") -extension UInt64: Scalar { static var dataType: DataType { .uInt64 } } +extension UInt64: Scalar { public static var dataType: DataType { .uInt64 } } @available(*, deprecated, message: "This API is experimental.") -extension UInt: Scalar { static var dataType: DataType { .uInt64 } } +extension UInt: Scalar { public static var dataType: DataType { .uInt64 } } @available(*, deprecated, message: "This API is experimental.") -extension Tensor { +public extension Tensor { /// Calls the closure with a typed, immutable buffer pointer over the tensor’s elements. /// /// - Parameter body: A closure that receives an `UnsafeBufferPointer` bound to the tensor’s data. diff --git a/extension/apple/ExecuTorch/__tests__/TensorTest.swift b/extension/apple/ExecuTorch/__tests__/TensorTest.swift index 4013cb2b296..3a6dc4cfe75 100644 --- a/extension/apple/ExecuTorch/__tests__/TensorTest.swift +++ b/extension/apple/ExecuTorch/__tests__/TensorTest.swift @@ -153,7 +153,7 @@ class TensorTest: XCTestCase { let tensor = data.withUnsafeMutableBytes { Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 3], dataType: .float) } - let array: [Float] = try tensor.withUnsafeBytes { Array($0) } + let array = try tensor.withUnsafeBytes([Float].init) XCTAssertEqual(array, data) }