diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorch+Value.swift b/extension/apple/ExecuTorch/Exported/ExecuTorch+Value.swift index 148b8f03cf0..b00fba87b39 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorch+Value.swift +++ b/extension/apple/ExecuTorch/Exported/ExecuTorch+Value.swift @@ -8,14 +8,6 @@ @_exported import ExecuTorch -/// A protocol that provides a uniform way to convert different Swift types -/// into a `Value`. -@available(*, deprecated, message: "This API is experimental.") -public protocol ValueConvertible { - /// Converts the instance into a `Value`. - func asValue() -> Value -} - @available(*, deprecated, message: "This API is experimental.") public extension Value { /// Creates a `Value` instance encapsulating a `Tensor`. @@ -41,6 +33,52 @@ public extension Value { } } +/// A protocol that provides a uniform way to convert different Swift types +/// into a `Value`. +@available(*, deprecated, message: "This API is experimental.") +public protocol ValueConvertible { + /// Converts the instance into a `Value`. + func asValue() -> Value +} + +/// A protocol that provides a uniform way to create an instance from a `Value`. +@available(*, deprecated, message: "This API is experimental.") +public protocol ValueConstructible { + /// Constructs the instance from a `Value`. + static func from(_ value: Value) throws -> Self +} + +@available(*, deprecated, message: "This API is experimental.") +public extension ValueConstructible { + /// Sugar on top of `decode(from:)` + init(_ value: Value) throws { + self = try Self.from(value) + } +} + +/// A protocol that provides a uniform way to create an instance from an array of `Value`. +@available(*, deprecated, message: "This API is experimental.") +public protocol ValueSequenceConstructible { + /// Constructs the instance from a `Value` array. + static func from(_ values: [Value]) throws -> Self +} + +@available(*, deprecated, message: "This API is experimental.") +extension ValueSequenceConstructible where Self: ValueConstructible { + public static func from(_ values: [Value]) throws -> Self { + guard values.count == 1 else { throw Error(code: .invalidType) } + return try Self.from(values[0]) + } +} + +@available(*, deprecated, message: "This API is experimental.") +public extension ValueSequenceConstructible { + /// Sugar on top of `decode(from:)` + init(_ values: [Value]) throws { + self = try Self.from(values) + } +} + // MARK: - ValueConvertible Conformances @available(*, deprecated, message: "This API is experimental.") @@ -150,3 +188,224 @@ extension UInt: ValueConvertible { /// Converts the `UInt` into a `Value`. public func asValue() -> Value { Value(NSNumber(value: self)) } } + +// MARK: - ValueConstructible Conformances + +@available(*, deprecated, message: "This API is experimental.") +extension Value: ValueConstructible, ValueSequenceConstructible { + public static func from(_ value: Value) throws -> Self { + value as! Self + } +} + +@available(*, deprecated, message: "This API is experimental.") +extension AnyTensor: ValueConstructible, ValueSequenceConstructible { + public static func from(_ value: Value) throws -> Self { + guard let tensor = value.anyTensor else { + throw Error(code: .invalidType, description: "Value is not a tensor") + } + return tensor as! Self + } +} + +@available(*, deprecated, message: "This API is experimental.") +extension Tensor: ValueConstructible, ValueSequenceConstructible { + public static func from(_ value: Value) throws -> Self { + guard let anyTensor = value.anyTensor else { + throw Error(code: .invalidType, description: "Value is not a tensor") + } + guard let tensor = Tensor(anyTensor) as? Self else { + throw Error(code: .invalidType, description: "Tensor is not of type \(Self.self)") + } + return tensor + } +} + +@available(*, deprecated, message: "This API is experimental.") +extension String: ValueConstructible, ValueSequenceConstructible { + public static func from(_ value: Value) throws -> Self { + guard let string = value.string else { + throw Error(code: .invalidType, description: "Value is not a string") + } + return string + } +} + +@available(*, deprecated, message: "This API is experimental.") +extension NSNumber: ValueConstructible, ValueSequenceConstructible { + public static func from(_ value: Value) throws -> Self { + guard let scalar = value.scalar as? Self else { + throw Error(code: .invalidType, description: "Value is not a scalar") + } + return scalar + } +} + +@available(*, deprecated, message: "This API is experimental.") +extension UInt8: ValueConstructible, ValueSequenceConstructible { + public static func from(_ value: Value) throws -> Self { + guard let scalar = value.scalar else { + throw Error(code: .invalidType, description: "Value is not a scalar") + } + guard let integer = UInt8(exactly: scalar.uint8Value) else { + throw Error(code: .invalidType, description: "Cannot convert scalar to \(Self.self)") + } + return integer + } +} + +@available(*, deprecated, message: "This API is experimental.") +extension Int8: ValueConstructible, ValueSequenceConstructible { + public static func from(_ value: Value) throws -> Self { + guard let scalar = value.scalar else { + throw Error(code: .invalidType, description: "Value is not a scalar") + } + guard let integer = Int8(exactly: scalar.int8Value) else { + throw Error(code: .invalidType, description: "Cannot convert scalar to \(Self.self)") + } + return integer + } +} + +@available(*, deprecated, message: "This API is experimental.") +extension Int16: ValueConstructible, ValueSequenceConstructible { + public static func from(_ value: Value) throws -> Self { + guard let scalar = value.scalar else { + throw Error(code: .invalidType, description: "Value is not a scalar") + } + guard let integer = Int16(exactly: scalar.int16Value) else { + throw Error(code: .invalidType, description: "Cannot convert scalar to \(Self.self)") + } + return integer + } +} + +@available(*, deprecated, message: "This API is experimental.") +extension Int32: ValueConstructible, ValueSequenceConstructible { + public static func from(_ value: Value) throws -> Self { + guard let scalar = value.scalar else { + throw Error(code: .invalidType, description: "Value is not a scalar") + } + guard let integer = Int32(exactly: scalar.int32Value) else { + throw Error(code: .invalidType, description: "Cannot convert scalar to \(Self.self)") + } + return integer + } +} + +@available(*, deprecated, message: "This API is experimental.") +extension Int64: ValueConstructible, ValueSequenceConstructible { + public static func from(_ value: Value) throws -> Self { + guard let scalar = value.scalar else { + throw Error(code: .invalidType, description: "Value is not a scalar") + } + guard let integer = Int64(exactly: scalar.int64Value) else { + throw Error(code: .invalidType, description: "Cannot convert scalar to \(Self.self)") + } + return integer + } +} + +@available(*, deprecated, message: "This API is experimental.") +extension Int: ValueConstructible, ValueSequenceConstructible { + public static func from(_ value: Value) throws -> Self { + guard let scalar = value.scalar else { + throw Error(code: .invalidType, description: "Value is not a scalar") + } + guard let integer = Int(exactly: scalar.intValue) else { + throw Error(code: .invalidType, description: "Cannot convert scalar to \(Self.self)") + } + return integer + } +} + +@available(*, deprecated, message: "This API is experimental.") +extension Float: ValueConstructible, ValueSequenceConstructible { + public static func from(_ value: Value) throws -> Self { + guard value.isFloat else { + throw Error(code: .invalidType, description: "Value is not a float") + } + return value.float as Self + } +} + +@available(*, deprecated, message: "This API is experimental.") +extension Double: ValueConstructible, ValueSequenceConstructible { + public static func from(_ value: Value) throws -> Self { + guard value.isDouble else { + throw Error(code: .invalidType, description: "Value is not a double") + } + return value.double as Self + } +} + +@available(*, deprecated, message: "This API is experimental.") +extension Bool: ValueConstructible, ValueSequenceConstructible { + public static func from(_ value: Value) throws -> Self { + guard value.isBoolean else { + throw Error(code: .invalidType, description: "Value is not a boolean") + } + return value.boolean as Self + } +} + +@available(*, deprecated, message: "This API is experimental.") +extension UInt16: ValueConstructible, ValueSequenceConstructible { + public static func from(_ value: Value) throws -> Self { + guard let scalar = value.scalar else { + throw Error(code: .invalidType, description: "Value is not a scalar") + } + guard let integer = UInt16(exactly: scalar.uint16Value) else { + throw Error(code: .invalidType, description: "Cannot convert scalar to \(Self.self)") + } + return integer + } +} + +@available(*, deprecated, message: "This API is experimental.") +extension UInt32: ValueConstructible, ValueSequenceConstructible { + public static func from(_ value: Value) throws -> Self { + guard let scalar = value.scalar else { + throw Error(code: .invalidType, description: "Value is not a scalar") + } + guard let integer = UInt32(exactly: scalar.uint32Value) else { + throw Error(code: .invalidType, description: "Cannot convert scalar to \(Self.self)") + } + return integer + } +} + +@available(*, deprecated, message: "This API is experimental.") +extension UInt64: ValueConstructible, ValueSequenceConstructible { + public static func from(_ value: Value) throws -> Self { + guard let scalar = value.scalar else { + throw Error(code: .invalidType, description: "Value is not a scalar") + } + guard let integer = UInt64(exactly: scalar.uint64Value) else { + throw Error(code: .invalidType, description: "Cannot convert scalar to \(Self.self)") + } + return integer + } +} + +@available(*, deprecated, message: "This API is experimental.") +extension UInt: ValueConstructible, ValueSequenceConstructible { + public static func from(_ value: Value) throws -> Self { + guard let scalar = value.scalar else { + throw Error(code: .invalidType, description: "Value is not a scalar") + } + guard let integer = UInt(exactly: scalar.uintValue) else { + throw Error(code: .invalidType, description: "Cannot convert scalar to \(Self.self)") + } + return integer + } +} + +// MARK: - ValueSequenceConstructible Conformances + +@available(*, deprecated, message: "This API is experimental.") +extension Array: ValueSequenceConstructible where Element: ValueConstructible { + public static func from(_ values: [Value]) throws -> [Element] { + return try values.map { try Element.from($0) } + } +} diff --git a/extension/apple/ExecuTorch/__tests__/ValueTest.swift b/extension/apple/ExecuTorch/__tests__/ValueTest.swift index 34c3d12e14d..c28f9db2fe8 100644 --- a/extension/apple/ExecuTorch/__tests__/ValueTest.swift +++ b/extension/apple/ExecuTorch/__tests__/ValueTest.swift @@ -123,3 +123,169 @@ class ValueTest: XCTestCase { XCTAssertFalse(tensorValue1.isEqual(tensorValueDifferent)) } } + +class ValueProtocolTest: XCTestCase { + private func encoded(_ inputs: ValueConvertible...) -> [Value] { + inputs.map { $0.asValue() } + } + + func testEncodeDecodeBool() throws { + let original: Bool = true + let value = original.asValue() + XCTAssertTrue(value.isBoolean) + let decoded: Bool = try Bool.from(value) + XCTAssertEqual(decoded, original) + } + + func testEncodeDecodeInt() throws { + let original: Int = 123 + let value = original.asValue() + XCTAssertTrue(value.isInteger) + let decoded: Int = try Int.from(value) + XCTAssertEqual(decoded, original) + } + + func testEncodeDecodeInt8() throws { + let original: Int8 = -42 + let value = original.asValue() + XCTAssertTrue(value.isInteger) + let decoded: Int8 = try Int8.from(value) + XCTAssertEqual(decoded, original) + } + + func testEncodeDecodeInt16() throws { + let original: Int16 = 1024 + let value = original.asValue() + XCTAssertTrue(value.isInteger) + let decoded: Int16 = try Int16.from(value) + XCTAssertEqual(decoded, original) + } + + func testEncodeDecodeInt32() throws { + let original: Int32 = -2048 + let value = original.asValue() + XCTAssertTrue(value.isInteger) + let decoded: Int32 = try Int32.from(value) + XCTAssertEqual(decoded, original) + } + + func testEncodeDecodeInt64() throws { + let original: Int64 = 1_000_000_000 + let value = original.asValue() + XCTAssertTrue(value.isInteger) + let decoded: Int64 = try Int64.from(value) + XCTAssertEqual(decoded, original) + } + + func testEncodeDecodeUInt8() throws { + let original: UInt8 = 255 + let value = original.asValue() + XCTAssertTrue(value.isInteger) + let decoded: UInt8 = try UInt8.from(value) + XCTAssertEqual(decoded, original) + } + + func testEncodeDecodeUInt16() throws { + let original: UInt16 = 65_535 + let value = original.asValue() + XCTAssertTrue(value.isInteger) + let decoded: UInt16 = try UInt16.from(value) + XCTAssertEqual(decoded, original) + } + + func testEncodeDecodeUInt32() throws { + let original: UInt32 = 4_294_967_295 + let value = original.asValue() + XCTAssertTrue(value.isInteger) + let decoded: UInt32 = try UInt32.from(value) + XCTAssertEqual(decoded, original) + } + + func testEncodeDecodeUInt64() throws { + let original: UInt64 = 18_446_744_073_709_551_615 + let value = original.asValue() + XCTAssertTrue(value.isInteger) + let decoded: UInt64 = try UInt64.from(value) + XCTAssertEqual(decoded, original) + } + + func testEncodeDecodeUInt() throws { + let original: UInt = 42 + let value = original.asValue() + XCTAssertTrue(value.isInteger) + let decoded: UInt = try UInt.from(value) + XCTAssertEqual(decoded, original) + } + + func testEncodeDecodeFloat() throws { + let original: Float = 3.1415 + let value = original.asValue() + XCTAssertTrue(value.isFloat) + let decoded: Float = try Float.from(value) + XCTAssertEqual(decoded, original) + } + + func testEncodeDecodeDouble() throws { + let original: Double = 2.71828 + let value = original.asValue() + XCTAssertTrue(value.isDouble) + let decoded: Double = try Double.from(value) + XCTAssertEqual(decoded, original) + } + + func testEncodeDecodeString() throws { + let original = "swift" + let value = original.asValue() + XCTAssertTrue(value.isString) + let decoded: String = try String.from(value) + XCTAssertEqual(decoded, original) + } + + func testEncodeDecodeNSNumber() throws { + let original = NSNumber(value: 7.0) + let value = original.asValue() + XCTAssertTrue(value.isDouble) + let decoded: NSNumber = try NSNumber.from(value) + XCTAssertEqual(decoded, original) + } + + func testSequenceDecodeSingleInt() throws { + let values = encoded(99) + let decoded = try Int.from(values) + XCTAssertEqual(decoded, 99) + } + + func testSequenceDecodeSingleBool() throws { + let values = encoded(false) + let decoded = try Bool.from(values) + XCTAssertEqual(decoded, false) + } + + func testSequenceDecodeMultipleFailure() { + let values = encoded(1, 2) + XCTAssertThrowsError(try Int.from(values)) + } + + func testArrayDecodeInts() throws { + let values = encoded(1, 2, 3, 4) + let decoded: [Int] = try [Int].from(values) + XCTAssertEqual(decoded, [1, 2, 3, 4]) + } + + func testArrayDecodeFloats() throws { + let values = encoded(1.5, 2.5, 3.5) + let decoded: [Float] = try [Float].from(values) + XCTAssertEqual(decoded, [1.5, 2.5, 3.5]) + } + + func testArrayDecodeMismatchFailure() { + let values = encoded(1, "two", 3) + XCTAssertThrowsError(try [Int].from(values)) + } + + func testArrayDecodeEmpty() throws { + let values: [Value] = encoded() + let decoded: [Int] = try [Int].from(values) + XCTAssertEqual(decoded, []) + } +}