diff --git a/Sources/DeepLearning/DifferentialOperators.swift b/Sources/DeepLearning/DifferentialOperators.swift new file mode 100644 index 000000000..bfb53db77 --- /dev/null +++ b/Sources/DeepLearning/DifferentialOperators.swift @@ -0,0 +1,178 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if !COMPILING_TENSORFLOW_MODULE +import TensorFlow +#endif + +//===------------------------------------------------------------------------------------------===// +// Method-style Differential Operators +//===------------------------------------------------------------------------------------------===// + +public extension Differentiable { + @inlinable + func gradient( + in f: @differentiable (Self) -> Tensor + ) -> CotangentVector { + return self.pullback(in: f)(Tensor(1)) + } + + @inlinable + func valueWithGradient( + in f: @differentiable (Self) -> Tensor + ) -> (value: Tensor, gradient: CotangentVector) { + let (y, pb) = self.valueWithPullback(in: f) + return (y, pb(Tensor(1))) + } + + @inlinable + func gradient( + at x: T, + in f: @differentiable (Self, T) -> Tensor + ) -> (CotangentVector, T.CotangentVector) { + return self.pullback(at: x, in: f)(Tensor(1)) + } + + @inlinable + func valueWithGradient( + at x: T, + in f: @differentiable (Self, T) -> Tensor + ) -> (value: Tensor, gradient: (CotangentVector, T.CotangentVector)) { + let (y, pb) = self.valueWithPullback(at: x, in: f) + return (y, pb(Tensor(1))) + } +} + +//===------------------------------------------------------------------------------------------===// +// Free-Function-Style Differential Operators +//===------------------------------------------------------------------------------------------===// + +// Value with gradient + +@inlinable +public func valueWithGradient( + at x: T, + in f: @differentiable (T) -> Tensor +) -> (value: Tensor, gradient: T.CotangentVector) +where T: Differentiable, R: TensorFlowFloatingPoint { + let (y, pullback) = valueWithPullback(at: x, in: f) + return (y, pullback(Tensor(1))) +} + +@inlinable +public func valueWithGradient( + at x: T, + _ y: U, + in f: @differentiable (T, U) -> Tensor +) -> (value: Tensor, gradient: (T.CotangentVector, U.CotangentVector)) + where T: Differentiable, U: Differentiable, R: TensorFlowFloatingPoint { + let (y, pullback) = valueWithPullback(at: x, y, in: f) + return (y, pullback(Tensor(1))) +} + +@inlinable +public func valueWithGradient( + at x: T, + _ y: U, + _ z: V, + in f: @differentiable (T, U, V) -> Tensor +) -> (value: Tensor, gradient: (T.CotangentVector, U.CotangentVector, V.CotangentVector)) + where T: Differentiable, U: Differentiable, V: Differentiable, R: TensorFlowFloatingPoint { + let (y, pullback) = valueWithPullback(at: x, y, z, in: f) + return (y, pullback(Tensor(1))) +} + +// Value with gradient (curried) + +@inlinable +public func valueWithGradient( + of f: @escaping @differentiable (T) -> Tensor +) -> (T) -> (value: Tensor, gradient: T.CotangentVector) + where T: Differentiable, R: TensorFlowFloatingPoint { + return { x in valueWithGradient(at: x, in: f) } +} + +@inlinable +public func valueWithGradient( + of f: @escaping @differentiable (T, U) -> Tensor +) -> (T, U) -> (value: Tensor, gradient: (T.CotangentVector, U.CotangentVector)) + where T: Differentiable, U: Differentiable, R: TensorFlowFloatingPoint { + return { x, y in valueWithGradient(at: x, y, in: f) } +} + +@inlinable +public func valueWithGradient( + of f: @escaping @differentiable (T, U, V) -> Tensor +) -> (T, U, V) -> ( + value: Tensor, + gradient: (T.CotangentVector, U.CotangentVector, V.CotangentVector)) + where T: Differentiable, U: Differentiable, V: Differentiable, R: TensorFlowFloatingPoint { + return { x, y, z in valueWithGradient(at: x, y, z, in: f) } +} + +// Gradient + +@inlinable +public func gradient( + at x: T, + in f: @differentiable (T) -> Tensor +) -> T.CotangentVector where T: Differentiable, R: TensorFlowFloatingPoint { + return pullback(at: x, in: f)(Tensor(1)) +} + +@inlinable +public func gradient( + at x: T, + _ y: U, + in f: @differentiable (T, U) -> Tensor +) -> (T.CotangentVector, U.CotangentVector) + where T: Differentiable, U: Differentiable, R: TensorFlowFloatingPoint { + return pullback(at: x, y, in: f)(Tensor(1)) +} + +@inlinable +public func gradient( + at x: T, + _ y: U, + _ z: V, + in f: @differentiable (T, U, V) -> Tensor +) -> (T.CotangentVector, U.CotangentVector, V.CotangentVector) + where T: Differentiable, U: Differentiable, V: Differentiable, R: TensorFlowFloatingPoint { + return pullback(at: x, y, z, in: f)(Tensor(1)) +} + +// Gradient (curried) + +@inlinable +public func gradient( + of f: @escaping @differentiable (T) -> Tensor +) -> (T) -> T.CotangentVector where T: Differentiable, R: TensorFlowFloatingPoint { + return { x in gradient(at: x, in: f) } +} + +@inlinable +public func gradient( + of f: @escaping @differentiable (T, U) -> Tensor +) -> (T, U) -> (T.CotangentVector, U.CotangentVector) + where T: Differentiable, U: Differentiable, R: TensorFlowFloatingPoint { + return { x, y in gradient(at: x, y, in: f) } +} + +@inlinable +public func gradient( + of f: @escaping @differentiable (T, U, V) -> Tensor +) -> (T, U, V) -> (T.CotangentVector, U.CotangentVector, V.CotangentVector) + where T: Differentiable, U: Differentiable, V: Differentiable, R: TensorFlowFloatingPoint { + return { x, y, z in gradient(at: x, y, z, in: f) } +} diff --git a/Sources/DeepLearning/Helpers.swift b/Sources/DeepLearning/Helpers.swift index 86aec74bf..4d9c0217b 100644 --- a/Sources/DeepLearning/Helpers.swift +++ b/Sources/DeepLearning/Helpers.swift @@ -13,12 +13,20 @@ // limitations under the License. #if !COMPILING_TENSORFLOW_MODULE -import TensorFlow +@_exported import TensorFlow #endif +/// Returns a tensor with the same shape and scalars as the specified tensor. +@inlinable +@differentiable +public func identity(_ x: Tensor) -> Tensor { + return x +} + // `pow` is defined in Darwin/Glibc on `Float` and `Double`, but there doesn't exist a generic // version for `FloatingPoint`. // This is a manual definition. -func pow(_ x: T, _ y: T) -> T { +@inlinable +func pow(_ x: T, _ y: T) -> T { return T(pow(Double(x), Double(y))) } diff --git a/Sources/DeepLearning/Initializers.swift b/Sources/DeepLearning/Initializers.swift index 0655b817b..3ab3f5654 100644 --- a/Sources/DeepLearning/Initializers.swift +++ b/Sources/DeepLearning/Initializers.swift @@ -13,9 +13,294 @@ // limitations under the License. #if !COMPILING_TENSORFLOW_MODULE -@_exported import TensorFlow +import TensorFlow #endif +public extension Tensor { + /// Creates a tensor with the specified shape and a single, repeated scalar + /// value. + /// + /// - Parameters: + /// - shape: The dimensions of the tensor. + /// - repeatedValue: The scalar value to repeat. + @inlinable + @available(*, deprecated, renamed: "init(repeating:shape:)") + init(shape: TensorShape, repeating repeatedValue: Scalar) { + self.init(repeating: repeatedValue, shape: shape) + } + + /// Creates a tensor with the specified shape and a single, repeated scalar value. + /// + /// - Parameters: + /// - repeatedValue: The scalar value to repeat. + /// - shape: The dimensions of the tensor. + @inlinable + @differentiable( + vjp: _vjpInit(repeating:shape:) where Scalar: TensorFlowFloatingPoint) + init(repeating repeatedValue: Scalar, shape: TensorShape) { + self = Raw.fill( + dims: Tensor(shape.dimensions.map(Int32.init)), + value: Tensor(repeatedValue)) + } + + /// Creates a tensor by broadcasting the given scalar to a given rank with + /// all dimensions being 1. + @inlinable + // @differentiable(where Scalar: TensorFlowFloatingPoint) + init(broadcasting scalar: Scalar, rank: Int) { + self = Tensor(scalar).reshaped(to: TensorShape(repeating: 1, count: rank)) + } + + /// Creates a tensor of shape `[4]` from a 4-tuple. + /// - Note: This is intended for internal use, for example, to initialize a + /// tensor attribute from `convolved2D`'s `strides` argument. + @inlinable + internal init(_ scalars: (Scalar, Scalar, Scalar, Scalar)) { + self.init([scalars.0, scalars.1, scalars.2, scalars.3]) + } +} + +internal extension Tensor where Scalar: TensorFlowFloatingPoint { + @inlinable + static func _vjpInit( + repeating repeatedValue: Scalar, + shape: TensorShape + ) -> (Tensor, (Tensor) -> Scalar) { + return (Tensor(repeating: repeatedValue, shape: shape), { + $0.sum().scalarized() + }) + } +} + +//===------------------------------------------------------------------------------------------===// +// Casting +//===------------------------------------------------------------------------------------------===// + +public extension Tensor where Scalar: Numeric { + /// Perform an element-wise type conversion from a `Bool` tensor. + @inlinable + init(_ other: Tensor) { + self = Raw.cast(other) + } + + /// Perform an element-wise conversion from another `Tensor`. + @inlinable + @differentiable( + vjp: _vjpCast where Scalar: TensorFlowFloatingPoint, OtherScalar: TensorFlowFloatingPoint) + init(_ other: Tensor) { + self = Raw.cast(other) + } +} + +internal extension Tensor where Scalar: TensorFlowFloatingPoint { + @inlinable + static func _vjpCast( + _ other: Tensor + ) -> (Tensor, (Tensor) -> Tensor) { + return (Tensor(other), { v in Tensor(v) }) + } +} + +//===------------------------------------------------------------------------------------------===// +// Stacking / Concatenating / Tiling +//===------------------------------------------------------------------------------------------===// + +public extension Tensor { + /// Creates a tensor from an array of tensors (which may themselves be scalars). + @inlinable + @differentiable(where Scalar: TensorFlowFloatingPoint) + init(_ elements: [Tensor]) { + self = Tensor(stacking: elements) + } + + /// Stacks `tensors`, along the `axis` dimension, into a new tensor with rank one higher than + /// the current tensor and each tensor in `tensors`. + /// + /// Given that `tensors` all have shape `[A, B, C]`, and `tensors.count = N`, then: + /// - if `axis == 0` then the resulting tensor will have the shape `[N, A, B, C]`. + /// - if `axis == 1` then the resulting tensor will have the shape `[A, N, B, C]`. + /// - etc. + /// + /// For example: + /// ``` + /// // 'x' is [1, 4] + /// // 'y' is [2, 5] + /// // 'z' is [3, 6] + /// Tensor(stacking: [x, y, z]) // is [[1, 4], [2, 5], [3, 6]] + /// Tensor(stacking: [x, y, z], alongAxis: 1) // is [[1, 2, 3], [4, 5, 6]] + /// ``` + /// + /// This is the opposite of `Tensor.unstack(alongAxis:)`. + /// + /// - Parameters: + /// - tensors: Tensors to stack. + /// - axis: Dimension along which to stack. Negative values wrap around. + /// + /// - Precondition: All tensors must have the same shape. + /// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of the + /// provided tensors. + /// + /// - Returns: The stacked tensor. + @inlinable + @differentiable(vjp: _vjpStacking where Scalar: TensorFlowFloatingPoint) + init(stacking tensors: [Tensor], alongAxis axis: Int = 0) { + self = Raw.pack(tensors, axis: Int64(axis)) + } + + /// Concatenates `tensors` along the `axis` dimension. + /// + /// Given that `tensors[i].shape = [D0, D1, ... Daxis(i), ...Dn]`, then the concatenated result + /// has shape `[D0, D1, ... Raxis, ...Dn]`, where `Raxis = sum(Daxis(i))`. That is, the data + /// from the input tensors is joined along the `axis` dimension. + /// + /// For example: + /// ``` + /// // t1 is [[1, 2, 3], [4, 5, 6]] + /// // t2 is [[7, 8, 9], [10, 11, 12]] + /// Tensor(concatenating: [t1, t2]) // is [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]] + /// Tensor(concatenating: [t1, t2], alongAxis: 1) // is [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]] + /// + /// // t3 has shape [2, 3] + /// // t4 has shape [2, 3] + /// Tensor(concatenating: [t3, t4]) // has shape [4, 3] + /// Tensor(concatenating: [t3, t4], alongAxis: 1) // has shape [2, 6] + /// ``` + /// + /// - Note: If you are concatenating along a new axis consider using + /// `Tensor.init(stacking:alongAxis:)`. + /// + /// - Parameters: + /// - tensors: Tensors to concatenate. + /// - axis: Dimension along which to concatenate. Negative values wrap around. + /// + /// - Precondition: All tensors must have the same rank and all dimensions except `axis` + /// must be equal. + /// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of the + /// provided tensors. + /// + /// - Returns: The concatenated tensor. + @inlinable + @differentiable(vjp: _vjpConcatenating where Scalar: TensorFlowFloatingPoint) + init(concatenating tensors: [Tensor], alongAxis axis: Int = 0) { + precondition(tensors.count > 0) + self = Raw.concatV2(tensors, axis: Tensor(Int32(axis))) + } +} + +internal extension Tensor where Scalar: TensorFlowFloatingPoint { + @inlinable + static func _vjpStacking( + stacking tensors: [Tensor], + alongAxis axis: Int = 0 + ) -> (Tensor, (Tensor) -> Array.DifferentiableView) { + let result = Tensor(stacking: tensors, alongAxis: axis) + return (result, { v in + Array.DifferentiableView(v.unstack(alongAxis: axis)) + }) + } + + @inlinable + static func _vjpConcatenating( + concatenating tensors: [Tensor], + alongAxis axis: Int = 0 + ) -> (Tensor, (Tensor) -> Array.DifferentiableView) { + let result = Tensor(concatenating: tensors, alongAxis: axis) + let posAxis = axis < 0 ? axis + tensors[0].rank: axis + let sizes = Tensor(stacking: tensors.map { $0.shapeTensor[posAxis] }) + return (result, { [count = tensors.count] v in + if count == 1 { return Array.DifferentiableView([v]) } + let splits = v.split(sizes: sizes, alongAxis: posAxis) + return Array.DifferentiableView(splits) + }) + } +} + +//===------------------------------------------------------------------------------------------===// +// Numeric +//===------------------------------------------------------------------------------------------===// + +public extension Tensor where Scalar: Numeric { + /// Creates a tensor with all scalars set to zero. + /// + /// - Parameter shape: Shape of the tensor. + @inlinable + init(zeros shape: TensorShape) { + self.init(repeating: 0, shape: shape) + } + + /// Creates a tensor with all scalars set to one. + /// + /// - Parameter shape: Shape of the tensor. + @inlinable + init(ones shape: TensorShape) { + self.init(repeating: 1, shape: shape) + } + + /// Creates a 1-D tensor representing a sequence from a starting value to, but not including, + /// an end value, stepping by the specified amount. + /// + /// - Parameters: + /// - start: The starting value to use for the sequence. If the sequence + /// contains any values, the first one is `start`. + /// - end: An end value to limit the sequence. `end` is never an element of + /// the resulting sequence. + /// - stride: The amount to step by with each iteration. `stride` must be + /// positive. + /// + @inlinable + init(rangeFrom start: Scalar, to end: Scalar, stride: Scalar) { + self = Raw.range(start: Tensor(start), limit: Tensor(end), delta: Tensor(stride)) + } + + /// Creates a one-hot tensor at given indices. The locations represented by + /// `indices` take value `onValue` (`1` by default), while all other locations + /// take value `offValue` (`0` by default). If the input `indices` is rank + /// `n`, the new tensor will have rank `n+1`. The new axis is created at + /// dimension `axis` (by default, the new axis is appended at the end). + /// + /// If `indices` is a scalar, the new tensor's shape will be a vector of + /// length `depth`. + /// + /// If `indices` is a vector of length `features`, the output shape will be: + /// features x depth, if axis == -1 + /// depth x features, if axis == 0 + /// + /// If `indices` is a matrix (batch) with shape `[batch, features]`, the + /// output shape will be: + /// batch x features x depth, if axis == -1 + /// batch x depth x features, if axis == 1 + /// depth x batch x features, if axis == 0 + /// + /// - Parameters: + /// - indices: A `Tensor` of indices. + /// - depth: A scalar defining the depth of the one hot dimension. + /// - onValue: A scalar defining the value at the location referred to by + /// some index in `indices`. + /// - offValue: A scalar defining the value at a location that is not + /// referred to by any index in `indices`. + /// - axis: The axis to fill. The default is `-1`, a new inner-most axis. + /// + @inlinable + init( + oneHotAtIndices indices: Tensor, + depth: Int, + onValue: Scalar = 1, + offValue: Scalar = 0, + axis: Int = -1 + ) { + self = Raw.oneHot( + indices: indices, + depth: Tensor(Int32(depth)), + onValue: Tensor(onValue), + offValue: Tensor(offValue), + axis: Int64(axis)) + } +} + +//===------------------------------------------------------------------------------------------===// +// Random +//===------------------------------------------------------------------------------------------===// + public extension Tensor where Scalar == Int32 { /// Creates a tensor with the specified shape, randomly sampling scalar values /// from a discrete uniform distribution. @@ -24,8 +309,10 @@ public extension Tensor where Scalar == Int32 { /// - shape: The dimensions of the tensor. /// - generator: Random number generator to use. /// - init(randomStandardUniform shape: TensorShape, - generator: inout G) { + init( + randomStandardUniform shape: TensorShape, + generator: inout G + ) { let dist = UniformIntegerDistribution() var scalars: [Scalar] = [] for _ in 0 ..< shape.contiguousSize { @@ -94,8 +381,10 @@ public extension Tensor where Scalar: BinaryFloatingPoint, /// - shape: The dimensions of the tensor. /// - generator: Random number generator to use. /// - init(randomUniform shape: TensorShape, - generator: inout G) { + init( + randomUniform shape: TensorShape, + generator: inout G + ) { let dist = UniformFloatingPointDistribution() var scalars: [Scalar] = [] for _ in 0 ..< shape.contiguousSize { @@ -113,10 +402,12 @@ public extension Tensor where Scalar: BinaryFloatingPoint, /// - stddev: The standard deviation of the distribution. /// - generator: Random number generator to use. /// - init(randomNormal shape: TensorShape, - mean: Scalar = 0, - stddev: Scalar = 1, - generator: inout G) { + init( + randomNormal shape: TensorShape, + mean: Scalar = 0, + stddev: Scalar = 1, + generator: inout G + ) { let dist = NormalDistribution(mean: mean, standardDeviation: stddev) var scalars: [Scalar] = [] for _ in 0 ..< shape.contiguousSize { @@ -126,7 +417,7 @@ public extension Tensor where Scalar: BinaryFloatingPoint, } } -fileprivate extension Tensor where Scalar : BinaryFloatingPoint { +fileprivate extension Tensor where Scalar: BinaryFloatingPoint { private static func glorot( fromStandardUniform randomUniform: __shared Tensor, shape: __shared TensorShape @@ -150,9 +441,11 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { /// - Parameters: /// - shape: The dimensions of the tensor. /// - init(glorotUniform shape: TensorShape, - seed: (Int64, Int64) = (Int64.random(in: Int64.min.. Tensor { + return Tensor(repeating: self, shape: TensorShape(rank)) + } +} + +public extension Tensor { + /// Unpacks the given dimension of a rank-`R` tensor into multiple rank-`(R-1)` tensors. Unpacks + /// `N` tensors from this tensor by chipping it along the `axis` dimension, where `N` is + /// inferred from this tensor's shape. For example, given a tensor with shape `[A, B, C, D]`: + /// + /// - If `axis == 0` then the `i`-th tensor in the returned array is the slice + /// `self[i, :, :, :]` and each tensor in that array will have shape `[B, C, D]`. + /// (Note that the dimension unpacked along is gone, unlike + /// `Tensor.split(numSplits:alongAxis)`, or `Tensor.split(sizes:alongAxis)`). + /// - If `axis == 1` then the `i`-th tensor in the returned array is the slice + /// `value[:, i, :, :]` and each tensor in that array will have shape `[A, C, D]`. + /// - Etc. + /// + /// This is the opposite of `Tensor.init(stacking:alongAxis:)`. + /// + /// - Parameters: + /// - axis: Dimension along which to unstack. Negative values wrap around. + /// + /// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of the + /// provided tensors. + /// + /// - Returns: Array containing the unstacked tensors. + @inlinable + @differentiable(vjp: _vjpUnstack(alongAxis:) where Scalar: TensorFlowFloatingPoint) + func unstack(alongAxis axis: Int = 0) -> [Tensor] { + return Raw.unpack(value: self, num: Int64(shape[axis]), axis: Int64(axis)) + } + + /// Splits a tensor into multiple tensors. The tensor is split along dimension `axis` into + /// `numSplits` smaller tensors. This requires that `numSplits` evenly divides `shape[axis]`. + /// + /// For example: + /// ``` + /// // 'value' is a tensor with shape [5, 30] + /// // Split 'value' into 3 tensors along dimension 1: + /// let parts = value.split(numSplits: 3, alongAxis: 1) + /// parts[0] // has shape [5, 10] + /// parts[1] // has shape [5, 10] + /// parts[2] // has shape [5, 10] + /// ``` + /// + /// - Parameters: + /// - numSplits: Number of splits to create. + /// - axis: Dimension along which to split this tensor. Negative values wrap around. + /// + /// - Precondition: `numSplits` must divide the size of dimension `axis` evenly. + /// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of the + /// provided tensors. + /// + /// - Returns: Array containing the tensors parts. + @inlinable + @differentiable(vjp: _vjpSplit(numSplits:alongAxis:) where Scalar: TensorFlowFloatingPoint) + func split(numSplits: Int, alongAxis axis: Int = 0) -> [Tensor] { + return Raw.split( + splitDim: Tensor(Int32(axis)), value: self, numSplit: Int64(numSplits)) + } + + /// Splits a tensor into multiple tensors. The tensor is split into `sizes.shape[0]` pieces. + /// The shape of the `i`-th piece has the same shape as this tensor except along dimension + /// `axis` where the size is `sizes[i]`. + /// + /// For example: + /// ``` + /// // 'value' is a tensor with shape [5, 30] + /// // Split 'value' into 3 tensors with sizes [4, 15, 11] along dimension 1: + /// let parts = value.split(sizes: Tensor([4, 15, 11]), alongAxis: 1) + /// parts[0] // has shape [5, 4] + /// parts[1] // has shape [5, 15] + /// parts[2] // has shape [5, 11] + /// ``` + /// + /// - Parameters: + /// - sizes: 1-D tensor containing the size of each split. + /// - axis: Dimension along which to split this tensor. Negative values wrap around. + /// + /// - Precondition: The values in `sizes` must add up to the size of dimension `axis`. + /// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of the + /// provided tensors. + /// + /// - Returns: Array containing the tensors parts. + @inlinable + @differentiable( + wrt: self, + vjp: _vjpSplit(sizes:alongAxis:) where Scalar: TensorFlowFloatingPoint) + func split(sizes: Tensor, alongAxis axis: Int = 0) -> [Tensor] { + return Raw.splitV( + value: self, + sizeSplits: sizes, + splitDim: Tensor(Int32(axis)), + numSplit: Int64(sizes.shape[0])) + } + + /// Reshape to the shape of the specified `Tensor`. + /// - Precondition: The number of scalars matches the new shape. + @inlinable + @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) + func reshaped(like other: Tensor) -> Tensor { + return reshaped(toShape: other.shapeTensor) + } + + /// Reshape to the specified shape. + /// - Precondition: The number of scalars matches the new shape. + @inlinable + @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) + func reshaped(to newShape: TensorShape) -> Tensor { + // TODO(TF-433): Remove workaround for differentiating `map`. + return reshaped(toShape: Tensor({newShape.dimensions.map(Int32.init)}())) + } + + /// Reshape to the specified `Tensor` representing a shape. + /// - Precondition: The number of scalars matches the new shape. + @inlinable + @differentiable( + wrt: self, + vjp: _vjpReshaped(toShape:) where Scalar: TensorFlowFloatingPoint) + func reshaped(toShape newShape: Tensor) -> Tensor { + return Raw.reshape(self, shape: newShape) + } + + /// Return a copy of the tensor collapsed into a 1-D `Tensor`, in row-major order. + @inlinable + @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) + func flattened() -> Tensor { + return reshaped(to: [-1]) + } + + /// Returns a shape-expanded `Tensor`, with a dimension of 1 inserted at the specified shape + /// indices. + @inlinable + @differentiable(wrt: self where Scalar : TensorFlowFloatingPoint) + func expandingShape(at axes: Int...) -> Tensor { + return expandingShape(at: axes) + } + + /// Returns a shape-expanded `Tensor`, with a dimension of 1 inserted at the + /// specified shape indices. + @inlinable + @differentiable(wrt: self, vjp: _vjpExpandingShape(at:) where Scalar: TensorFlowFloatingPoint) + func expandingShape(at axes: [Int]) -> Tensor { + var result = self + for i in axes { result = Raw.expandDims(result, dim: Tensor(Int32(i))) } + return result + } + + /// Returns a rank-lifted `Tensor` with a leading dimension of 1. + @inlinable + @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) + func rankLifted() -> Tensor { + return expandingShape(at: 0) + } + + /// Remove the specified dimensions of size 1 from the shape of a tensor. If no dimensions are + /// specified, then all dimensions of size 1 will be removed. + @inlinable + @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) + func squeezingShape(at axes: Int...) -> Tensor { + return squeezingShape(at: axes) + } + + /// Remove the specified dimensions of size 1 from the shape of a tensor. If no dimensions are + /// specified, then all dimensions of size 1 will be removed. + @inlinable + @differentiable(wrt: self, vjp: _vjpSqueezingShape(at:) where Scalar: TensorFlowFloatingPoint) + func squeezingShape(at axes: [Int]) -> Tensor { + return Raw.squeeze(self, squeezeDims: axes.map(Int32.init)) + } +} + +internal extension Tensor where Scalar: TensorFlowFloatingPoint { + @inlinable + func _vjpUnstack( + alongAxis axis: Int = 0 + ) -> ([Tensor], (Array.CotangentVector) -> Tensor) { + let result = unstack(alongAxis: axis) + return (result, { v in Tensor(stacking: v.base, alongAxis: axis) }) + } + + @inlinable + func _vjpSplit( + numSplits: Int, + alongAxis axis: Int = 0 + ) -> ([Tensor], (Array.CotangentVector) -> Tensor) { + let result = split(numSplits: numSplits, alongAxis: axis) + return (result, { v in Tensor(concatenating: v.base, alongAxis: axis) }) + } + + @inlinable + func _vjpSplit( + sizes: Tensor, + alongAxis axis: Int = 0 + ) -> ([Tensor], (Array.CotangentVector) -> Tensor) { + let result = split(sizes: sizes, alongAxis: axis) + return (result, { v in Tensor(concatenating: v.base, alongAxis: axis) }) + } + + @inlinable + func _vjpReshaped(toShape newShape: Tensor) -> (Tensor, (Tensor) -> Tensor) { + let value = reshaped(toShape: newShape) + return (value, { [shape = shapeTensor] v in v.reshaped(toShape: shape) }) + } + + @inlinable + func _vjpExpandingShape(at axes: [Int]) -> (Tensor, (Tensor) -> Tensor) { + let value = self.expandingShape(at: axes) + return (value, { v in v.squeezingShape(at: axes) }) + } + + @inlinable + func _vjpSqueezingShape(at axes: [Int]) -> (Tensor, (Tensor) -> Tensor) { + let value = squeezingShape(at: axes) + return (value, { [shape = shapeTensor] v in v.reshaped(toShape: shape) }) + } +} + +//===------------------------------------------------------------------------------------------===// +// Other Tensor Transformations +//===------------------------------------------------------------------------------------------===// + +infix operator ++: AdditionPrecedence + +public extension Tensor { + /// Returns a transposed tensor, with dimensions permuted in the specified order. + @inlinable + @differentiable( + wrt: self, + vjp: _vjpTransposed(withPermutations:) where Scalar: TensorFlowFloatingPoint) + func transposed(withPermutations permutations: Tensor) -> Tensor { + return Raw.transpose(self, perm: permutations) + } + + /// Returns a transposed tensor, with dimensions permuted in the specified order. + @inlinable + @differentiable( + wrt: self, + vjp: _vjpTransposed(withPermutations:) where Scalar: TensorFlowFloatingPoint) + func transposed(withPermutations permutations: [Int]) -> Tensor { + let permutations = permutations.map(Int32.init) + return transposed(withPermutations: Tensor(permutations)) + } + + /// Returns a transposed tensor, with dimensions permuted in the specified order. + @inlinable + @differentiable( + wrt: self, vjp: _vjpTransposed(withPermutations:) where Scalar: TensorFlowFloatingPoint) + func transposed(withPermutations permutations: Int...) -> Tensor { + return transposed(withPermutations: permutations) + } + + /// Returns a transposed tensor, with dimensions permuted in reverse order. + @inlinable + @differentiable(wrt: self, vjp: _vjpTransposed() where Scalar: TensorFlowFloatingPoint) + func transposed() -> Tensor { + let defaultPermutations = rankTensor - 1 - Tensor( + rangeFrom: 0, to: Int32(rank), stride: 1) + return transposed(withPermutations: Tensor(defaultPermutations)) + } + + /// Concatenates tensors along the specified axis. + /// - Precondition: The tensors must have the same dimensions, except for the + /// specified axis. + /// - Precondition: The axis must be in the range `-rank.. Tensor { + return Tensor(concatenating: [self, other], alongAxis: axis) + } + + /// Concatenation operator. + /// - Note: `++` is a custom operator that does not exist in Swift, but does + /// in Haskell/Scala. Its addition is not an insignificant language change + /// and may be controversial. The existence/naming of `++` will be discussed + /// during a later API design phase. + @inlinable + @differentiable(where Scalar: TensorFlowFloatingPoint) + static func ++ (lhs: Tensor, rhs: Tensor) -> Tensor { + return lhs.concatenated(with: rhs) + } +} + +internal extension Tensor where Scalar: TensorFlowFloatingPoint { + @inlinable + func _vjpTransposed( + withPermutations permutations: Tensor + ) -> (Tensor, (Tensor) -> Tensor) { + let value = transposed(withPermutations: permutations) + return (value, { $0.transposed(withPermutations: permutations) }) + } + + @inlinable + func _vjpTransposed(withPermutations permutations: [Int]) -> (Tensor, (Tensor) -> Tensor) { + let value = transposed(withPermutations: permutations) + return (value, { $0.transposed(withPermutations: permutations) }) + } + + @inlinable + func _vjpTransposed(withPermutations permutations: Int...) -> (Tensor, (Tensor) -> Tensor) { + let value = transposed(withPermutations: permutations) + return (value, { $0.transposed(withPermutations: permutations) }) + } + + @inlinable + func _vjpTransposed() -> (Tensor, (Tensor) -> Tensor) { + return (transposed(), { $0.transposed() }) + } + + @inlinable + func _vjpConcatenated( + with other: Tensor, + alongAxis axis: Int + ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { + let idx = axis < 0 ? axis + rank: axis + let splits = Tensor([shapeTensor[idx], other.shapeTensor[idx]]) + return (concatenated(with: other, alongAxis: axis), { result in + let gradients = result.split(sizes: splits, alongAxis: axis) + return (gradients[0], gradients[1]) + }) + } +} + +//===------------------------------------------------------------------------------------------===// +// Broadcasting +//===------------------------------------------------------------------------------------------===// + +// TODO: What about precedence? Also, why is this operator meaningful for broadcasting? +infix operator .= + +public extension Tensor { + @inlinable + func broadcast(toShape shape: Tensor) -> Tensor { + return Raw.broadcastTo(self, shape: shape) + } + + @inlinable + func broadcast(to shape: TensorShape) -> Tensor { + return broadcast(toShape: Tensor(shape.dimensions.map(Int32.init))) + } + + /// Broadcast to the same shape as the specified `Tensor`. + /// - Precondition: The specified shape must be compatible for broadcasting. + @inlinable + func broadcast(like other: Tensor) -> Tensor { + return broadcast(toShape: other.shapeTensor) + } + + @inlinable + static func .= (lhs: inout Tensor, rhs: Tensor) { + lhs = rhs.broadcast(like: lhs) + } +} + +// TODO: Why is this limited only to numeric data types whereas `broadcast` is not? +public extension Tensor where Scalar: Numeric { + @inlinable + func unbroadcast(toShape otherShape: Tensor) -> Tensor { + let rankDiff = (rankTensor - otherShape.scalarCountTensor).rankLifted() + let ones: Tensor = Raw.fill(dims: rankDiff, value: Tensor(1)) + let paddedShape = ones ++ otherShape + let nonEqualIndices = paddedShape .!= shapeTensor + let broadcastIndices = Raw.where_(nonEqualIndices).flattened() + let unbroadcasted: Tensor = Raw.sum( + self, reductionIndices: Tensor(broadcastIndices), keepDims: false) + return Raw.reshape(unbroadcasted, shape: otherShape) + } + + @inlinable + func unbroadcast(like other: Tensor) -> Tensor { + return unbroadcast(toShape: other.shapeTensor) + } + + @inlinable + func unbroadcast(to shape: TensorShape) -> Tensor { + return unbroadcast(toShape: Tensor(shape.dimensions.map(Int32.init))) + } +} + +//===------------------------------------------------------------------------------------------===// +// Padding +//===------------------------------------------------------------------------------------------===// + +public extension Tensor where Scalar: Numeric { + /// Returns a padded tensor according to the specified padding sizes. + @inlinable + func padded(forSizes sizes: [(before: Int, after: Int)], with value: Scalar = 0) -> Tensor { + let paddings = Tensor( + shape: [sizes.count, 2], + scalars: sizes.flatMap { [Int32($0.before), Int32($0.after)] }) + return Raw.padV2(self, paddings: paddings, constantValues: Tensor(value)) + } +} + +//===------------------------------------------------------------------------------------------===// +// Indexing and Slicing +//===------------------------------------------------------------------------------------------===// + +// TODO: Negative indexing and strides syntax. + +public extension Tensor { + /// Extracts a slice from the tensor defined by lower and upper bounds for + /// each dimension. + /// + /// - Parameter lowerBounds: The lower bounds at each dimension. + /// - Parameter upperBounds: The upper bounds at each dimension. + @inlinable + @differentiable(wrt: self) + func slice(lowerBounds: [Int], upperBounds: [Int]) -> Tensor { + // TODO: Precondition `lowerBounds.count == upperBounds.count`, + // preferably in graph. + // TODO: Differentiating control flow is not supported yet, thus the thunks. + let lowerBoundsTensor = Tensor({lowerBounds.map(Int32.init)}()) + let upperBoundsTensor = Tensor({upperBounds.map(Int32.init)}()) + return slice(lowerBounds: lowerBoundsTensor, sizes: upperBoundsTensor - lowerBoundsTensor) + } + + @inlinable + @differentiable(wrt: self, vjp: _vjpSlice) + func slice(lowerBounds: Tensor, sizes: Tensor) -> Tensor { + return Raw.slice(self, begin: lowerBounds, size: sizes) + } + + @inlinable + internal func _vjpSlice( + lowerBounds: Tensor, + sizes: Tensor + ) -> (Tensor, (Tensor) -> Tensor) { + let value = slice(lowerBounds: lowerBounds, sizes: sizes) + let afterPaddings = shapeTensor - value.shapeTensor - lowerBounds + return (value, { [after = afterPaddings] v in + let beforePaddings = lowerBounds.expandingShape(at: 1) + let afterPaddings = after.expandingShape(at: 1) + let paddings = Tensor( + concatenating: [beforePaddings, afterPaddings], alongAxis: 1) + return Raw.pad(v, paddings: paddings) + }) + } +} + +public enum TensorRange: TensorRangeExpression { + case ellipsis + case newAxis + case squeezeAxis + case index(Int) + case range(Range, stride: Int) + case closedRange(ClosedRange, stride: Int) + case partialRangeFrom(PartialRangeFrom, stride: Int) + case partialRangeUpTo(PartialRangeUpTo, stride: Int) + case partialRangeThrough(PartialRangeThrough, stride: Int) + + public var tensorRange: TensorRange { return self } +} + +extension TensorRange: Equatable { + public static func == (lhs: TensorRange, rhs: TensorRange) -> Bool { + switch (lhs, rhs) { + case (.ellipsis, .ellipsis), + (.newAxis, .newAxis), + (.squeezeAxis, .squeezeAxis): + return true + case (let .index(i1), let .index(i2)): return i1 == i2 + case (let .range(r1, s1), let .range(r2, s2)): return r1 == r2 && s1 == s2 + case (let .closedRange(r1, s1), let .closedRange(r2, s2)): + return r1 == r2 && s1 == s2 + case (let .partialRangeFrom(r1, s1), let .partialRangeFrom(r2, s2)): + return r1.lowerBound == r2.lowerBound && s1 == s2 + case (let .partialRangeUpTo(r1, s1), let .partialRangeUpTo(r2, s2)): + return r1.upperBound == r2.upperBound && s1 == s2 + case (let .partialRangeThrough(r1, s1), let .partialRangeThrough(r2, s2)): + return r1.upperBound == r2.upperBound && s1 == s2 + default: return false + } + } +} + +public protocol TensorRangeExpression { + var tensorRange: TensorRange { get } +} + +// TODO: Cannot extend non-nominal type 'UnboundedRange'. +// extension UnboundedRange: TensorRangeExpression { +// public var tensorRange: TensorRange { return .ellipsis } +// } + +extension Int: TensorRangeExpression { + public var tensorRange: TensorRange { return .index(self) } +} + +extension Range: TensorRangeExpression where Bound == Int { + public var tensorRange: TensorRange { + return .range(self, stride: 1) + } +} + +extension ClosedRange: TensorRangeExpression where Bound == Int { + public var tensorRange: TensorRange { + return .closedRange(self, stride: 1) + } +} + +extension PartialRangeFrom: TensorRangeExpression where Bound == Int { + public var tensorRange: TensorRange { + return .partialRangeFrom(self, stride: 1) + } +} + +extension PartialRangeUpTo: TensorRangeExpression where Bound == Int { + public var tensorRange: TensorRange { + return .partialRangeUpTo(self, stride: 1) + } +} + +extension PartialRangeThrough: TensorRangeExpression where Bound == Int { + public var tensorRange: TensorRange { + return .partialRangeThrough(self, stride: 1) + } +} + +infix operator ..: StridedRangeFormationPrecedence +precedencegroup StridedRangeFormationPrecedence { + associativity: left + higherThan: CastingPrecedence + lowerThan: RangeFormationPrecedence +} + +public extension Range where Bound == Int { + static func .. (range: Range, stride: Int) -> TensorRange { + return .range(range, stride: stride) + } +} + +public extension ClosedRange where Bound == Int { + static func .. (range: ClosedRange, stride: Int) -> TensorRange { + return .closedRange(range, stride: stride) + } +} + +public extension PartialRangeFrom where Bound == Int { + static func .. (range: PartialRangeFrom, stride: Int) -> TensorRange { + return .partialRangeFrom(range, stride: stride) + } +} + +public extension PartialRangeUpTo where Bound == Int { + static func .. (range: PartialRangeUpTo, stride: Int) -> TensorRange { + return .partialRangeUpTo(range, stride: stride) + } +} + +public extension PartialRangeThrough where Bound == Int { + static func .. (range: PartialRangeThrough, stride: Int) -> TensorRange { + return .partialRangeThrough(range, stride: stride) + } +} + +public extension Tensor { + @_fixed_layout @usableFromInline + internal struct IndexPath { + @usableFromInline + let begin, end, strides: Tensor + + @usableFromInline + let beginMask, endMask, ellipsisMask, newAxisMask, squeezeAxisMask: Int64 + + @inlinable + public init( + begin: Tensor, end: Tensor, strides: Tensor, + beginMask: Int64, endMask: Int64, ellipsisMask: Int64, newAxisMask: Int64, + squeezeAxisMask: Int64 + ) { + self.begin = begin + self.end = end + self.strides = strides + self.beginMask = beginMask + self.endMask = endMask + self.ellipsisMask = ellipsisMask + self.newAxisMask = newAxisMask + self.squeezeAxisMask = squeezeAxisMask + } + } + + @inlinable + @differentiable(wrt: self, vjp: _vjpSubscript) + internal subscript(_ indexPath: IndexPath) -> Tensor { + get { + return Raw.stridedSlice( + self, begin: indexPath.begin, end: indexPath.end, + strides: indexPath.strides, beginMask: indexPath.beginMask, + endMask: indexPath.endMask, ellipsisMask: indexPath.ellipsisMask, + newAxisMask: indexPath.newAxisMask, + shrinkAxisMask: indexPath.squeezeAxisMask) + } + set { + self = Raw.tensorStridedSliceUpdate( + self, begin: indexPath.begin, end: indexPath.end, + strides: indexPath.strides, value: newValue, + beginMask: indexPath.beginMask, endMask: indexPath.endMask, + ellipsisMask: indexPath.ellipsisMask, + newAxisMask: indexPath.newAxisMask, + shrinkAxisMask: indexPath.squeezeAxisMask) + } + } + + @inlinable + // TODO: @differentiable(wrt: self) + subscript(_ ranges: TensorRangeExpression...) -> Tensor { + get { + return self[IndexPath(ranges.map { $0.tensorRange })] + } + set { + self[IndexPath(ranges.map { $0.tensorRange })] = newValue + } + } + + @usableFromInline + internal func _vjpSubscript( + _ indexPath: IndexPath + ) -> (Tensor, (Tensor) -> Tensor) { + return (self[indexPath], { [shape = shapeTensor] v in + Raw.stridedSliceGrad( + shape: shape, begin: indexPath.begin, end: indexPath.end, + strides: indexPath.strides, dy: v, beginMask: indexPath.beginMask, + endMask: indexPath.endMask, ellipsisMask: indexPath.ellipsisMask, + newAxisMask: indexPath.newAxisMask, + shrinkAxisMask: indexPath.squeezeAxisMask) + }) + } +} + +internal extension Tensor.IndexPath { + @inlinable + init(_ ranges: [TensorRange]) { + precondition(!ranges.isEmpty, "The tensor range collection cannot be empty.") + precondition(ranges.count { $0 == TensorRange.ellipsis } < 2, + "Only one ellipsis is allowed per tensor range collection.") + + var begin = [Int32](repeating: 0, count: ranges.count) + var end = [Int32](repeating: 0, count: ranges.count) + var strides = [Int32](repeating: 1, count: ranges.count) + var beginMask: Int64 = 0 + var endMask: Int64 = 0 + var ellipsisMask: Int64 = 0 + var newAxisMask: Int64 = 0 + var squeezeAxisMask: Int64 = 0 + for (i, index) in ranges.enumerated() { + switch index { + case .ellipsis: ellipsisMask |= 1 << i + case .newAxis: newAxisMask |= 1 << i + case .squeezeAxis: squeezeAxisMask |= 1 << i + case .index(let index): + begin[i] = Int32(index) + end[i] = Int32(index) + 1 + squeezeAxisMask |= 1 << i + case .range(let range, let stride): + begin[i] = Int32(range.lowerBound) + end[i] = Int32(range.upperBound) + strides[i] = Int32(stride) + case .closedRange(let range, let stride): + begin[i] = Int32(range.lowerBound) + switch Int32(range.upperBound) { + case -1: endMask |= 1 << i + case let u: end[i] = u + 1 + } + strides[i] = Int32(stride) + case .partialRangeFrom(let range, let stride): + begin[i] = Int32(range.lowerBound) + strides[i] = Int32(stride) + endMask |= 1 << i + case .partialRangeUpTo(let range, let stride): + end[i] = Int32(range.upperBound) + strides[i] = Int32(stride) + beginMask |= 1 << i + case .partialRangeThrough(let range, let stride): + end[i] = Int32(range.upperBound) + 1 + strides[i] = Int32(stride) + beginMask |= 1 << i + } + } + + self.begin = Tensor(begin) + self.end = Tensor(end) + self.strides = Tensor(strides) + self.beginMask = beginMask + self.endMask = endMask + self.ellipsisMask = ellipsisMask + self.newAxisMask = newAxisMask + self.squeezeAxisMask = squeezeAxisMask + } +} diff --git a/Sources/DeepLearning/Operators/Comparison.swift b/Sources/DeepLearning/Operators/Comparison.swift new file mode 100644 index 000000000..02bf5fadf --- /dev/null +++ b/Sources/DeepLearning/Operators/Comparison.swift @@ -0,0 +1,237 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if !COMPILING_TENSORFLOW_MODULE +import TensorFlow +#endif + +infix operator .<: ComparisonPrecedence +infix operator .<=: ComparisonPrecedence +infix operator .>=: ComparisonPrecedence +infix operator .>: ComparisonPrecedence +infix operator .==: ComparisonPrecedence +infix operator .!=: ComparisonPrecedence + +public extension Tensor where Scalar: Numeric & Comparable { + /// Computes `lhs < rhs` element-wise and returns a `Tensor` of Boolean /// scalars. + @inlinable + static func .< (lhs: Tensor, rhs: Tensor) -> Tensor { + return Raw.less(lhs, rhs) + } + + /// Computes `lhs <= rhs` element-wise and returns a `Tensor` of Boolean scalars. + @inlinable + static func .<= (lhs: Tensor, rhs: Tensor) -> Tensor { + return Raw.lessEqual(lhs, rhs) + } + + /// Computes `lhs > rhs` element-wise and returns a `Tensor` of Boolean scalars. + @inlinable + static func .> (lhs: Tensor, rhs: Tensor) -> Tensor { + return Raw.greater(lhs, rhs) + } + + /// Computes `lhs >= rhs` element-wise and returns a `Tensor` of Boolean scalars. + @inlinable + static func .>= (lhs: Tensor, rhs: Tensor) -> Tensor { + return Raw.greaterEqual(lhs, rhs) + } + + /// Computes `lhs < rhs` element-wise and returns a `Tensor` of Boolean scalars. + /// - Note: `.<` supports broadcasting. + @inlinable + static func .< (lhs: Scalar, rhs: Tensor) -> Tensor { + return Raw.less(Tensor(lhs), rhs) + } + + /// Computes `lhs <= rhs` element-wise and returns a `Tensor` of Boolean scalars. + /// - Note: `.<=` supports broadcasting. + @inlinable + static func .<= (lhs: Scalar, rhs: Tensor) -> Tensor { + return Raw.lessEqual(Tensor(lhs), rhs) + } + + /// Computes `lhs > rhs` element-wise and returns a `Tensor` of Boolean scalars. + /// - Note: `.>` supports broadcasting. + @inlinable + static func .> (lhs: Scalar, rhs: Tensor) -> Tensor { + return Raw.greater(Tensor(lhs), rhs) + } + + /// Computes `lhs >= rhs` element-wise and returns a `Tensor` of Boolean scalars. + /// - Note: `.>=` supports broadcasting. + @inlinable + static func .>= (lhs: Scalar, rhs: Tensor) -> Tensor { + return Raw.greaterEqual(Tensor(lhs), rhs) + } + + /// Computes `lhs < rhs` element-wise and returns a `Tensor` of Boolean scalars. + /// - Note: `.<` supports broadcasting. + @inlinable + static func .< (lhs: Tensor, rhs: Scalar) -> Tensor { + return Raw.less(lhs, Tensor(rhs)) + } + + /// Computes `lhs <= rhs` element-wise and returns a `Tensor` of Boolean scalars. + /// - Note: `.<=` supports broadcasting. + @inlinable + static func .<= (lhs: Tensor, rhs: Scalar) -> Tensor { + return Raw.lessEqual(lhs, Tensor(rhs)) + } + + /// Computes `lhs > rhs` element-wise and returns a `Tensor` of Boolean scalars. + /// - Note: `.>` supports broadcasting. + @inlinable + static func .> (lhs: Tensor, rhs: Scalar) -> Tensor { + return Raw.greater(lhs, Tensor(rhs)) + } + + /// Computes `lhs >= rhs` element-wise and returns a `Tensor` of Boolean scalars. + /// - Note: `.>=` supports broadcasting. + @inlinable + static func .>= (lhs: Tensor, rhs: Scalar) -> Tensor { + return Raw.greaterEqual(lhs, Tensor(rhs)) + } +} + +extension Tensor: Equatable where Scalar: Equatable { + @inlinable + public static func == (lhs: Tensor, rhs: Tensor) -> Bool { + return (lhs .== rhs).all() + } + + @inlinable + public static func != (lhs: Tensor, rhs: Tensor) -> Bool { + return (lhs .== rhs).any() + } +} + +extension Tensor: Comparable where Scalar: Numeric & Comparable { + /// Returns a Boolean value indicating whether the value of the first argument is + /// lexicographically less than that of the second argument. + @inlinable + public static func < (lhs: Tensor, rhs: Tensor) -> Bool { + return (lhs .< rhs).all() + } + + /// Returns a Boolean value indicating whether the value of the first argument is + /// lexicographically less than or equal to that of the second argument. + @inlinable + public static func <= (lhs: Tensor, rhs: Tensor) -> Bool { + return (lhs .<= rhs).all() + } + + /// Returns a Boolean value indicating whether the value of the first argument is + /// lexicographically greater than that of the second argument. + @inlinable + public static func > (lhs: Tensor, rhs: Tensor) -> Bool { + return (lhs .> rhs).all() + } + + /// Returns a Boolean value indicating whether the value of the first argument is + /// lexicographically greater than or equal to that of the second argument. + @inlinable + public static func >= (lhs: Tensor, rhs: Tensor) -> Bool { + return (lhs .>= rhs).all() + } +} + +public extension Tensor where Scalar: Numeric & Comparable { + /// Returns a Boolean value indicating whether the value of the first argument is + /// lexicographically less than that of the second argument. + @inlinable + static func < (lhs: Tensor, rhs: Scalar) -> Bool { + return (lhs .< rhs).all() + } + + /// Returns a Boolean value indicating whether the value of the first argument is + /// lexicographically less than or equal to that of the second argument. + @inlinable + static func <= (lhs: Tensor, rhs: Scalar) -> Bool { + return (lhs .<= rhs).all() + } + + /// Returns a Boolean value indicating whether the value of the first argument is + /// lexicographically greater than that of the second argument. + @inlinable + static func > (lhs: Tensor, rhs: Scalar) -> Bool { + return (lhs .> rhs).all() + } + + /// Returns a Boolean value indicating whether the value of the first argument is + /// lexicographically greater than or equal to that of the second argument. + @inlinable + static func >= (lhs: Tensor, rhs: Scalar) -> Bool { + return (lhs .>= rhs).all() + } +} + +public extension Tensor where Scalar: Equatable { + /// Computes `lhs != rhs` element-wise and returns a `Tensor` of Boolean scalars. + /// - Note: `.==` supports broadcasting. + @inlinable + static func .==(lhs: Tensor, rhs: Tensor) -> Tensor { + return Raw.equal(lhs, rhs) + } + + /// Computes `lhs != rhs` element-wise and returns a `Tensor` of Boolean scalars. + /// - Note: `.!=` supports broadcasting. + @inlinable + static func .!=(lhs: Tensor, rhs: Tensor) -> Tensor { + return Raw.notEqual(lhs, rhs) + } + + /// Computes `lhs == rhs` element-wise and returns a `Tensor` of Boolean scalars. + /// - Note: `.==` supports broadcasting. + @inlinable + static func .==(lhs: Scalar, rhs: Tensor) -> Tensor { + return Tensor(lhs) .== rhs + } + + /// Computes `lhs != rhs` element-wise and returns a `Tensor` of Boolean scalars. + /// - Note: `.!=` supports broadcasting. + @inlinable + static func .!=(lhs: Scalar, rhs: Tensor) -> Tensor { + return Tensor(lhs) .!= rhs + } + + /// Computes `lhs == rhs` element-wise and returns a `Tensor` of Boolean + /// scalars. + /// - Note: `.==` supports broadcasting. + @inlinable + static func .==(lhs: Tensor, rhs: Scalar) -> Tensor { + return lhs .== Tensor(rhs) + } + + /// Computes `lhs != rhs` element-wise and returns a `Tensor` of Boolean scalars. + /// - Note: `.!=` supports broadcasting. + @inlinable + static func .!=(lhs: Tensor, rhs: Scalar) -> Tensor { + return lhs .!= Tensor(rhs) + } +} + +// TODO: infix operator ≈: ComparisonPrecedence + +public extension Tensor where Scalar: FloatingPoint & Equatable { + /// Returns a `Tensor` of Boolean values indicating whether the elements of `self` are + /// approximately equal to those of `other`. + @inlinable + func elementsApproximatelyEqual( + _ other: Tensor, + tolerance: Double = 0.00001 + ) -> Tensor { + return Raw.approximateEqual(self, other, tolerance: tolerance) + } +} diff --git a/Sources/DeepLearning/Operators/Math.swift b/Sources/DeepLearning/Operators/Math.swift new file mode 100644 index 000000000..3255aea10 --- /dev/null +++ b/Sources/DeepLearning/Operators/Math.swift @@ -0,0 +1,1505 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if !COMPILING_TENSORFLOW_MODULE +import TensorFlow +#endif + +#if COMPILING_TENSORFLOW_MODULE +infix operator .>: ComparisonPrecedence +infix operator .==: ComparisonPrecedence +#endif + +// TODO: +// - Consider explicit broadcasting for elementwise binary ops when +// scalarization and rank getter are implemented. + +//===------------------------------------------------------------------------------------------===// +// Additive Group +//===------------------------------------------------------------------------------------------===// + +extension Tensor: AdditiveArithmetic where Scalar: Numeric { + /// A scalar zero tensor. + @inlinable + public static var zero: Tensor { + return Tensor(zeros: []) + } + + /// Adds two tensors and produces their sum. + /// - Note: `+` supports broadcasting. + @inlinable + @differentiable(vjp: _vjpAdd(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) + public static func + (lhs: Tensor, rhs: Tensor) -> Tensor { + return Raw.add(lhs, rhs) + } + + /// Subtracts one tensor from another and produces their difference. + /// - Note: `-` supports broadcasting. + @inlinable + @differentiable(vjp: _vjpSubtract(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) + public static func - (lhs: Tensor, rhs: Tensor) -> Tensor { + return Raw.sub(lhs, rhs) + } +} + +internal extension Tensor where Scalar: TensorFlowFloatingPoint { + @inlinable + static func _vjpAdd(lhs: Tensor, rhs: Tensor) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { + return (lhs + rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in + (v.unbroadcast(toShape: lhsShape), v.unbroadcast(toShape: rhsShape)) + }) + } + + @inlinable + static func _vjpSubtract(lhs: Tensor, rhs: Tensor) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { + return (lhs - rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in + (v.unbroadcast(toShape: lhsShape), -v.unbroadcast(toShape: rhsShape)) + }) + } +} + +//===------------------------------------------------------------------------------------------===// +// Vector Space +//===------------------------------------------------------------------------------------------===// + +extension Tensor: VectorNumeric where Scalar: Numeric { + /// Multiplies the scalar with every scalar of the tensor and produces the product. + @inlinable + @differentiable(vjp: _vjpMultiply(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) + public static func * (lhs: Scalar, rhs: Tensor) -> Tensor { + return Tensor(lhs) * rhs + } +} + +internal extension Tensor where Scalar: TensorFlowFloatingPoint { + @inlinable + static func _vjpMultiply(lhs: Tensor, rhs: Tensor) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { + return (lhs * rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in + ((rhs * v).unbroadcast(toShape: lhsShape), (lhs * v).unbroadcast(toShape: rhsShape)) + }) + } +} + +extension Tensor: ShapedVectorNumeric where Scalar: Numeric {} + +extension Tensor: Differentiable where Scalar: TensorFlowFloatingPoint { + public typealias TangentVector = Tensor + public typealias CotangentVector = Tensor + public typealias AllDifferentiableVariables = Tensor + + @inlinable + public func tangentVector(from cotangent: CotangentVector) -> TangentVector { + return cotangent + } +} + +//===------------------------------------------------------------------------------------------===// +// Additional Element-wise Operators +//===------------------------------------------------------------------------------------------===// + +public extension Tensor where Scalar: Numeric { + /// Adds the scalar to every scalar of the tensor and produces the sum. + @inlinable + @differentiable(vjp: _vjpAdd(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) + static func + (lhs: Scalar, rhs: Tensor) -> Tensor { + return Tensor(lhs) + rhs + } + + /// Adds the scalar to every scalar of the tensor and produces the sum. + @inlinable + @differentiable(vjp: _vjpAdd(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) + static func + (lhs: Tensor, rhs: Scalar) -> Tensor { + return lhs + Tensor(rhs) + } + + /// Subtracts the scalar from every scalar of the tensor and produces the difference. + @inlinable + @differentiable(vjp: _vjpSubtract(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) + static func - (lhs: Scalar, rhs: Tensor) -> Tensor { + return Tensor(lhs) - rhs + } + + /// Subtracts the scalar from every scalar of the tensor and produces the difference + @inlinable + @differentiable(vjp: _vjpSubtract(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) + static func - (lhs: Tensor, rhs: Scalar) -> Tensor { + return lhs - Tensor(rhs) + } + + /// Adds two tensors and stores the result in the left-hand-side variable. + /// - Note: `+=` supports broadcasting. + @inlinable + static func += (lhs: inout Tensor, rhs: Tensor) { + lhs = lhs + rhs + } + + /// Adds the scalar to every scalar of the tensor and stores the result in the left-hand-side + /// variable. + @inlinable + static func += (lhs: inout Tensor, rhs: Scalar) { + lhs = lhs + rhs + } + + /// Subtracts the second tensor from the first and stores the result in the left-hand-side + /// variable. + /// - Note: `-=` supports broadcasting. + @inlinable + static func -= (lhs: inout Tensor, rhs: Tensor) { + lhs = lhs - rhs + } + + /// Subtracts the scalar from every scalar of the tensor and stores the result in the + /// left-hand-side variable. + @inlinable + static func -= (lhs: inout Tensor, rhs: Scalar) { + lhs = lhs - rhs + } + + /// Multiplies two tensors and produces their product. + /// - Note: `*` supports broadcasting. + @inlinable + @differentiable(vjp: _vjpMultiply(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) + static func * (lhs: Tensor, rhs: Tensor) -> Tensor { + return Raw.mul(lhs, rhs) + } + + /// Multiplies the scalar with every scalar of the tensor and produces the product. + @inlinable + @differentiable(vjp: _vjpMultiply(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) + static func * (lhs: Tensor, rhs: Scalar) -> Tensor { + return lhs * Tensor(rhs) + } + + /// Multiplies two tensors and stores the result in the left-hand-side variable. + /// - Note: `*=` supports broadcasting. + @inlinable + static func *= (lhs: inout Tensor, rhs: Tensor) { + lhs = lhs * rhs + } + + /// Multiplies the tensor with the scalar, broadcasting the scalar, and stores the result in the + /// left-hand-side variable. + @inlinable + static func *= (lhs: inout Tensor, rhs: Scalar) { + lhs = lhs * rhs + } + + /// Returns the quotient of dividing the first tensor by the second. + /// - Note: `/` supports broadcasting. + @inlinable + @differentiable(vjp: _vjpDivide(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) + static func / (lhs: Tensor, rhs: Tensor) -> Tensor { + return Raw.div(lhs, rhs) + } + + /// Returns the quotient of dividing the scalar by the tensor, broadcasting the scalar. + @inlinable + @differentiable(vjp: _vjpDivide(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) + static func / (lhs: Scalar, rhs: Tensor) -> Tensor { + return Tensor(lhs) / rhs + } + + /// Returns the quotient of dividing the tensor by the scalar, broadcasting the scalar. + @inlinable + @differentiable(vjp: _vjpDivide(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) + static func / (lhs: Tensor, rhs: Scalar) -> Tensor { + return lhs / Tensor(rhs) + } + + /// Divides the first tensor by the second and stores the quotient in the left-hand-side + /// variable. + @inlinable + static func /= (lhs: inout Tensor, rhs: Tensor) { + lhs = lhs / rhs + } + + /// Divides the tensor by the scalar, broadcasting the scalar, and stores the quotient in the + /// left-hand-side variable. + @inlinable + static func /= (lhs: inout Tensor, rhs: Scalar) { + lhs = lhs / rhs + } + + /// Returns the remainder of dividing the first tensor by the second. + /// - Note: `%` supports broadcasting. + @inlinable + static func % (lhs: Tensor, rhs: Tensor) -> Tensor { + return Raw.mod(lhs, rhs) + } + + /// Returns the remainder of dividing the tensor by the scalar, broadcasting the scalar. + @inlinable + static func % (lhs: Tensor, rhs: Scalar) -> Tensor { + return lhs % Tensor(rhs) + } + + /// Returns the remainder of dividing the scalar by the tensor, broadcasting the scalar. + @inlinable + static func % (lhs: Scalar, rhs: Tensor) -> Tensor { + return Tensor(lhs) % rhs + } + + /// Divides the first tensor by the second and stores the remainder in the left-hand-side + /// variable. + @inlinable + static func %= (lhs: inout Tensor, rhs: Tensor) { + lhs = lhs % rhs + } + + /// Divides the tensor by the scalar and stores the remainder in the left-hand-side variable. + @inlinable + static func %= (lhs: inout Tensor, rhs: Scalar) { + lhs = lhs % rhs + } +} + +internal extension Tensor where Scalar: TensorFlowFloatingPoint { + @inlinable + static func _vjpAdd(lhs: Tensor, rhs: Scalar) -> (Tensor, (Tensor) -> (Tensor, Scalar)) { + return (lhs + rhs, { v in (v, v.sum().scalarized()) }) + } + + @inlinable + static func _vjpAdd(lhs: Scalar, rhs: Tensor) -> (Tensor, (Tensor) -> (Scalar, Tensor)) { + return (lhs + rhs, { v in (v.sum().scalarized(), v) }) + } + + @inlinable + static func _vjpSubtract(lhs: Tensor, rhs: Scalar) -> (Tensor, (Tensor) -> (Tensor, Scalar)) { + return (lhs - rhs, { v in (v, 0 - v.sum().scalarized()) }) + } + + @inlinable + static func _vjpSubtract(lhs: Scalar, rhs: Tensor) -> (Tensor, (Tensor) -> (Scalar, Tensor)) { + return (lhs - rhs, { v in (v.sum().scalarized(), 0 - v) }) + } + + @inlinable + static func _vjpMultiply(lhs: Tensor, rhs: Scalar) -> (Tensor, (Tensor) -> (Tensor, Scalar)) { + return (lhs * rhs, { v in (v * rhs, (v * lhs).sum().scalarized()) }) + } + + @inlinable + static func _vjpMultiply(lhs: Scalar, rhs: Tensor) -> (Tensor, (Tensor) -> (Scalar, Tensor)) { + return (lhs * rhs, { v in ((v * rhs).sum().scalarized(), v * lhs) }) + } + + @inlinable + static func _vjpDivide(lhs: Tensor, rhs: Tensor) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { + return (lhs / rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in + ((v / rhs).unbroadcast(toShape: lhsShape), + ((-lhs) / rhs.squared() * v).unbroadcast(toShape: rhsShape)) + }) + } + + @inlinable + static func _vjpDivide(lhs: Tensor, rhs: Scalar) -> (Tensor, (Tensor) -> (Tensor, Scalar)) { + return (lhs / rhs, { v in + (v / rhs, (v * (0 - lhs) / Tensor(rhs).squared()).sum().scalarized()) + }) + } + + @inlinable + static func _vjpDivide(lhs: Scalar, rhs: Tensor) -> (Tensor, (Tensor) -> (Scalar, Tensor)) { + return (lhs / rhs, { v in ((v / rhs).sum().scalarized(), v * -lhs / rhs.squared()) }) + } +} + +public extension Tensor where Scalar == Bool { + /// Computes `!self` element-wise. + @inlinable + func elementsLogicalNot() -> Tensor { + return Raw.logicalNot(self) + } + + /// Computes `self && other` element-wise. + /// - Note: `&&` supports broadcasting. + @inlinable + func elementsLogicalAnd(_ other: Tensor) -> Tensor { + return Raw.logicalAnd(self, other) + } + + /// Computes `self && other` element-wise, broadcasting `other`. + @inlinable + func elementsLogicalAnd(_ other: Scalar) -> Tensor { + return elementsLogicalAnd(Tensor(other)) + } + + /// Computes `self || other` element-wise. + @inlinable + func elementsLogicalOr(_ other: Tensor) -> Tensor { + return Raw.logicalOr(self, other) + } + + /// Computes `self || other` element-wise, broadcasting `other`. + @inlinable + func elementsLogicalOr(_ other: Scalar) -> Tensor { + return elementsLogicalOr(Tensor(other)) + } +} + +//===------------------------------------------------------------------------------------------===// +// Element-wise Unary Math Functions +//===------------------------------------------------------------------------------------------===// + +// Export Glibc/Darwin math functions. We should not require users to import +// Foundation/Darwin/Glibc in order to use scalar math functions. +// +#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) +@_exported import Darwin.C +#else +@_exported import Glibc +#endif +// +// FIXME(rxwei): Scoped imports are not yet supported in parseable module +// interfaces, so `@_exported import` won't work. When that becomes supported, +// switch to `@_exported import` by removing `import Darwin.C/Glibc` above and +// uncommenting the following lines. In the meantime, consider using indirect +// wrappers for each function so that random libc symbols won't be leaked to +// users' code completion. +// +// #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) +// @_exported import func Darwin.C.sin +// @_exported import func Darwin.C.cos +// @_exported import func Darwin.C.tan +// @_exported import func Darwin.C.sinf +// @_exported import func Darwin.C.cosf +// @_exported import func Darwin.C.tanf +// @_exported import func Darwin.C.sinh +// @_exported import func Darwin.C.cosh +// @_exported import func Darwin.C.tanh +// @_exported import func Darwin.C.sinhf +// @_exported import func Darwin.C.coshf +// @_exported import func Darwin.C.tanhf +// @_exported import func Darwin.C.log +// @_exported import func Darwin.C.logf +// @_exported import func Darwin.C.exp +// @_exported import func Darwin.C.expf +// @_exported import func Darwin.C.pow +// @_exported import func Darwin.C.powf +// #else +// @_exported import func Glibc.sin +// @_exported import func Glibc.cos +// @_exported import func Glibc.tan +// @_exported import func Glibc.sinf +// @_exported import func Glibc.cosf +// @_exported import func Glibc.tanf +// @_exported import func Glibc.sinh +// @_exported import func Glibc.cosh +// @_exported import func Glibc.tanh +// @_exported import func Glibc.sinhf +// @_exported import func Glibc.coshf +// @_exported import func Glibc.tanhf +// @_exported import func Glibc.log +// @_exported import func Glibc.logf +// @_exported import func Glibc.exp +// @_exported import func Glibc.expf +// @_exported import func Glibc.pow +// @_exported import func Glibc.powf +// #endif + +public extension Tensor where Scalar: SignedNumeric { + /// Computes the negation of the specified tensor element-wise. + @inlinable + @differentiable(vjp: _vjpNegate(_:) where Scalar: TensorFlowFloatingPoint) + static prefix func - (rhs: Tensor) -> Tensor { + return Raw.neg(rhs) + } +} + +internal extension Tensor where Scalar: TensorFlowFloatingPoint { + @inlinable + static func _vjpNegate(_ x: Tensor) -> (Tensor, (Tensor) -> Tensor) { + return (-x, { v in -v }) + } +} + +/// Computes the absolute value of the specified tensor element-wise. +@inlinable +@differentiable(vjp: _vjpAbs(_:) where T: TensorFlowFloatingPoint) +public func abs(_ x: Tensor) -> Tensor { + return Raw.abs(x) +} + +@inlinable +internal func _vjpAbs( + _ x: Tensor +) -> (Tensor, (Tensor) -> Tensor) { + let sign = Raw.sign(x) + return (abs(x), { v in v * sign }) +} + +/// Computes the natural logarithm of the specified tensor element-wise. +@inlinable +@differentiable(vjp: _vjpLog(_:) where T: TensorFlowFloatingPoint) +public func log(_ x: Tensor) -> Tensor { + return Raw.log(x) +} + +@inlinable +internal func _vjpLog( + _ x: Tensor +) -> (Tensor, (Tensor) -> Tensor) { + return (log(x), { v in v / x }) +} + +/// Computes `sin` of the specified tensor element-wise. +@inlinable +@differentiable(vjp: _vjpSin(_:) where T: TensorFlowFloatingPoint) +public func sin(_ x: Tensor) -> Tensor { + return Raw.sin(x) +} + +@inlinable +internal func _vjpSin( + _ x: Tensor +) -> (Tensor, (Tensor) -> Tensor) { + return (sin(x), { v in v * cos(x) }) +} + +/// Computes `cos` of the specified tensor element-wise. +@inlinable +@differentiable(vjp: _vjpCos(_:) where T: TensorFlowFloatingPoint) +public func cos(_ x: Tensor) -> Tensor { + return Raw.cos(x) +} + +@inlinable +internal func _vjpCos( + _ x: Tensor +) -> (Tensor, (Tensor) -> Tensor) { + return (cos(x), { v in -v * sin(x) }) +} + +/// Computes `tan` of the specified tensor element-wise. +@inlinable +@differentiable(vjp: _vjpTan(_:) where T: TensorFlowFloatingPoint) +public func tan(_ x: Tensor) -> Tensor { + return Raw.tan(x) +} + +@inlinable +internal func _vjpTan( + _ x: Tensor +) -> (Tensor, (Tensor) -> Tensor) { + let value = tan(x) + return (value, { v in v * (1 + value.squared()) }) +} + +/// Computes `sinh` of the specified tensor element-wise. +@inlinable +@differentiable(vjp: _vjpSinh(_:) where T: TensorFlowFloatingPoint) +public func sinh(_ x: Tensor) -> Tensor { + return Raw.sinh(x) +} + +@inlinable +internal func _vjpSinh( + _ x: Tensor +) -> (Tensor, (Tensor) -> Tensor) { + return (sinh(x), { v in v * cosh(x) }) +} + +/// Computes `cosh` of the specified tensor element-wise. +@inlinable +@differentiable(vjp: _vjpCosh(_:) where T: TensorFlowFloatingPoint) +public func cosh(_ x: Tensor) -> Tensor { + return Raw.cosh(x) +} + +@inlinable +internal func _vjpCosh( + _ x: Tensor +) -> (Tensor, (Tensor) -> Tensor) { + return (cosh(x), { v in v * sinh(x) }) +} + +/// Computes `tanh` of the specified tensor element-wise. +@inlinable +@differentiable(vjp: _vjpTanh(_:) where T: TensorFlowFloatingPoint) +public func tanh(_ x: Tensor) -> Tensor { + return Raw.tanh(x) +} + +@inlinable +internal func _vjpTanh( + _ x: Tensor +) -> (Tensor, (Tensor) -> Tensor) { + let value = tanh(x) + return (value, { v in v * (1 - value.squared()) }) +} + +/// Computes the square of the tensor. +public extension Tensor where Scalar: Numeric { + @inlinable + @differentiable(wrt: self, vjp: _vjpSquared() where Scalar: TensorFlowFloatingPoint) + func squared() -> Tensor { + return Raw.square(self) + } +} + +internal extension Tensor where Scalar: TensorFlowFloatingPoint { + @inlinable + func _vjpSquared() -> (Tensor, (Tensor) -> Tensor) { + return (squared(), { 2 * self * $0 }) + } +} + +/// Computes the square root of the specified tensor element-wise. +@inlinable +@differentiable(vjp: _vjpSqrt(_:) where T: TensorFlowFloatingPoint) +public func sqrt(_ x: Tensor) -> Tensor { + return Raw.sqrt(x) +} + +@inlinable +internal func _vjpSqrt( + _ x: Tensor +) -> (Tensor, (Tensor) -> Tensor) { + let value = sqrt(x) + return (value, { v in v / (2 * value) }) +} + +/// Computes the inverse square root of the specified tensor element-wise. +@inlinable +@differentiable(vjp: _vjpRsqrt(_:) where T: TensorFlowFloatingPoint) +public func rsqrt(_ x: Tensor) -> Tensor { + return Raw.rsqrt(x) +} + +@inlinable +internal func _vjpRsqrt( + _ x: Tensor +) -> (Tensor, (Tensor) -> Tensor) { + let value = rsqrt(x) + return (value, { v in -v / 2 * value }) +} + +/// Computes `exp` of the specified tensor element-wise. +@inlinable +@differentiable(vjp: _vjpExp(_:) where T: TensorFlowFloatingPoint) +public func exp(_ x: Tensor) -> Tensor { + return Raw.exp(x) +} + +@inlinable +internal func _vjpExp( + _ x: Tensor +) -> (Tensor, (Tensor) -> Tensor) { + let value = exp(x) + return (value, { v in value * v }) +} + +/// Computes the ceiling of the specified tensor element-wise. +@inlinable +@differentiable(vjp: _vjpCeil(_:) where T: TensorFlowFloatingPoint) +public func ceil(_ x: Tensor) -> Tensor { + return Raw.ceil(x) +} + +@inlinable +internal func _vjpCeil( + _ x: Tensor +) -> (Tensor, (Tensor) -> Tensor) { + return (ceil(x), { _ in Tensor(0).broadcast(like: x) }) +} + +/// Computes the floor of the specified tensor element-wise. +@inlinable +@differentiable(vjp: _vjpFloor(_:) where T: TensorFlowFloatingPoint) +public func floor(_ x: Tensor) -> Tensor { + return Raw.floor(x) +} + +@inlinable +internal func _vjpFloor( + _ x: Tensor +) -> (Tensor, (Tensor) -> Tensor) { + return (floor(x), { _ in Tensor(0).broadcast(like: x) }) +} + +/// Computes the sigmoid of the specified tensor element-wise. +/// Specifically, computes `1 / (1 + exp(-x))`. +@inlinable +@differentiable(vjp: _vjpSigmoid) +public func sigmoid(_ x: Tensor) -> Tensor { + return Raw.sigmoid(x) +} + +@inlinable +internal func _vjpSigmoid( + _ x: Tensor +) -> (Tensor, (Tensor) -> Tensor) { + return (sigmoid(x), { v in Raw.sigmoidGrad(x, dy: v) }) +} + +/// Computes the softmax of the specified tensor along the last axis. +/// Specifically, computes `exp(x) / exp(x).sum(alongAxes: -1)`. +@inlinable +@differentiable(vjp: _vjpSoftmax(_:) where T: TensorFlowFloatingPoint) +public func softmax(_ x: Tensor) -> Tensor { + return Raw.softmax(logits: x) +} + +/// Computes the softmax of the specified tensor along the specified axis. +/// Specifically, computes `exp(x) / exp(x).sum(alongAxes: axis)`. +@inlinable +// TODO: [AD]. +public func softmax(_ x: Tensor, alongAxis axis: Int) -> Tensor { + let xExp = exp(x) + return xExp / xExp.sum(alongAxes: Tensor(Int32(axis))) +} + +@inlinable +func _vjpSoftmax( + _ x: Tensor +) -> (Tensor, (Tensor) -> Tensor) { + let value = softmax(x) + return (value, { v in + let sumChannels = (v * value).sum(alongAxes: -1) + return (v - sumChannels) * value + }) +} + +/// Computes the log-softmax of the specified tensor element-wise. +@inlinable +@differentiable(vjp: _vjpLogSoftmax(_:) where T: TensorFlowFloatingPoint) +public func logSoftmax(_ x: Tensor) -> Tensor { + return Raw.logSoftmax(logits: x) +} + +@inlinable +func _vjpLogSoftmax( + _ x: Tensor +) -> (Tensor, (Tensor) -> Tensor) { + let value = logSoftmax(x) + return (value, { v in v - v.sum(alongAxes: -1) * exp(value) }) +} + +/// Computes `relu` of the specified tensor element-wise. +/// Specifically, computes `max(0, x)`. +@inlinable +@differentiable(vjp: _vjpRelu(_:) where T: TensorFlowFloatingPoint) +public func relu(_ x: Tensor) -> Tensor { + return max(0, x) +} + +@inlinable +func _vjpRelu( + _ x: Tensor +) -> (Tensor, (Tensor) -> Tensor) { + return (relu(x), { v in Tensor(x .> 0) * v }) +} + +//===------------------------------------------------------------------------------------------===// +// Element-wise Binary Math Functions +//===------------------------------------------------------------------------------------------===// + +/// Computes the power of the first tensor to the second tensor. +@inlinable +@differentiable(vjp: _vjpPow(_:_:) where T: TensorFlowFloatingPoint) +public func pow(_ lhs: Tensor, _ rhs: Tensor) -> Tensor where T: FloatingPoint { + return Raw.pow(lhs, rhs) +} + +@inlinable +internal func _vjpPow( + _ x: Tensor, _ y: Tensor +) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { + let value = pow(x, y) + return (value, { v in + ((v * y * pow(x, y-1)).unbroadcast(like: x), + (v * log(x) * value).unbroadcast(like: y)) + }) +} + +/// Computes the power of the scalar to the tensor, broadcasting the scalar. +@inlinable +// @differentiable(where T: TensorFlowFloatingPoint) +public func pow(_ lhs: T, _ rhs: Tensor) -> Tensor where T: FloatingPoint { + return pow(Tensor(lhs), rhs) +} + +/// Computes the power of the tensor to the scalar, broadcasting the scalar. +@inlinable +// @differentiable(where T: TensorFlowFloatingPoint) +public func pow(_ lhs: Tensor, _ rhs: T) -> Tensor where T: FloatingPoint { + return pow(lhs, Tensor(rhs)) +} + +/// Computes the element-wise maximum of two tensors. +/// - Note: `max` supports broadcasting. +@inlinable +@differentiable(vjp: _vjpMax(_:_:) where T: TensorFlowFloatingPoint) +public func max(_ lhs: Tensor, _ rhs: Tensor) -> Tensor where T: Numeric & Comparable { + return Raw.maximum(lhs, rhs) +} + +@inlinable +internal func _vjpMax( + _ x: Tensor, _ y: Tensor +) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { + let value = max(x, y) + return (value, { v in _vjpMinMaxHelper(x, y, originalValue: value, vector: v) }) +} + +/// Computes the element-wise maximum of the scalar and the tensor, broadcasting the scalar. +@inlinable +// @differentiable(where T: TensorFlowFloatingPoint) +public func max(_ lhs: T, _ rhs: Tensor) -> Tensor where T: Numeric & Comparable { + return max(Tensor(lhs), rhs) +} + +/// Computes the element-wise maximum of the scalar and the tensor, broadcasting the scalar. +@inlinable +// @differentiable(where T: TensorFlowFloatingPoint) +public func max(_ lhs: Tensor, _ rhs: T) -> Tensor where T: Numeric & Comparable { + return max(lhs, Tensor(rhs)) +} + +/// Computes the element-wise minimum of two tensors. +/// - Note: `min` supports broadcasting. +@inlinable +@differentiable(vjp: _vjpMin(_:_:) where T: TensorFlowFloatingPoint) +public func min(_ lhs: Tensor, _ rhs: Tensor) -> Tensor where T: Numeric & Comparable { + return Raw.minimum(lhs, rhs) +} + +@inlinable +internal func _vjpMin( + _ x: Tensor, _ y: Tensor +) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { + let value = min(x, y) + return (value, { v in _vjpMinMaxHelper(x, y, originalValue: value, vector: v) }) +} + +/// Computes the element-wise minimum of the scalar and the tensor, broadcasting the scalar. +@inlinable +// @differentiable(where T: TensorFlowFloatingPoint) +public func min(_ lhs: T, _ rhs: Tensor) -> Tensor where T: Numeric & Comparable { + return min(Tensor(lhs), rhs) +} + +/// Computes the element-wise minimum of the scalar and the tensor, broadcasting the scalar. +@inlinable +// @differentiable(where T: TensorFlowFloatingPoint) +public func min(_ lhs: Tensor, _ rhs: T) -> Tensor where T: Numeric & Comparable { + return min(lhs, Tensor(rhs)) +} + +@inlinable +internal func _vjpMinMaxHelper( + _ x: Tensor, + _ y: Tensor, + originalValue: Tensor, + vector: Tensor +) -> (Tensor, Tensor) { + let denom = 1 + Tensor(x .== y) + let dfdx = vector * Tensor(x .== originalValue) / denom + let dfdy = vector * Tensor(y .== originalValue) / denom + return (dfdx.unbroadcast(like: x), dfdy.unbroadcast(like: y)) +} + +//===------------------------------------------------------------------------------------------===// +// Selection Functions +//===------------------------------------------------------------------------------------------===// + +public extension Tensor where Scalar == Bool { + /// Returns a new tensor containing elements from either `left` or `right`, + /// depending on the elements of `self`. + /// + /// `self` acts as a mask that chooses, based on the value at each scalar, + /// whether the corresponding scalar in the output should be taken from + /// `left` (if `true`) or `right` (if `false`). + /// + /// - Precondition: `left` and `right` must have the same shape. If + /// `left` and `right` are scalar, then `self` must also be scalar. If + /// `left` and `right` have rank greater than or equal to 1, then `self` + /// must be either have the same shape as `left` or be a 1-D `Tensor` such + /// that `self.scalarCount == left[0]`. + @available(*, deprecated, message: "Use '.replacing(with:mask:)' instead") + @inlinable + func selecting(_ left: Tensor, _ right: Tensor) -> Tensor { + return left.replacing(with: right, where: self) + } +} + +public extension Tensor { + /// Replaces elements of this tensor with `other` in the lanes where `mask` is + /// `true`. + /// + /// - Precondition: `self` and `other` must have the same shape. If + /// `self` and `other` are scalar, then `mask` must also be scalar. If + /// `self` and `other` have rank greater than or equal to `1`, then `mask` + /// must be either have the same shape as `self` or be a 1-D `Tensor` such + /// that `mask.scalarCount == self.shape[0]`. + @inlinable + @differentiable(wrt: (self, other), vjp: _vjpReplacing where Scalar: TensorFlowFloatingPoint) + func replacing(with other: Tensor, where mask: Tensor) -> Tensor { + return Raw.select(condition: mask, t: self, e: other) + } +} + +internal extension Tensor where Scalar: TensorFlowFloatingPoint { + @inlinable + func _vjpReplacing( + with other: Tensor, + where mask: Tensor + ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { + return (replacing(with: other, where: mask), { v in + let zeros = Tensor(zeros: v.shape) + return (v.replacing(with: zeros, where: mask), zeros.replacing(with: v, where: mask)) + }) + } +} + +//===------------------------------------------------------------------------------------------===// +// Reduction Functions +//===------------------------------------------------------------------------------------------===// + +public extension Tensor where Scalar == Bool { + /// Returns `true` if all scalars are equal to `true`. Otherwise, returns `false`. + // NOTE: This overload is necessary, otherwise `all()` would refer to the variadic method + // `all(squeezingAxes:)` with zero indices. + @inlinable + func all() -> Bool { + let axes = Tensor(rangeFrom: 0, to: Int32(rank), stride: 1) + return _TFGetScalarOrDie(Raw.all(self, reductionIndices: axes).handle) + } + + /// Returns `true` if any scalars are equal to `true`. Otherwise, returns `false`. + // NOTE: This overload is necessary, otherwise `any()` would refer to the variadic method + // `any(squeezingAxes:)` with zero indices. + @inlinable + func any() -> Bool { + let axes = Tensor(rangeFrom: 0, to: Int32(rank), stride: 1) + return _TFGetScalarOrDie(Raw.any(self, reductionIndices: axes).handle) + } + + /// Performs a logical AND operation along the specified axes. The reduced dimensions are + /// removed. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + let axes = axes.map(Int32.init) + return Raw.all(self, reductionIndices: Tensor(axes), keepDims: false) + } + + /// Performs a logical AND operation along the specified axes. The reduced dimensions are + /// removed. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + let axes = axes.map(Int32.init) + return Raw.any(self, reductionIndices: Tensor(axes), keepDims: false) + } + + /// Performs a logical AND operation along the specified axes. The reduced dimensions are + /// retained with value 1. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + let axes = axes.map(Int32.init) + return Raw.all(self, reductionIndices: Tensor(axes), keepDims: true) + } + + /// Performs a logical OR operation along the specified axes. The reduced + /// dimensions are retained with value 1. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + let axes = axes.map(Int32.init) + return Raw.any(self, reductionIndices: Tensor(axes), keepDims: true) + } +} + +public extension Tensor where Scalar: Numeric & Comparable { + // NOTE: This overload is necessary, otherwise `min()` would refer to the variadic method + // `min(squeezingAxes:)` with zero indices. + @inlinable + func min() -> Tensor { + let axes = Tensor(rangeFrom: 0, to: Int32(rank), stride: 1) + return Raw.min(self, reductionIndices: axes) + } + + // NOTE: This overload is necessary, otherwise `max()` would refer to the variadic method + // `max(squeezingAxes:)` with zero indices. + @inlinable + func max() -> Tensor { + let axes = Tensor(rangeFrom: 0, to: Int32(rank), stride: 1) + return Raw.max(self, reductionIndices: axes) + } + + /// Returns the maximum values along the specified axes. The reduced dimensions are removed. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + let axes = axes.map(Int32.init) + return Raw.max(self, reductionIndices: Tensor(axes), keepDims: false) + } + + /// Returns the maximum values along the specified axes. The reduced dimensions are removed. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + return max(squeezingAxes: axes) + } + + /// Returns the minimum values along the specified axes. The reduced dimensions are removed. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + let axes = axes.map(Int32.init) + return Raw.min(self, reductionIndices: Tensor(axes), keepDims: false) + } + + /// Returns the minimum values along the specified axes. The reduced dimensions are removed. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + return min(squeezingAxes: axes) + } + + /// Returns the indices of the maximum values along the specified axes. The reduced dimensions + /// are removed. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + return Raw.argMax(self, dimension: Tensor(Int32(axis))) + } + + /// Returns the indices of the minimum values along the specified axes. The reduced dimensions + /// are removed. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + return Raw.argMin(self, dimension: Tensor(Int32(axis))) + } + + /// Returns the minimum along the specified axes. The reduced dimensions are retained with + /// value 1. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + let axes = axes.map(Int32.init) + return Raw.min(self, reductionIndices: Tensor(axes), keepDims: true) + } + + /// Returns the minimum along the specified axes. The reduced dimensions are retained with + /// value 1. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + return min(alongAxes: axes) + } + + /// Returns the minimum along the specified axes. The reduced dimensions are retained with + /// value 1. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + let axes = axes.map(Int32.init) + return Raw.max(self, reductionIndices: Tensor(axes), keepDims: true) + } + + /// Returns the minimum along the specified axes. The reduced dimensions are retained with + /// value 1. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + return max(alongAxes: axes) + } + + /// Returns the index of the maximum value of the flattened scalars. + @inlinable + func argmax() -> Tensor { + return flattened().argmax(squeezingAxis: 0) + } + + /// Returns the index of the minimum value of the flattened scalars. + @inlinable + func argmin() -> Tensor { + return flattened().argmin(squeezingAxis: 0) + } +} + +// MARK: - Numeric Reductions + +public extension Tensor where Scalar: Numeric { + // MARK: - Sum + + /// Returns the sum along the specified axes. The reduced dimensions are removed. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank...rank`. + @inlinable + @differentiable(wrt: self, vjp: _vjpSum(squeezingAxes:) where Scalar: TensorFlowFloatingPoint) + func sum(squeezingAxes axes: Tensor) -> Tensor { + return Raw.sum(self, reductionIndices: Tensor(axes), keepDims: false) + } + + /// Returns the sum along the specified axes. The reduced dimensions are removed. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank...rank`. + @inlinable + @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) + func sum(squeezingAxes axes: [Int]) -> Tensor { + // TODO(TF-433): Remove workaround for differentiating `map`. + let axes = {axes.map(Int32.init)}() + return sum(squeezingAxes: Tensor(axes)) + } + + /// Returns the sum along the specified axes. The reduced dimensions are removed. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank...rank`. + @inlinable + @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) + func sum(squeezingAxes axes: Int...) -> Tensor { + return sum(squeezingAxes: axes) + } + + @inlinable + @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) + func sum() -> Tensor { + return flattened().sum(squeezingAxes: 0) + } + + /// Returns the sum along the specified axes. The reduced dimensions are retained with value 1. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank..) -> Tensor { + return Raw.sum(self, reductionIndices: axes, keepDims: true) + } + + /// Returns the sum along the specified axes. The reduced dimensions are retained with value 1. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + // TODO(TF-433): Remove workaround for differentiating `map`. + let axes = {axes.map(Int32.init)}() + return sum(alongAxes: Tensor(axes)) + } + + /// Returns the sum along the specified axes. The reduced dimensions are retained with value 1. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + return sum(alongAxes: axes) + } + + // MARK: - Product + + /// Returns the product along the specified axes. The reduced dimensions are removed. + /// + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank...rank`. + // TODO: Make this @differentiable. + @inlinable + func product(squeezingAxes axes: Tensor) -> Tensor { + return Raw.prod(self, reductionIndices: axes, keepDims: false) + } + + /// Returns the product along the specified axes. The reduced dimensions are removed. + /// + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank...rank`. + @inlinable + func product(squeezingAxes axes: [Int]) -> Tensor { + // TODO(TF-433): Remove workaround for differentiating `map`. + let axes = {axes.map(Int32.init)}() + return product(squeezingAxes: Tensor(axes)) + } + + /// Returns the product along the specified axes. The reduced dimensions are removed. + /// + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank...rank`. + @inlinable + func product(squeezingAxes axes: Int...) -> Tensor { + return product(squeezingAxes: axes) + } + + @inlinable + func product() -> Tensor { + return flattened().product(squeezingAxes: 0) + } + + /// Returns the product along the specified axes. The reduced dimensions are retained with + /// value 1. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank..) -> Tensor { + return Raw.prod(self, reductionIndices: axes, keepDims: true) + } + + /// Returns the product along the specified axes. The reduced dimensions are retained with + /// value 1. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + // TODO(TF-433): Remove workaround for differentiating `map`. + let axes = {axes.map(Int32.init)}() + return product(alongAxes: Tensor(axes)) + } + + /// Returns the product along the specified axes. The reduced dimensions are retained with + /// value 1. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + return product(alongAxes: axes) + } + + // MARK: - Mean + + /// Returns the arithmetic mean along the specified axes. The reduced dimensions are removed. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank...rank`. + @inlinable + @differentiable(wrt: self, vjp: _vjpMean(squeezingAxes:) where Scalar: TensorFlowFloatingPoint) + func mean(squeezingAxes axes: Tensor) -> Tensor { + return Raw.mean(self, reductionIndices: axes, keepDims: false) + } + + /// Returns the arithmetic mean along the specified axes. The reduced dimensions are removed. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank...rank`. + @inlinable + @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) + func mean(squeezingAxes axes: [Int]) -> Tensor { + // TODO(TF-433): Remove workaround for differentiating `map`. + let axes = {axes.map(Int32.init)}() + return mean(squeezingAxes: Tensor(axes)) + } + + /// Returns the arithmetic mean along the specified axes. The reduced dimensions are removed. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank...rank`. + @inlinable + @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) + func mean(squeezingAxes axes: Int...) -> Tensor { + return mean(squeezingAxes: axes) + } + + @inlinable + @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) + func mean() -> Tensor { + return flattened().mean(squeezingAxes: [0]) + } + + /// Returns the arithmetic mean along the specified axes. The reduced dimensions are retained + /// with value 1. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank..) -> Tensor { + return Raw.mean(self, reductionIndices: axes, keepDims: true) + } + + /// Returns the arithmetic mean along the specified axes. The reduced dimensions are retained + /// with value 1. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + // TODO(TF-433): Remove workaround for differentiating `map`. + let axes = {axes.map(Int32.init)}() + return mean(alongAxes: Tensor(axes)) + } + + /// Returns the arithmetic mean along the specified axes. The reduced dimensions are retained + /// with value 1. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + return mean(alongAxes: axes) + } + + // MARK: - Variance + + /// Returns the variance along the specified axes. The reduced dimensions are removed. Does not + /// apply Bessel's correction. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank..) -> Tensor { + let squaredDiff = (self - mean(alongAxes: axes)).squared() + return squaredDiff.mean(squeezingAxes: axes) + } + + /// Returns the variance along the specified axes. The reduced dimensions are removed. Does not + /// apply Bessel's correction. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + // TODO(TF-433): Remove workaround for differentiating `map`. + let axes = {axes.map(Int32.init)}() + return variance(squeezingAxes: Tensor(axes)) + } + + /// Returns the variance along the specified axes. The reduced dimensions are retained with + /// value 1. Does not apply Bessel's correction. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + return variance(squeezingAxes: axes) + } + + @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) + @inlinable + func variance() -> Tensor { + let mean = self.mean() + let squaredDiff = (self - mean).squared() + return squaredDiff.mean() + } + + /// Returns the variance along the specified axes. The reduced dimensions are retained with + /// value 1. Does not apply Bessel's correction. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank..) -> Tensor { + let squaredDiff = (self - mean(alongAxes: axes)).squared() + return squaredDiff.mean(alongAxes: axes) + } + + /// Returns the variance along the specified axes. The reduced dimensions are retained with + /// value 1. Does not apply Bessel's correction. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + // TODO(TF-433): Remove workaround for differentiating `map`. + let axes = {axes.map(Int32.init)}() + return variance(alongAxes: Tensor(axes)) + } + + /// Returns the variance along the specified axes. The reduced dimensions are retained with + /// value 1. Does not apply Bessel's correction. + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + return variance(alongAxes: axes) + } +} + +internal extension Tensor where Scalar: TensorFlowFloatingPoint { + @inlinable + func _vjpSum(alongAxes axes: Tensor) -> (Tensor, (Tensor) -> Tensor) { + let value = sum(alongAxes: axes) + return (value, { [shape = shapeTensor] in $0.broadcast(toShape: shape) }) + } + + @inlinable + func _vjpSum(squeezingAxes axes: Tensor) -> (Tensor, (Tensor) -> Tensor) { + let value = sum(squeezingAxes: axes) + return (value, { [shape = shapeTensor] in + var result = $0 + for i in axes.array.scalars { result = result.expandingShape(at: Int(i)) } + return result.broadcast(toShape: shape) + }) + } + + @inlinable + func _vjpMean(alongAxes axes: Tensor) -> (Tensor, (Tensor) -> Tensor) { + let value = mean(alongAxes: axes) + let count = Raw.gather(params: shapeTensor, indices: axes).product() + return (value, { [shape = shapeTensor] in $0.broadcast(toShape: shape) / Tensor(count) }) + } + + @inlinable + func _vjpMean(squeezingAxes axes: Tensor) -> (Tensor, (Tensor) -> Tensor) { + let value = mean(squeezingAxes: axes) + let count = Raw.gather(params: shapeTensor, indices: axes).product() + return (value, { [shape = shapeTensor] in + var result = $0 + for i in axes.array.scalars { result = result.expandingShape(at: Int(i)) } + return result.broadcast(toShape: shape) / Tensor(count) + }) + } +} + +// TODO: Consider making the return type be generic over `FloatingPoint` types +// so that `self`'s scalar type can be any `Numeric` type. +public extension Tensor where Scalar: TensorFlowFloatingPoint { + /// Returns the standard deviation of the elements along the specified axes. The reduced + /// dimensions are retained with value `1`. Does not apply Bessel's correction. + /// + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank..) -> Tensor { + return sqrt(variance(squeezingAxes: axes)) + } + + /// Returns the standard deviation of the elements along the specified axes. The reduced + /// dimensions are retained with value `1`. Does not apply Bessel's correction. + /// + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + return sqrt(variance(squeezingAxes: axes)) + } + + /// Returns the standard deviation of the elements along the specified axes. The reduced + /// dimensions are retained with value `1`. Does not apply Bessel's correction. + /// + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + return standardDeviation(squeezingAxes: axes) + } + + /// Returns the standard deviation of the elements along the specified axes. The reduced + /// dimensions are retained with value `1`. Does not apply Bessel's correction. + /// + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + // Reduce along all dimensions. + return standardDeviation(squeezingAxes: Array(0..) -> Tensor { + return sqrt(variance(alongAxes: axes)) + } + + /// Returns the standard deviation of the elements along the specified axes. The reduced + /// dimensions are retained with value `1`. Does not apply Bessel's correction. + /// + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + // TODO(TF-433): Remove workaround for differentiating `map`. + let axes = {axes.map(Int32.init)}() + return standardDeviation(alongAxes: Tensor(axes)) + } + + /// Returns the standard deviation of the elements along the specified axes. The reduced + /// dimensions are retained with value `1`. Does not apply Bessel's correction. + /// + /// - Parameter axes: The dimensions to reduce. + /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { + return sqrt(variance(alongAxes: axes)) + } +} + +//===------------------------------------------------------------------------------------------===// +// Linear Algebra +//===------------------------------------------------------------------------------------------===// + +/// Performs matrix multiplication with another tensor and produces the result. +@inlinable +@differentiable(vjp: _vjpMatmul(_:_:) where Scalar: TensorFlowFloatingPoint) +public func matmul( + _ lhs: Tensor, + _ rhs: Tensor +) -> Tensor { + // Default arguments specified explicitly to avoid "external declarations of SILFunctions with + // shared visibility is not allowed" SILVerifier error in + // "tests/AutoDiff/tensor_autodiff_runtime.swift". + return Raw.matMul(lhs, rhs, transposeA: false, transposeB: false) +} + +@inlinable +internal func _vjpMatmul( + _ lhs: Tensor, + _ rhs: Tensor +) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { + let value = matmul(lhs, rhs) + return (value, { v in + (matmul(v, rhs.transposed()), matmul(lhs.transposed(), v)) + }) +} + +infix operator •: MultiplicationPrecedence + +public extension Tensor where Scalar: Numeric { + // TODO: We have to define a custom VJP on • because AD can't yet differentiate generic methods. + // After AD can differentiate generic methods, remove the custom VJP. + + /// Performs matrix multiplication between two tensors and produces the result. + @inlinable + @differentiable(vjp: _vjpMatmulOperator(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) + static func • (lhs: Tensor, rhs: Tensor) -> Tensor { + return matmul(lhs, rhs) + } +} + +// TODO: We have to define a custom VJP on • because AD can't yet +// differentiate generic methods. After AD can differentiate generic methods, +// remove the custom VJP. +internal extension Tensor where Scalar: TensorFlowFloatingPoint { + @inlinable + static func _vjpMatmulOperator( + lhs: Tensor, + rhs: Tensor + ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { + return _vjpMatmul(lhs, rhs) + } +} diff --git a/Sources/DeepLearning/Operators.swift b/Sources/DeepLearning/Operators/NN.swift similarity index 78% rename from Sources/DeepLearning/Operators.swift rename to Sources/DeepLearning/Operators/NN.swift index 918a197ae..c461a910b 100644 --- a/Sources/DeepLearning/Operators.swift +++ b/Sources/DeepLearning/Operators/NN.swift @@ -16,22 +16,36 @@ import TensorFlow #endif -/// Returns the values of the specified tensor rounded to the nearest integer, element-wise. -public func round(_ x: Tensor) -> Tensor { - return Raw.round(x) -} - -/// Returns a tensor with the same shape and scalars as the specified tensor. -@differentiable -public func identity(_ x: Tensor) -> Tensor { - return x -} - //===------------------------------------------------------------------------------------------===// // Normalization //===------------------------------------------------------------------------------------------===// public extension Tensor where Scalar: TensorFlowFloatingPoint { + /// Computes the batch normalized tensor along the specified axis. + /// + /// Specifically, returns `(self - mu) / (var + epsilon) * gamma + beta` where `mu` and `var` + /// are respectively the mean and variance of `self` along `axis`. + /// + /// - Parameters: + /// - axis: The batch dimension. + /// - offset: The offset, also known as beta. + /// - scale: The scale, also known as gamma. + /// - epsilon: A small value added to the denominator for numerical stability. + @inlinable + @differentiable(wrt: (self, offset, scale), vjp: _vjpBatchNormalized) + func batchNormalized( + alongAxis axis: Int, + offset: Tensor = Tensor(0), + scale: Tensor = Tensor(1), + epsilon: Scalar = 0.001 + ) -> Tensor { + let mean = self.mean(alongAxes: axis) + let squaredDiff: Tensor = Raw.squaredDifference(self, mean) + let variance = squaredDiff.mean(alongAxes: axis) + let inv = rsqrt(variance + epsilon) * scale + return self * inv + offset - mean * inv + } + // TODO: Verify that these calculations are correct. @inlinable internal func _vjpBatchNormalized( @@ -40,18 +54,15 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { scale: Tensor, epsilon: Scalar ) -> (Tensor, (Tensor) -> (Tensor, Tensor, Tensor)) { - let value = batchNormalized(alongAxis: axis, offset: offset, scale: scale, - epsilon: epsilon) + let value = batchNormalized(alongAxis: axis, offset: offset, scale: scale, epsilon: epsilon) return (value, { v in let mean = self.mean(alongAxes: axis) let squaredDiff: Tensor = Raw.squaredDifference(self, mean) let variance = squaredDiff.mean(alongAxes: axis) - - let diff = self - mean + let diff = self - mean let inv = rsqrt(variance + epsilon) let norm = diff * inv - - let dNorm = v * scale + let dNorm = v * scale let dVariance = -(dNorm * diff).sum(alongAxes: axis) / 2 * pow(inv, -3) let dMean = (-dNorm * inv).sum(alongAxes: axis) + dVariance * (-diff * 2).mean(alongAxes: axis) @@ -80,9 +91,8 @@ public extension Tensor where Scalar: BinaryFloatingPoint { /// stability. @inlinable @differentiable( - wrt: (self, offset, scale), vjp: _vjpBatchNormalized - where Scalar : TensorFlowFloatingPoint - ) + wrt: (self, offset, scale), + vjp: _vjpBatchNormalized where Scalar: TensorFlowFloatingPoint) func batchNormalized( alongAxis axis: Int, offset: Tensor = Tensor(0), @@ -172,14 +182,12 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { _ strides: (Int, Int, Int, Int), _ padding: Padding ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - let value = conv2DBackpropInput(shape: shape, filter: filter, strides: strides, - padding: padding) + let value = conv2DBackpropInput( + shape: shape, filter: filter, strides: strides, padding: padding) return (value, { v in - return ( - self.conv2DBackpropFilter(input: v, filterSizes: shape, strides: strides, - padding: padding), - v.convolved2D(withFilter: filter, strides: strides, padding: padding) - ) + (self.conv2DBackpropFilter( + input: v, filterSizes: shape, strides: strides, padding: padding), + v.convolved2D(withFilter: filter, strides: strides, padding: padding)) }) } @@ -190,14 +198,12 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { _ strides: (Int, Int, Int, Int), _ padding: Padding ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - let value = conv2DBackpropFilter(input: input, filterSizes: filterSizes, - strides: strides, padding: padding) + let value = conv2DBackpropFilter( + input: input, filterSizes: filterSizes, strides: strides, padding: padding) return (value, { v in - return ( - self.conv2DBackpropInput(shape: filterSizes, filter: v, strides: strides, - padding: padding), - input.convolved2D(withFilter: v, strides: strides, padding: padding) - ) + (self.conv2DBackpropInput( + shape: filterSizes, filter: v, strides: strides, padding: padding), + input.convolved2D(withFilter: v, strides: strides, padding: padding)) }) } @@ -207,19 +213,12 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { strides: (Int, Int, Int, Int), padding: Padding ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - let value = convolved2D(withFilter: filter, strides: strides, - padding: padding) + let value = convolved2D(withFilter: filter, strides: strides, padding: padding) return (value, { v in - return ( - v.conv2DBackpropInput( - shape: self.shapeTensor, filter: filter, - strides: strides, padding: padding - ), - v.conv2DBackpropFilter( - input: self, filterSizes: filter.shapeTensor, - strides: strides, padding: padding - ) - ) + (v.conv2DBackpropInput( + shape: self.shapeTensor, filter: filter, strides: strides, padding: padding), + v.conv2DBackpropFilter( + input: self, filterSizes: filter.shapeTensor, strides: strides, padding: padding)) }) } @@ -231,10 +230,9 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { ) -> (Tensor, (Tensor) -> Tensor) { // TODO: Currently this is not higher order differentiable. Redefine in // closed form. - let value = maxPooled(kernelSize: kernelSize, strides: strides, - padding: padding) + let value = maxPooled(kernelSize: kernelSize, strides: strides, padding: padding) return (value, { v in - return Raw.maxPoolGradV2( + Raw.maxPoolGradV2( origInput: self, origOutput: value, grad: v, @@ -242,8 +240,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { Int32(kernelSize.2), Int32(kernelSize.3)]), strides: Tensor([Int32(strides.0), Int32(strides.1), Int32(strides.2), Int32(strides.3)]), - padding: padding.raw - ) + padding: padding.raw) }) } @@ -255,17 +252,15 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { ) -> (Tensor, (Tensor) -> Tensor) { // TODO: Currently this is not higher order differentiable. Redefine in // closed form. - let value = averagePooled(kernelSize: kernelSize, strides: strides, - padding: padding) + let value = averagePooled(kernelSize: kernelSize, strides: strides, padding: padding) return (value, { v in - return Raw.avgPoolGrad( + Raw.avgPoolGrad( origInputShape: self.shapeTensor, grad: v, ksize: [Int32(kernelSize.0), Int32(kernelSize.1), Int32(kernelSize.2), Int32(kernelSize.3)], strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2), Int32(strides.3)], - padding: padding.raw - ) + padding: padding.raw) }) } } @@ -281,11 +276,10 @@ public extension Tensor where Scalar: FloatingPoint { /// - padding: The padding for the operation. /// - Precondition: `self` must have rank 4. /// - Precondition: `filter` must have rank 4. - @inlinable @inline(__always) + @inlinable @differentiable( - wrt: (self, filter), vjp: _vjpConvolved2D - where Scalar: TensorFlowFloatingPoint - ) + wrt: (self, filter), + vjp: _vjpConvolved2D where Scalar: TensorFlowFloatingPoint) func convolved2D( withFilter filter: Tensor, strides: (Int, Int, Int, Int), @@ -307,11 +301,10 @@ public extension Tensor where Scalar: FloatingPoint { /// - strides: The strides of the sliding filter for each dimension of the /// input. /// - padding: The padding for the operation. - @inlinable @inline(__always) + @inlinable @differentiable( - wrt: self, vjp: _vjpMaxPooled(kernelSize:strides:padding:) - where Scalar : TensorFlowFloatingPoint - ) + wrt: self, + vjp: _vjpMaxPooled(kernelSize:strides:padding:) where Scalar: TensorFlowFloatingPoint) func maxPooled( kernelSize: (Int, Int, Int, Int), strides: (Int, Int, Int, Int), @@ -334,11 +327,10 @@ public extension Tensor where Scalar: FloatingPoint { /// - strides: The strides of the sliding filter for each dimension of the /// input. /// - padding: The padding for the operation. - @inlinable @inline(__always) + @inlinable @differentiable( - wrt: self, vjp: _vjpAveragePooled(kernelSize:strides:padding:) - where Scalar : TensorFlowFloatingPoint - ) + wrt: self, + vjp: _vjpAveragePooled(kernelSize:strides:padding:) where Scalar: TensorFlowFloatingPoint) func averagePooled( kernelSize: (Int, Int, Int, Int), strides: (Int, Int, Int, Int), diff --git a/Sources/DeepLearning/PythonConversion.swift b/Sources/DeepLearning/PythonConversion.swift new file mode 100644 index 000000000..5e52548c4 --- /dev/null +++ b/Sources/DeepLearning/PythonConversion.swift @@ -0,0 +1,171 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if !COMPILING_TENSORFLOW_MODULE +import TensorFlow +#endif + +#if canImport(Python) +import Python + +/// The `numpy` Python module. +/// Note: Global variables are lazy, so the following declaration won't produce +// a Python import error until it is first used. +private let np = Python.import("numpy") + +private func debugLogNumpyError(_ message: String) { + debugLog("NumPy conversion error: " + message) +} + +extension ShapedArray: ConvertibleFromNumpyArray + where Scalar: NumpyScalarCompatible { + /// Creates a `ShapedArray` with the same shape and scalars as the specified + /// `numpy.ndarray` instance. + /// + /// - Parameter numpyArray: The `numpy.ndarray` instance to convert. + /// - Precondition: The `numpy` Python package must be installed. + /// - Precondition: `numpyArray` must have a compatible scalar `dtype`. + public init?(numpy numpyArray: PythonObject) { + // Check if input is a `numpy.ndarray` instance. + guard Python.isinstance(numpyArray, np.ndarray) == true else { + debugLogNumpyError(""" + PythonObject input has type '\(Python.type(numpyArray))' and is not \ + an instance of 'numpy.ndarray'. + """) + return nil + } + // Check if the dtype of the `ndarray` is compatible with the `Scalar` + // type. + guard Scalar.numpyScalarTypes.contains(numpyArray.dtype) else { + debugLogNumpyError(""" + 'numpy.ndarray' dtype '\(numpyArray.dtype)' is incompatible with \ + Swift type '\(Scalar.self)'. + """) + return nil + } + + let pyShape = numpyArray.__array_interface__["shape"] + guard let shape = [Int](pyShape) else { + debugLogNumpyError("cannot access shape of 'numpy.ndarray' instance.") + return nil + } + + // Make sure that the array is contiguous in memory. This does a copy if + // the array is not already contiguous in memory. + let contiguousNumpyArray = np.ascontiguousarray(numpyArray) + + guard let ptrVal = + UInt(contiguousNumpyArray.__array_interface__["data"].tuple2.0) else { + debugLogNumpyError("cannot access data of 'numpy.ndarray' instance.") + return nil + } + // Note: `ptr` is not nil even if the `ndarray` is empty (i.e. has a shape + // of `(0,)`). + guard let ptr = UnsafePointer(bitPattern: ptrVal) else { + fatalError("'numpy.ndarray' data pointer was nil") + } + // This code avoids calling `init(shape: [Int], scalars: S)`, + // which inefficiently copies scalars one by one. Instead, + // `init(shape: [Int], scalars: [Scalar])` is called, which efficiently + // does a `memcpy` of the entire `scalars` array. + // Unecessary copying is minimized. + let dummyPointer = UnsafeMutablePointer.allocate(capacity: 1) + let scalarCount = shape.reduce(1, *) + var scalars: [Scalar] = Array(repeating: dummyPointer.move(), count: scalarCount) + dummyPointer.deallocate() + scalars.withUnsafeMutableBufferPointer { buffPtr in + buffPtr.baseAddress!.assign(from: ptr, count: scalarCount) + } + self.init(shape: shape, scalars: scalars) + } +} + +extension Tensor: ConvertibleFromNumpyArray + where Scalar: NumpyScalarCompatible { + /// Creates a tensor with the same shape and scalars as the specified + /// `numpy.ndarray` instance. + /// + /// - Parameter numpyArray: The `numpy.ndarray` instance to convert. + /// - Precondition: The `numpy` Python package must be installed. + /// - Returns: `numpyArray` converted to an `Array`. Returns `nil` if + /// `numpyArray` does not have a compatible scalar `dtype`. + public init?(numpy numpyArray: PythonObject) { + // Check if input is a `numpy.ndarray` instance. + guard Python.isinstance(numpyArray, np.ndarray) == true else { + debugLogNumpyError(""" + PythonObject input has type '\(Python.type(numpyArray))' and is not \ + an instance of 'numpy.ndarray'. + """) + return nil + } + // Check if the dtype of the `ndarray` is compatible with the `Scalar` + // type. + guard Scalar.numpyScalarTypes.contains(numpyArray.dtype) else { + debugLogNumpyError(""" + 'numpy.ndarray' dtype '\(numpyArray.dtype)' is incompatible with \ + Swift type '\(Scalar.self)'. + """) + return nil + } + + let pyShape = numpyArray.__array_interface__["shape"] + guard let dimensions = [Int](pyShape) else { + debugLogNumpyError("cannot access shape of 'numpy.ndarray' instance.") + return nil + } + let shape = TensorShape(dimensions) + + // Make sure that the array is contiguous in memory. This does a copy if + // the array is not already contiguous in memory. + let contiguousNumpyArray = np.ascontiguousarray(numpyArray) + + guard let ptrVal = UInt(contiguousNumpyArray.__array_interface__["data"].tuple2.0) else { + debugLogNumpyError("cannot access data of 'numpy.ndarray' instance.") + return nil + } + // Note: `ptr` is not nil even if the `ndarray` is empty (i.e. has a shape + // of `(0,)`). + guard let ptr = UnsafePointer(bitPattern: ptrVal) else { + fatalError("'numpy.ndarray' data pointer was nil") + } + let buffPtr = UnsafeBufferPointer(start: ptr, count: Int(shape.contiguousSize)) + self.init(shape: shape, scalars: buffPtr) + } +} + +extension ShapedArray where Scalar: NumpyScalarCompatible { + /// Creates a `numpy.ndarray` instance with the same shape and scalars as + /// this `ShapedArray`. + /// + /// - Precondition: The `numpy` Python package must be installed. + public func makeNumpyArray() -> PythonObject { + return scalars.makeNumpyArray().reshape(shape) + } +} + +extension Tensor where Scalar: NumpyScalarCompatible { + /// Creates a `numpy.ndarray` instance with the same shape and scalars as + /// this tensor. + /// + /// - Precondition: The `numpy` Python package must be installed. + public func makeNumpyArray() -> PythonObject { return array.makeNumpyArray() } +} + +extension TensorShape: PythonConvertible { + public var pythonObject: PythonObject { + return dimensions.pythonObject + } +} + +#endif // canImport(Python) diff --git a/Sources/DeepLearning/Random.swift b/Sources/DeepLearning/Random.swift index 44e55223c..ade75a826 100644 --- a/Sources/DeepLearning/Random.swift +++ b/Sources/DeepLearning/Random.swift @@ -19,7 +19,7 @@ import Glibc #endif //===------------------------------------------------------------------------------------------===// -// Random number generators +// Random Number Generators //===------------------------------------------------------------------------------------------===// /// A type that provides seedable deterministic pseudo-random data. @@ -409,8 +409,8 @@ private func makeUInt64Pair(_ vector: UInt32x4) -> (UInt64, UInt64) { //===------------------------------------------------------------------------------------------===// public protocol RandomDistribution { - associatedtype Sample - func next(using generator: inout G) -> Sample + associatedtype Sample + func next(using generator: inout G) -> Sample } @_fixed_layout @@ -429,8 +429,8 @@ public struct UniformIntegerDistribution: RandomDistributi } @_fixed_layout -public struct UniformFloatingPointDistribution: RandomDistribution - where T.RawSignificand : FixedWidthInteger { +public struct UniformFloatingPointDistribution: RandomDistribution + where T.RawSignificand: FixedWidthInteger { public let lowerBound: T public let upperBound: T @@ -445,8 +445,8 @@ public struct UniformFloatingPointDistribution: RandomD } @_fixed_layout -public struct NormalDistribution: RandomDistribution - where T.RawSignificand : FixedWidthInteger { +public struct NormalDistribution: RandomDistribution + where T.RawSignificand: FixedWidthInteger { public let mean: T public let standardDeviation: T private let uniformDist = UniformFloatingPointDistribution() @@ -503,10 +503,10 @@ public struct BetaDistribution: RandomDistribution { /// /// - Returns: Sample obtained using Cheng's BB algorithm. private static func chengsAlgorithmBB( - _ alpha0: Float, - _ a: Float, - _ b: Float, - using rng: inout G + _ alpha0: Float, + _ a: Float, + _ b: Float, + using rng: inout G ) -> Float { let alpha = a + b let beta = sqrt((alpha - 2) / (2 * a * b - alpha)) @@ -536,7 +536,7 @@ public struct BetaDistribution: RandomDistribution { } while r + alpha * (log(alpha) - log(b + w)) < t w = min(w, Float.greatestFiniteMagnitude) - return a == alpha0 ? w / (b + w) : b / (b + w) + return a == alpha0 ? w / (b + w): b / (b + w) } /// Returns one sample from a Beta(alpha, beta) distribution using Cheng's BC @@ -550,10 +550,10 @@ public struct BetaDistribution: RandomDistribution { /// /// - Returns: Sample obtained using Cheng's BB algorithm. private static func chengsAlgorithmBC( - _ alpha0: Float, - _ a: Float, - _ b: Float, - using rng: inout G + _ alpha0: Float, + _ a: Float, + _ b: Float, + using rng: inout G ) -> Float { let alpha = a + b let beta = 1 / b @@ -592,6 +592,6 @@ public struct BetaDistribution: RandomDistribution { } w = min(w, Float.greatestFiniteMagnitude) - return a == alpha0 ? w / (b + w) : b / (b + w) + return a == alpha0 ? w / (b + w): b / (b + w) } } diff --git a/Sources/DeepLearning/Tensors.swift b/Sources/DeepLearning/Tensors.swift new file mode 100644 index 000000000..1c1700649 --- /dev/null +++ b/Sources/DeepLearning/Tensors.swift @@ -0,0 +1,121 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if !COMPILING_TENSORFLOW_MODULE +import TensorFlow +#endif + +#if COMPILING_TENSORFLOW_MODULE +infix operator .==: ComparisonPrecedence +#endif + +//===------------------------------------------------------------------------------------------===// +// Tensor Properties +//===------------------------------------------------------------------------------------------===// + +public extension Tensor { + /// The rank of the tensor, represented as a `Tensor`. + @inlinable + var rankTensor: Tensor { + return Raw.rank(self) + } + + /// The dimensions of the tensor, represented as a `Tensor`. + @inlinable + var shapeTensor: Tensor { + return Raw.shape(self) + } + + /// The number of scalars in the tensor, represented as a `Tensor`. + @inlinable + var scalarCountTensor: Tensor { + return Raw.size(self) + } +} + +//===------------------------------------------------------------------------------------------===// +// Description and Visualization +//===------------------------------------------------------------------------------------------===// + +// String conversion. +extension Tensor: CustomStringConvertible { + /// A textual representation of the tensor. + /// + /// - Note: use `fullDescription` for a non-pretty-printed description showing all scalars. + public var description: String { + return array.description + } +} + +public extension Tensor { + /// A textual representation of the tensor. Returns a summarized description if `summarize` is + /// true and the element count exceeds twice the `edgeElementCount`. + /// + /// - Parameters: + /// - lineWidth: The max line width for printing. Used to determine number of scalars to print + /// per line. + /// - edgeElementCount: The maximum number of elements to print before and after summarization + /// via ellipses (`...`). + /// - summarizing: If true, summarize description if element count exceeds twice + /// `edgeElementCount`. + func description( + lineWidth: Int = 80, + edgeElementCount: Int = 3, + summarizing: Bool = false + ) -> String { + return array.description( + lineWidth: lineWidth, + edgeElementCount: edgeElementCount, + summarizing: summarizing) + } + + /// A full, non-pretty-printed textual representation of the tensor, showing + /// all scalars. + var fullDescription: String { + return array.fullDescription + } +} + +// Xcode Playground display conversion. +extension Tensor: CustomPlaygroundDisplayConvertible { + public var playgroundDescription: Any { + return description + } +} + +// Mirror representation, used by debugger/REPL. +extension Tensor: CustomReflectable { + public var customMirror: Mirror { + return Mirror(self, children: [], displayStyle: .struct) + } +} + +//===------------------------------------------------------------------------------------------===// +// Codable Conformance +//===------------------------------------------------------------------------------------------===// + +extension Tensor: Codable where Scalar: Codable { + @inlinable + public func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + try container.encode(array) + } + + @inlinable + public init(from decoder: Decoder) throws { + let container = try decoder.singleValueContainer() + let array = try container.decode(ShapedArray.self) + self.init(array) + } +} diff --git a/Tests/DeepLearningTests/InitializerTests.swift b/Tests/DeepLearningTests/InitializerTests.swift new file mode 100644 index 000000000..3407e5816 --- /dev/null +++ b/Tests/DeepLearningTests/InitializerTests.swift @@ -0,0 +1,108 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import XCTest +@testable import DeepLearning + +final class InitializerTests: XCTestCase { + func testInitializers() { + let scalar = Tensor(1) + let matrix: Tensor = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + let broadcastScalar = Tensor(broadcasting: 10, rank: 3) + let some4d = Tensor( + shape: [2, 1, 2, 1], + scalars: AnyRandomAccessCollection([2, 3, 4, 5])) + XCTAssertEqual(ShapedArray(shape: [2, 1, 2, 1], scalars: [2, 3, 4, 5]), some4d.array) + XCTAssertEqual(ShapedArray(shape: [], scalars: [1]), scalar.array) + XCTAssertEqual(ShapedArray(shape: [2, 3], scalars: [1, 2, 3, 4, 5, 6]), matrix.array) + XCTAssertEqual(ShapedArray(shape: [1, 1, 1], scalars: [10]), broadcastScalar.array) + } + + func testFactoryInitializers() { + let x = Tensor(ones: [1, 10]) + XCTAssertEqual(ShapedArray(repeating: 1, shape: [1, 10]), x.array) + } + + func testNumericInitializers() { + let x = Tensor(oneHotAtIndices: [0, 2, -1, 1], depth: 3) + XCTAssertEqual(ShapedArray( + shape: [4, 3], + scalars: [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0]), x.array) + } + + func testScalarToTensorConversion() { + let tensor = Tensor(broadcasting: 42, rank: 4) + XCTAssertEqual([1, 1, 1, 1], tensor.shape) + XCTAssertEqual([42], tensor.scalars) + } + + func testArrayConversion() { + let array3D = ShapedArray(repeating: 1.0, shape: [2, 3, 4]) + let tensor3D = Tensor(array3D) + XCTAssertEqual(array3D, tensor3D.array) + } + + func testNonTPUDataTypeCast() { + // TPU does not support Int8 or 16 casting. + guard !_RuntimeConfig.executionMode.isTPU else { return } + + let x = Tensor(ones: [5, 5]) + let ints = Tensor(x) + let floats = Tensor(x) + let i8s = Tensor(floats) + XCTAssertEqual(ShapedArray(repeating: 1, shape: [5, 5]), ints.array) + XCTAssertEqual(ShapedArray(repeating: 1, shape: [5, 5]), floats.array) + XCTAssertEqual(ShapedArray(repeating: 1, shape: [5, 5]), i8s.array) + } + + func testTPUDataTypeCast() { + // Non-TPU mode (e.g. eager) does not support Uint32 casting. + guard _RuntimeConfig.executionMode.isTPU else { return } + + let x = Tensor(ones: [5, 5]) + let ints = Tensor(x) + let floats = Tensor(x) + let u32s = Tensor(floats) + XCTAssertEqual(ShapedArray(repeating: 1, shape: [5, 5]), ints.array) + XCTAssertEqual(ShapedArray(repeating: 1, shape: [5, 5]), floats.array) + XCTAssertEqual(ShapedArray(repeating: 1, shape: [5, 5]), u32s.array) + } + + func testNonTPUBoolToNumericCast() { + // TPU does not support Int8 or 16 casting. + // + // When changing to UInt32, got another TPU/XLA compilation error when + // converting from bools to Uint32 (different from missing kernel error). + if _RuntimeConfig.executionMode.isTPU { return } + + let bools = Tensor(shape: [2, 2], scalars: [true, false, true, false]) + let ints = Tensor(bools) + let floats = Tensor(bools) + let i8s = Tensor(bools) + XCTAssertEqual(ShapedArray(shape: [2, 2], scalars: [1, 0, 1, 0]), ints.array) + XCTAssertEqual(ShapedArray(shape: [2, 2], scalars: [1, 0, 1, 0]), floats.array) + XCTAssertEqual(ShapedArray(shape: [2, 2], scalars: [1, 0, 1, 0]), i8s.array) + } + + static var allTests = [ + ("testInitializers", testInitializers), + ("testFactoryInitializers", testFactoryInitializers), + ("testNumericInitializers", testNumericInitializers), + ("testScalarToTensorConversion", testScalarToTensorConversion), + ("testArrayConversion", testArrayConversion), + ("testNonTPUDataTypeCast", testNonTPUDataTypeCast), + ("testTPUDataTypeCast", testTPUDataTypeCast), + ("testNonTPUBoolToNumericCast", testNonTPUBoolToNumericCast) + ] +} diff --git a/Tests/DeepLearningTests/OperatorTests/BasicTests.swift b/Tests/DeepLearningTests/OperatorTests/BasicTests.swift new file mode 100644 index 000000000..ae25efbc5 --- /dev/null +++ b/Tests/DeepLearningTests/OperatorTests/BasicTests.swift @@ -0,0 +1,479 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import XCTest +@testable import DeepLearning + +final class BasicOperatorTests: XCTestCase { + func testElementIndexing() { + // NOTE: cannot test multiple `Tensor.shape` or `Tensor.scalars` directly + // until send and receive are implemented (without writing a bunch of mini + // tests). Instead, `Tensor.array` is called to make a ShapedArray host copy + // and the ShapedArray is tested. + let tensor3D = Tensor( + shape: [3, 4, 5], scalars: Array(stride(from: 0.0, to: 60, by: 1))) + let element2D = tensor3D[2] + let element1D = tensor3D[1][3] + let element0D = tensor3D[2][0][3] + + let array2D = element2D.array + let array1D = element1D.array + let array0D = element0D.array + + /// Test shapes + XCTAssertEqual([4, 5], array2D.shape) + XCTAssertEqual([5], array1D.shape) + XCTAssertEqual([], array0D.shape) + + /// Test scalars + XCTAssertEqual(Array(stride(from: 40.0, to: 60, by: 1)), array2D.scalars) + XCTAssertEqual(Array(stride(from: 35.0, to: 40, by: 1)), array1D.scalars) + XCTAssertEqual([43], array0D.scalars) + } + + func testElementIndexingAssignment() { + // NOTE: cannot test multiple `Tensor.shape` or `Tensor.scalars` directly + // until send and receive are implemented (without writing a bunch of mini + // tests). Instead, `Tensor.array` is called to make a ShapedArray host copy + // and the ShapedArray is tested. + var tensor3D = Tensor( + shape: [3, 4, 5], scalars: Array(stride(from: 0.0, to: 60, by: 1))) + tensor3D[2] = Tensor( + shape: [4, 5], scalars: Array(stride(from: 20.0, to: 40, by: 1))) + let element2D = tensor3D[2] + let element1D = tensor3D[1][3] + let element0D = tensor3D[2][0][3] + + let array2D = element2D.array + let array1D = element1D.array + let array0D = element0D.array + + /// Test shapes + XCTAssertEqual([4, 5], array2D.shape) + XCTAssertEqual([5], array1D.shape) + XCTAssertEqual([], array0D.shape) + + /// Test scalars + XCTAssertEqual(Array(stride(from: 20.0, to: 40, by: 1)), array2D.scalars) + XCTAssertEqual(Array(stride(from: 35.0, to: 40, by: 1)), array1D.scalars) + XCTAssertEqual([23], array0D.scalars) + } + + func testNestedElementIndexing() { + // NOTE: This test could use a clearer name, along with other "indexing" + // tests. Note to update corresponding test names in other files + // (shaped_array.test) as well. + let tensor3D = Tensor( + shape: [3, 4, 5], scalars: Array(stride(from: 0.0, to: 60, by: 1))) + let element1D = tensor3D[1, 3] + let element0D = tensor3D[2, 0, 3] + + let array1D = element1D.array + let array0D = element0D.array + + /// Test shapes + XCTAssertEqual([5], array1D.shape) + XCTAssertEqual([], array0D.shape) + + /// Test scalars + XCTAssertEqual(Array(stride(from: 35.0, to: 40, by: 1)), array1D.scalars) + XCTAssertEqual([43], array0D.scalars) + } + + func testSliceIndexing() { + // NOTE: cannot test `Tensor.shape` or `Tensor.scalars` directly until send + // and receive are implemented (without writing a bunch of mini tests). + // Instead, `Tensor.array` is called to make a ShapedArray host copy and the + // ShapedArray is tested instead. + let tensor3D = Tensor( + shape: [3, 4, 5], scalars: Array(stride(from: 0.0, to: 60, by: 1))) + let slice3D = tensor3D[2...] + let slice2D = tensor3D[1][0..<2] + let slice1D = tensor3D[0][0][3..<5] + + let array3D = slice3D.array + let array2D = slice2D.array + let array1D = slice1D.array + + /// Test shapes + XCTAssertEqual([1, 4, 5], array3D.shape) + XCTAssertEqual([2, 5], array2D.shape) + XCTAssertEqual([2], array1D.shape) + + /// Test scalars + XCTAssertEqual(Array(stride(from: 40.0, to: 60, by: 1)), array3D.scalars) + XCTAssertEqual(Array(stride(from: 20.0, to: 30, by: 1)), array2D.scalars) + XCTAssertEqual(Array(stride(from: 3.0, to: 5, by: 1)), array1D.scalars) + } + + func testSliceIndexingAssignment() { + // NOTE: cannot test `Tensor.shape` or `Tensor.scalars` directly until send + // and receive are implemented (without writing a bunch of mini tests). + // Instead, `Tensor.array` is called to make a ShapedArray host copy and the + // ShapedArray is tested instead. + var tensor3D = Tensor( + shape: [3, 4, 5], scalars: Array(stride(from: 0.0, to: 60, by: 1))) + tensor3D[2, 0..<5, 0..<6] = Tensor( + shape: [4, 5], scalars: Array(stride(from: 20.0, to: 40, by: 1))) + let slice3D = tensor3D[2...] + let slice2D = tensor3D[1][0..<2] + let slice1D = tensor3D[0][0][3..<5] + + let array3D = slice3D.array + let array2D = slice2D.array + let array1D = slice1D.array + + /// Test shapes + XCTAssertEqual([1, 4, 5], array3D.shape) + XCTAssertEqual([2, 5], array2D.shape) + XCTAssertEqual([2], array1D.shape) + + /// Test scalars + XCTAssertEqual(Array(stride(from: 20.0, to: 40, by: 1)), array3D.scalars) + XCTAssertEqual(Array(stride(from: 20.0, to: 30, by: 1)), array2D.scalars) + XCTAssertEqual(Array(stride(from: 3.0, to: 5, by: 1)), array1D.scalars) + } + + func testEllipsisIndexing() { + // NOTE: cannot test `Tensor.shape` or `Tensor.scalars` directly until send + // and receive are implemented (without writing a bunch of mini tests). + // Instead, `Tensor.array` is called to make a ShapedArray host copy and the + // ShapedArray is tested instead. + var tensor3D = Tensor( + shape: [3, 4, 5], scalars: Array(stride(from: 0.0, to: 60, by: 1))) + tensor3D[2, TensorRange.ellipsis] = Tensor( + shape: [4, 5], scalars: Array(stride(from: 20.0, to: 40, by: 1))) + let slice3D = tensor3D[2..., TensorRange.ellipsis] + let slice2D = tensor3D[1][0..<2] + let slice1D = tensor3D[0][0][3..<5] + + let array3D = slice3D.array + let array2D = slice2D.array + let array1D = slice1D.array + + /// Test shapes + XCTAssertEqual([1, 4, 5], array3D.shape) + XCTAssertEqual([2, 5], array2D.shape) + XCTAssertEqual([2], array1D.shape) + + /// Test scalars + XCTAssertEqual(Array(stride(from: 20.0, to: 40, by: 1)), array3D.scalars) + XCTAssertEqual(Array(stride(from: 20.0, to: 30, by: 1)), array2D.scalars) + XCTAssertEqual(Array(stride(from: 3.0, to: 5, by: 1)), array1D.scalars) + } + + func testNewAxisIndexing() { + // NOTE: cannot test `Tensor.shape` or `Tensor.scalars` directly until send + // and receive are implemented (without writing a bunch of mini tests). + // Instead, `Tensor.array` is called to make a ShapedArray host copy and the + // ShapedArray is tested instead. + let tensor3D = Tensor( + shape: [3, 4, 5], scalars: Array(stride(from: 0.0, to: 60, by: 1))) + let newAxis = TensorRange.newAxis + let ellipsis = TensorRange.ellipsis + let slice3D = tensor3D[2..., newAxis, ellipsis] + let slice2D = tensor3D[1, newAxis][0..<1, 0..<2] + let slice1D = tensor3D[0][newAxis, 0][0..<1, 3..<5, newAxis] + + let array3D = slice3D.array + let array2D = slice2D.array + let array1D = slice1D.array + + /// Test shapes + XCTAssertEqual([1, 1, 4, 5], array3D.shape) + XCTAssertEqual([1, 2, 5], array2D.shape) + XCTAssertEqual([1, 2, 1], array1D.shape) + + /// Test scalars + XCTAssertEqual(Array(stride(from: 40.0, to: 60, by: 1)), array3D.scalars) + XCTAssertEqual(Array(stride(from: 20.0, to: 30, by: 1)), array2D.scalars) + XCTAssertEqual(Array(stride(from: 3.0, to: 5, by: 1)), array1D.scalars) + } + + func testSqueezeAxisIndexing() { + // NOTE: cannot test `Tensor.shape` or `Tensor.scalars` directly until send + // and receive are implemented (without writing a bunch of mini tests). + // Instead, `Tensor.array` is called to make a ShapedArray host copy and the + // ShapedArray is tested instead. + let tensor3D = Tensor( + shape: [3, 4, 5], scalars: Array(stride(from: 0.0, to: 60, by: 1))) + let newAxis = TensorRange.newAxis + let ellipsis = TensorRange.ellipsis + let squeezeAxis = TensorRange.squeezeAxis + let slice3D = tensor3D[2..., newAxis, ellipsis][squeezeAxis, squeezeAxis] + let slice2D = tensor3D[1, newAxis][squeezeAxis, 0..<2] + let slice1D = tensor3D[0..<1, 0, 3..<5, newAxis][ + squeezeAxis, ellipsis, squeezeAxis] + + let array3D = slice3D.array + let array2D = slice2D.array + let array1D = slice1D.array + + /// Test shapes + XCTAssertEqual([4, 5], array3D.shape) + XCTAssertEqual([2, 5], array2D.shape) + XCTAssertEqual([2], array1D.shape) + + /// Test scalars + XCTAssertEqual(Array(stride(from: 40.0, to: 60, by: 1)), array3D.scalars) + XCTAssertEqual(Array(stride(from: 20.0, to: 30, by: 1)), array2D.scalars) + XCTAssertEqual(Array(stride(from: 3.0, to: 5, by: 1)), array1D.scalars) + } + + func testStridedSliceIndexing() { + // NOTE: cannot test `Tensor.shape` or `Tensor.scalars` directly until send + // and receive are implemented (without writing a bunch of mini tests). + // Instead, `Tensor.array` is called to make a ShapedArray host copy and the + // ShapedArray is tested instead. + let tensor3D = Tensor( + shape: [3, 4, 5], scalars: Array(stride(from: 0.0, to: 60, by: 1))) + let slice3D = tensor3D[2...] + let slice2D = tensor3D[1][0..<3..2] + let slice1D = tensor3D[0][0][1..<5..2] + + let array3D = slice3D.array + let array2D = slice2D.array + let array1D = slice1D.array + + /// Test shapes + XCTAssertEqual([1, 4, 5], array3D.shape) + XCTAssertEqual([2, 5], array2D.shape) + XCTAssertEqual([2], array1D.shape) + + /// Test scalars + XCTAssertEqual(Array(stride(from: 40.0, to: 60, by: 1)), array3D.scalars) + XCTAssertEqual( + Array(stride(from: 20.0, to: 25, by: 1)) + + Array(stride(from: 30.0, to: 35, by: 1)), array2D.scalars) + XCTAssertEqual(Array(stride(from: 1.0, to: 5, by: 2)), array1D.scalars) + } + + func testStridedSliceIndexingAssignment() { + // NOTE: cannot test `Tensor.shape` or `Tensor.scalars` directly until send + // and receive are implemented (without writing a bunch of mini tests). + // Instead, `Tensor.array` is called to make a ShapedArray host copy and the + // ShapedArray is tested instead. + var tensor3D = Tensor( + shape: [3, 4, 5], scalars: Array(stride(from: 0.0, to: 60, by: 1))) + tensor3D[2, 0..<5..2, 0..<6] = Tensor( + shape: [2, 5], scalars: Array(stride(from: 20.0, to: 40, by: 2))) + let slice3D = tensor3D[2...] + let slice2D = tensor3D[1][0..<2] + let slice1D = tensor3D[0][0][3..<5] + + let array3D = slice3D.array + let array2D = slice2D.array + let array1D = slice1D.array + + /// Test shapes + XCTAssertEqual([1, 4, 5], array3D.shape) + XCTAssertEqual([2, 5], array2D.shape) + XCTAssertEqual([2], array1D.shape) + + /// Test scalars + XCTAssertEqual( + Array(stride(from: 20.0, to: 30, by: 2)) + + Array(stride(from: 45.0, to: 50, by: 1)) + + Array(stride(from: 30.0, to: 40, by: 2)) + + Array(stride(from: 55.0, to: 60, by: 1)), array3D.scalars) + XCTAssertEqual(Array(stride(from: 20.0, to: 30, by: 1)), array2D.scalars) + XCTAssertEqual(Array(stride(from: 3.0, to: 5, by: 1)), array1D.scalars) + } + + func testWholeTensorSlicing() { + let t: Tensor = [[[1, 1, 1], [2, 2, 2]], + [[3, 3, 3], [4, 4, 4]], + [[5, 5, 5], [6, 6, 6]]] + let slice2 = t.slice(lowerBounds: [1, 0, 0], upperBounds: [2, 1, 3]) + XCTAssertEqual(ShapedArray(shape: [1, 1, 3], scalars: [3, 3, 3]), slice2.array) + } + + func testAdvancedIndexing() { + // NOTE: cannot test multiple `Tensor.shape` or `Tensor.scalars` directly + // until send and receive are implemented (without writing a bunch of mini + // tests). Instead, `Tensor.array` is called to make a ShapedArray host copy + // and the ShapedArray is tested. + let tensor3D = Tensor( + shape: [3, 4, 5], scalars: Array(stride(from: 0.0, to: 60, by: 1))) + let element2D = tensor3D[1..<3, 0, 3...] + let array2D = element2D.array + + // Test shape + XCTAssertEqual([2, 2], array2D.shape) + + // Test scalars + XCTAssertEqual(Array([23.0, 24.0, 43.0, 44.0]), array2D.scalars) + } + + func testConcatenation() { + // 2 x 3 + let t1 = Tensor([[0, 1, 2], [3, 4, 5]]) + // 2 x 3 + let t2 = Tensor([[6, 7, 8], [9, 10, 11]]) + let concatenated = t1 ++ t2 + let concatenated0 = t1.concatenated(with: t2) + let concatenated1 = t1.concatenated(with: t2, alongAxis: 1) + XCTAssertEqual(ShapedArray(shape: [4, 3], scalars: Array(0..<12)), concatenated.array) + XCTAssertEqual(ShapedArray(shape: [4, 3], scalars: Array(0..<12)), concatenated0.array) + XCTAssertEqual( + ShapedArray(shape: [2, 6], scalars: [0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10, 11]), + concatenated1.array) + } + + func testVJPConcatenation() { + let a1 = Tensor([1,2,3,4]) + let b1 = Tensor([5,6,7,8,9,10]) + + let a2 = Tensor([1,1,1,1]) + let b2 = Tensor([1,1,1,1,1,1]) + + let grads = gradient(at: a2, b2) { a, b in + return ((a1 * a) ++ (b1 * b)).sum() + } + + XCTAssertEqual(a1, grads.0) + XCTAssertEqual(b1, grads.1) + } + + func testVJPConcatenationNegativeAxis() { + let a1 = Tensor([1,2,3,4]) + let b1 = Tensor([5,6,7,8,9,10]) + + let a2 = Tensor([1,1,1,1]) + let b2 = Tensor([1,1,1,1,1,1]) + + let grads = gradient(at: a2, b2) { a, b in + return (a1 * a).concatenated(with: b1 * b, alongAxis: -1).sum() + } + + XCTAssertEqual(a1, grads.0) + XCTAssertEqual(b1, grads.1) + } + + func testTranspose() { + // 3 x 2 -> 2 x 3 + let xT = Tensor([[1, 2], [3, 4], [5, 6]]).transposed() + let xTArray = xT.array + XCTAssertEqual(2, xTArray.rank) + XCTAssertEqual([2, 3], xTArray.shape) + XCTAssertEqual([1, 3, 5, 2, 4, 6], xTArray.scalars) + } + + func testReshape() { + // 2 x 3 -> 1 x 3 x 1 x 2 x 1 + let matrix = Tensor([[0, 1, 2], [3, 4, 5]]) + let reshaped = matrix.reshaped(to: [1, 3, 1, 2, 1]) + + XCTAssertEqual([1, 3, 1, 2, 1], reshaped.shape) + XCTAssertEqual(Array(0..<6), reshaped.scalars) + } + + func testFlatten() { + // 2 x 3 -> 6 + let matrix = Tensor([[0, 1, 2], [3, 4, 5]]) + let flattened = matrix.flattened() + + XCTAssertEqual([6], flattened.shape) + XCTAssertEqual(Array(0..<6), flattened.scalars) + } + + func testFlatten0D() { + let scalar = Tensor(5) + let flattened = scalar.flattened() + XCTAssertEqual([1], flattened.shape) + XCTAssertEqual([5], flattened.scalars) + } + + func testReshapeToScalar() { + // 1 x 1 -> scalar + let z = Tensor([[10]]).reshaped(to: []) + XCTAssertEqual([], z.shape) + } + + func testReshapeTensor() { + // 2 x 3 -> 1 x 3 x 1 x 2 x 1 + let x = Tensor(repeating: 0.0, shape: [2, 3]) + let y = Tensor(repeating: 0.0, shape: [1, 3, 1, 2, 1]) + let result = x.reshaped(like: y) + XCTAssertEqual([1, 3, 1, 2, 1], result.shape) + } + + func testUnbroadcast1() { + let x = Tensor(repeating: 1, shape: [2, 3, 4, 5]) + let y = Tensor(repeating: 1, shape: [4, 5]) + let z = x.unbroadcast(like: y) + XCTAssertEqual(ShapedArray(repeating: 6, shape: [4, 5]), z.array) + } + + func testUnbroadcast2() { + let x = Tensor(repeating: 1, shape: [2, 3, 4, 5]) + let y = Tensor(repeating: 1, shape: [3, 1, 5]) + let z = x.unbroadcast(like: y) + XCTAssertEqual(ShapedArray(repeating: 8, shape: [3, 1, 5]), z.array) + } + + func testSliceUpdate() { + guard !_RuntimeConfig.executionMode.isTPU else { return } + var t1 = Tensor([[1, 2, 3], [4, 5, 6]]) + t1[0] = Tensor(zeros: [3]) + XCTAssertEqual(ShapedArray(shape:[2, 3], scalars: [0, 0, 0, 4, 5, 6]), t1.array) + var t2 = t1 + t2[0][2] = Tensor(3) + XCTAssertEqual(ShapedArray(shape:[2, 3], scalars: [0, 0, 3, 4, 5, 6]), t2.array) + var t3 = Tensor([[true, true, true], [false, false, false]]) + t3[0][1] = Tensor(false) + XCTAssertEqual(ShapedArray( + shape:[2, 3], scalars: [true, false, true, false, false, false]), t3.array) + var t4 = Tensor([[true, true, true], [false, false, false]]) + t4[0] = Tensor(repeating: false, shape: [3]) + XCTAssertEqual(ShapedArray(repeating: false, shape: [2, 3]), t4.array) + } + + func testBroadcastTensor() { + // 1 -> 2 x 3 x 4 + let one = Tensor(1) + var target = Tensor(repeating: 0.0, shape: [2, 3, 4]) + let broadcasted = one.broadcast(like: target) + XCTAssertEqual(Tensor(repeating: 1, shape: [2, 3, 4]), broadcasted) + target .= Tensor(repeating: 1, shape: [1, 3, 1]) + XCTAssertEqual(Tensor(repeating: 1, shape: [2, 3, 4]), target) + } + + static var allTests = [ + ("testElementIndexing", testElementIndexing), + ("testElementIndexingAssignment", testElementIndexingAssignment), + ("testNestedElementIndexing", testNestedElementIndexing), + ("testSliceIndexing", testSliceIndexing), + ("testSliceIndexingAssignment", testSliceIndexingAssignment), + ("testEllipsisIndexing", testEllipsisIndexing), + ("testNewAxisIndexing", testNewAxisIndexing), + ("testSqueezeAxisIndexing", testSqueezeAxisIndexing), + ("testStridedSliceIndexing", testStridedSliceIndexing), + ("testStridedSliceIndexingAssignment", testStridedSliceIndexingAssignment), + ("testWholeTensorSlicing", testWholeTensorSlicing), + ("testAdvancedIndexing", testAdvancedIndexing), + ("testConcatenation", testConcatenation), + ("testVJPConcatenation", testVJPConcatenation), + ("testTranspose", testTranspose), + ("testReshape", testReshape), + ("testFlatten", testFlatten), + ("testFlatten0D", testFlatten0D), + ("testReshapeToScalar", testReshapeToScalar), + ("testReshapeTensor", testReshapeTensor), + ("testUnbroadcast1", testUnbroadcast1), + ("testUnbroadcast2", testUnbroadcast2), + ("testSliceUpdate", testSliceUpdate), + ("testBroadcastTensor", testBroadcastTensor) + ] +} diff --git a/Tests/DeepLearningTests/OperatorTests/ComparisonTests.swift b/Tests/DeepLearningTests/OperatorTests/ComparisonTests.swift new file mode 100644 index 000000000..e20a9cdc9 --- /dev/null +++ b/Tests/DeepLearningTests/OperatorTests/ComparisonTests.swift @@ -0,0 +1,35 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import XCTest +@testable import DeepLearning + +final class ComparisonOperatorTests: XCTestCase { + func testElementwiseComparison() { + let x = Tensor([0, 1, 2]) + let y = Tensor([2, 1, 3]) + XCTAssertEqual((x .< y).scalars, [true, false, true]) + } + + func testLexicographicalComparison() { + let x = Tensor([0, 1, 2, 3, 4]) + let y = Tensor([2, 3, 4, 5, 6]) + XCTAssertTrue(x < y) + } + + static var allTests = [ + ("testElementwiseComparison", testElementwiseComparison), + ("testLexicographicalComparison", testLexicographicalComparison) + ] +} diff --git a/Tests/DeepLearningTests/OperatorTests/MathTests.swift b/Tests/DeepLearningTests/OperatorTests/MathTests.swift new file mode 100644 index 000000000..3f769be07 --- /dev/null +++ b/Tests/DeepLearningTests/OperatorTests/MathTests.swift @@ -0,0 +1,211 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import XCTest +@testable import DeepLearning + +final class MathOperatorTests: XCTestCase { + func testReduction() { + // 2 x 5 + let x = Tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]) + XCTAssertEqual(Tensor(30), x.sum().toHost(shape: [])) + XCTAssertEqual( + Tensor(shape: [5], scalars: [2, 4, 6, 8, 10]), + x.sum(squeezingAxes: 0).toHost(shape: [])) + XCTAssertEqual( + Tensor(shape: [1, 5], scalars: [2, 4, 6, 8, 10]), + x.sum(alongAxes: 0).toHost(shape: [])) + + XCTAssertEqual(Tensor(14400), x.product().toHost(shape: [])) + XCTAssertEqual( + Tensor(shape: [5], scalars: [1, 4, 9, 16, 25]), + x.product(squeezingAxes: 0).toHost(shape: [])) + XCTAssertEqual( + Tensor(shape: [1, 5], scalars: [1, 4, 9, 16, 25]), + x.product(alongAxes: 0).toHost(shape: [])) + + XCTAssertEqual(Tensor(3), x.mean().toHost(shape: [])) + XCTAssertEqual( + Tensor(shape: [5], scalars: [1, 2, 3, 4, 5]), + x.mean(squeezingAxes: 0).toHost(shape: [])) + XCTAssertEqual( + Tensor(shape: [5], scalars: [1, 2, 3, 4, 5]), + x.mean(alongAxes: 0).toHost(shape: [])) + XCTAssertEqual( + Tensor(shape: [2], scalars: [3, 3]), + x.mean(squeezingAxes: 1).toHost(shape: [])) + XCTAssertEqual( + Tensor(shape: [1, 2], scalars: [3, 3]), + x.mean(alongAxes: 1).toHost(shape: [])) + + XCTAssertEqual(Tensor(2), x.variance().toHost(shape: [])) + XCTAssertEqual( + Tensor(shape: [5], scalars: [0, 0, 0, 0, 0]), + x.variance(squeezingAxes: 0).toHost(shape: [])) + XCTAssertEqual( + Tensor(shape: [5], scalars: [0, 0, 0, 0, 0]), + x.variance(alongAxes: 0).toHost(shape: [])) + XCTAssertEqual( + Tensor(shape: [2], scalars: [2, 2]), + x.variance(squeezingAxes: 1).toHost(shape: [])) + XCTAssertEqual( + Tensor(shape: [1, 2], scalars: [2, 2]), + x.variance(alongAxes: 1).toHost(shape: [])) + } + + func testArgmax() { + // 2 x 3 + let x = Tensor([[0, 1, 2], [3, 4, 5]]) + let argmax0 = x.argmax(squeezingAxis: 0) + let argmax1 = x.argmax(squeezingAxis: 1) + let scalarsArgmax = x.argmax() + XCTAssertEqual(ShapedArray(shape: [3], scalars: [1, 1, 1]), argmax0.array) + XCTAssertEqual(ShapedArray(shape: [2], scalars: [2, 2]), argmax1.array) + XCTAssertEqual(ShapedArray(shape: [], scalars: [5]), scalarsArgmax.array) + } + + func testCeilAndFloor() { + let x = Tensor([-1.3, -0.4, 0.5, 1.6]) + let xFloor = floor(x) + let xCeil = ceil(x) + XCTAssertEqual(ShapedArray(shape: [4], scalars: [-2, -1, 0, 1]), xFloor.array) + XCTAssertEqual(ShapedArray(shape: [4], scalars: [-1, 0, 1, 2]), xCeil.array) + } + + func testSimpleMath() { + let x = Tensor([1.2, 1.2]) + let y = tanh(x) + let array = y.array + XCTAssertEqual([2], array.shape) + XCTAssertEqual([0.833655, 0.833655], array.scalars, accuracy: 0.0001) + } + + func testStandardDeviation() { + XCTAssertEqual(Tensor(0), Tensor([1]).standardDeviation()) + XCTAssertEqual(Tensor(0.5), Tensor([0, 1]).standardDeviation(alongAxes: 0)) + XCTAssertEqual(Tensor(0.5), Tensor([0, 1]).standardDeviation()) + XCTAssertEqual( + 2.87228132, + Tensor(rangeFrom: 0, to: 10, stride: 1).standardDeviation().scalarized(), + accuracy: 0.001) + let matrix = Tensor(rangeFrom: 0, to: 10, stride: 1).reshaped(to: [2, 5]) + XCTAssertEqual(2.87228132, matrix.standardDeviation().scalarized(), accuracy: 0.001) + XCTAssertEqual( + [1.4142, 1.4142], + matrix.standardDeviation(alongAxes: 1).array.scalars, + accuracy: 0.001) + } + + func test3Adds() { + let a = Tensor([1]) + let b = Tensor([2]) + let c = Tensor([3]) + + let o = a + b + c + XCTAssertEqual([6], o.scalars) + } + + func testMultiOpMath() { + let x = Tensor([1.2, 1.2]) + let y = Tensor([2.4, 2.4]) + let t1 = x + y + let t2 = t1 * t1 + let t3 = sqrt(t2) + + let array1 = t1.array + let array2 = t2.array + let array3 = t3.array + XCTAssertEqual([2], array1.shape) + XCTAssertEqual([2], array2.shape) + XCTAssertEqual([2], array3.shape) + XCTAssertEqual([3.6, 3.6], array1.scalars, accuracy: 0.001) + XCTAssertEqual([12.96, 12.96], array2.scalars, accuracy: 0.001) + XCTAssertEqual([3.6, 3.6], array3.scalars, accuracy: 0.001) + } + + func testXWPlusB() { + // Shape: 1 x 4 + let x = Tensor([[1.0, 2.0, 2.0, 1.0]]) + // Shape: 4 x 2 + let w = Tensor([[1.0, 0.0], [3.0, 0.0], [2.0, 3.0], [1.0, 0.0]]) + // Shape: 2 + let b = Tensor([0.5, 0.5]) + // Shape: 1 x 2 (broadcasted) + let result = matmul(x, w) + b + XCTAssertEqual([1, 2], result.shape) + XCTAssertEqual([12.5, 6.5], result.scalars) + } + + func testXORInference() { + func xor(_ x: Float, _ y: Float) -> Float { + let x = Tensor([x, y]).reshaped(to: [1, 2]) + + // FIXME: If params are declared outside of `xor`, it would crash. + // 2 x 4 + let w1 = Tensor( + [[-1.83586664, -0.20809225, 0.47667537, 1.90780607], + [-1.83523219, -0.51167348, 0.15490439, 1.91018065]]) + // 1 x 4 + let b1 = Tensor([[2.54353216, 0.25132703, -0.16503136, -0.85754058]]) + // 4 x 1 + let w2 = Tensor([[3.04350065], [0.35590511], [-0.3252157], [3.49349223]]) + // 1 x 1 + let b2 = Tensor([[-0.74635993]]) + + let o1 = tanh(matmul(x, w1) + b1) + let y = tanh(matmul(o1, w2) + b2) + return y.array.scalars[0] // TODO: use better scalar getter + } + + XCTAssertEqual(0.0, xor(0.0, 0.0), accuracy: 0.1) + XCTAssertEqual(1.0, xor(0.0, 1.0), accuracy: 0.1) + XCTAssertEqual(1.0, xor(1.0, 0.0), accuracy: 0.1) + XCTAssertEqual(0.0, xor(1.0, 1.0), accuracy: 0.1) + } + + func testMLPClassifierStruct() { + struct MLPClassifier { + // 2 x 4 + var w1 = Tensor([[1.0, 0.8, 0.4, 0.4], + [0.4, 0.3, 0.2, 0.1]]) + // 4 x 1 + var w2 = Tensor([[0.4], [0.4], [0.3], [0.9]]) + var b1 = Tensor(zeros: [1, 4]) + var b2 = Tensor(zeros: [1, 1]) + + func prediction(for x: Tensor) -> Tensor { + let o1 = tanh(matmul(x, w1) + b1) + return tanh(matmul(o1, w2) + b2) + } + } + + let input = Tensor([[1, 0.5]]) + let classifier = MLPClassifier() + let prediction = classifier.prediction(for: input) + XCTAssertEqual([0.816997], prediction.scalars, accuracy: 0.001) + } + + static var allTests = [ + ("testReduction", testReduction), + ("testArgmax", testArgmax), + ("testCeilAndFloor", testCeilAndFloor), + ("testSimpleMath", testSimpleMath), + ("testStandardDeviation", testStandardDeviation), + ("test3Adds", test3Adds), + ("testMultiOpMath", testMultiOpMath), + ("testXWPlusB", testXWPlusB), + ("testXORInference", testXORInference), + ("testMLPClassifierStruct", testMLPClassifierStruct) + ] +} diff --git a/Tests/DeepLearningTests/TensorTests.swift b/Tests/DeepLearningTests/TensorTests.swift new file mode 100644 index 000000000..ec7d1f6e3 --- /dev/null +++ b/Tests/DeepLearningTests/TensorTests.swift @@ -0,0 +1,59 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import XCTest +@testable import DeepLearning + +final class TensorTests: XCTestCase { + func testSimpleCond() { + func selectValue(_ pred: Bool) -> Tensor { + let a = Tensor(0) + let b = Tensor(1) + if pred { + return a + } + return b + } + + XCTAssertEqual(0, selectValue(true).scalar) + } + + func testRankGetter() { + let vector = Tensor([1]) + let matrix = Tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + let ones = Tensor(ones: [1, 2, 2, 2, 2, 2, 1]) + let tensor = Tensor(shape: [3, 4, 5], scalars: Array(0..<60)) + XCTAssertEqual(1, vector.rank) + XCTAssertEqual(2, matrix.rank) + XCTAssertEqual(7, ones.rank) + XCTAssertEqual(3, tensor.rank) + } + + func testShapeGetter() { + let vector = Tensor([1]) + let matrix = Tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + let ones = Tensor(ones: [1, 2, 2, 2, 2, 2, 1]) + let tensor = Tensor(shape: [3, 4, 5], scalars: Array(0..<60)) + XCTAssertEqual([1], vector.shape) + XCTAssertEqual([2, 3], matrix.shape) + XCTAssertEqual([1, 2, 2, 2, 2, 2, 1], ones.shape) + XCTAssertEqual([3, 4, 5], tensor.shape) + } + + static var allTests = [ + ("testSimpleCond", testSimpleCond), + ("testRankGetter", testRankGetter), + ("testShapeGetter", testShapeGetter) + ] +} diff --git a/Tests/DeepLearningTests/XCTestManifests.swift b/Tests/DeepLearningTests/XCTestManifests.swift index 96a9048a5..e75c25298 100644 --- a/Tests/DeepLearningTests/XCTestManifests.swift +++ b/Tests/DeepLearningTests/XCTestManifests.swift @@ -22,6 +22,10 @@ public func allTests() -> [XCTestCaseEntry] { testCase(TrivialModelTests.allTests), testCase(SequentialTests.allTests), testCase(LayerTests.allTests), + testCase(TensorTests.allTests), + testCase(BasicOperatorTests.allTests), + testCase(ComparisonOperatorTests.allTests), + testCase(MathOperatorTests.allTests), ] } #endif