Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/AutoDiff/refcounting.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ public struct Vector : AdditiveArithmetic, VectorProtocol, Differentiable, Equat
public var y: Float
public var nonTrivialStuff = NonTrivialStuff()
public typealias TangentVector = Vector
public typealias Scalar = Float
public typealias VectorSpaceScalar = Float
public static var zero: Vector { return Vector(0) }
public init(_ scalar: Float) { self.x = scalar; self.y = scalar }

Expand Down
8 changes: 4 additions & 4 deletions test/Sema/struct_key_path_iterable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ extension TensorParameters : VectorProtocol {
static func - (lhs: TensorParameters, rhs: TensorParameters) -> TensorParameters {
return TensorParameters(w: lhs.w + rhs.w, b: lhs.b + rhs.b)
}
typealias Scalar = Tensor<Float>
static func * (lhs: Scalar, rhs: TensorParameters) -> TensorParameters {
typealias VectorSpaceScalar = Tensor<Float>
static func * (lhs: VectorSpaceScalar, rhs: TensorParameters) -> TensorParameters {
return TensorParameters(w: lhs + rhs.w, b: lhs + rhs.b)
}
}
Expand Down Expand Up @@ -89,7 +89,7 @@ func pow<T : BinaryFloatingPoint>(_ x: T, _ y: T) -> T {
}

struct AdamOptimizer<P : KeyPathIterable, Scalar : BinaryFloatingPoint>
where P : VectorProtocol, P.Scalar == Tensor<Scalar>
where P : VectorProtocol, P.VectorSpaceScalar == Tensor<Scalar>
{
let learningRate: Scalar
var beta1: Scalar
Expand Down Expand Up @@ -134,7 +134,7 @@ struct AdamOptimizer<P : KeyPathIterable, Scalar : BinaryFloatingPoint>
func testOptimizer<P : KeyPathIterable, Scalar : BinaryFloatingPoint>(
parameters: inout P, withGradients gradients: P
)
where P : VectorProtocol, P.Scalar == Tensor<Scalar>
where P : VectorProtocol, P.VectorSpaceScalar == Tensor<Scalar>
{
var optimizer = AdamOptimizer<P, Scalar>()
print(parameters)
Expand Down
8 changes: 4 additions & 4 deletions test/Sema/struct_vector_protocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
// RUN: %target-swift-frontend -typecheck -verify -primary-file %s %S/Inputs/struct_vector_protocol_other_module.swift

func testVectorProtocol<T : VectorProtocol>(
_ x: inout T, scalar: T.Scalar
_ x: inout T, scalar: T.VectorSpaceScalar
) {
// Test `AdditiveArithmetic` requirements: `zero`, `+`, `-`.
let zero = T.zero
x += x + zero
x -= x - zero
// Test `VectorProtocol` requirements: `Scalar`, `*`.
// Test `VectorProtocol` requirements: `VectorSpaceScalar`, `*`.
x *= scalar
_ = scalar * x
_ = x * scalar
Expand All @@ -32,7 +32,7 @@ func testVector2(float2: Float2) {
_ = Vector2<Double>(x: 1, y: 1)
_ = Vector2<Float2>(x: float2, y: float2)
}
func testGeneric<T : VectorProtocol>(vec2: inout Vector2<T>, scalar: T.Scalar) {
func testGeneric<T : VectorProtocol>(vec2: inout Vector2<T>, scalar: T.VectorSpaceScalar) {
testVectorProtocol(&vec2, scalar: scalar)
}

Expand Down Expand Up @@ -90,7 +90,7 @@ extension GenericExtended : Equatable, AdditiveArithmetic, VectorProtocol where
struct Empty : VectorProtocol {} // expected-error {{type 'Empty' does not conform to protocol 'VectorProtocol'}}

// Test type whose members conform to `VectorProtocol`
// but have different `Scalar` associated type.
// but have different `VectorSpaceScalar` associated type.
struct InvalidMixedScalar: VectorProtocol { // expected-error {{type 'InvalidMixedScalar' does not conform to protocol 'VectorProtocol'}}
var float: Float
var double: Double
Expand Down