diff --git a/Sources/DeepLearning/DifferentialOperators.swift b/Sources/DeepLearning/DifferentialOperators.swift deleted file mode 100644 index bfb53db77..000000000 --- a/Sources/DeepLearning/DifferentialOperators.swift +++ /dev/null @@ -1,178 +0,0 @@ -// Copyright 2018 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#if !COMPILING_TENSORFLOW_MODULE -import TensorFlow -#endif - -//===------------------------------------------------------------------------------------------===// -// Method-style Differential Operators -//===------------------------------------------------------------------------------------------===// - -public extension Differentiable { - @inlinable - func gradient( - in f: @differentiable (Self) -> Tensor - ) -> CotangentVector { - return self.pullback(in: f)(Tensor(1)) - } - - @inlinable - func valueWithGradient( - in f: @differentiable (Self) -> Tensor - ) -> (value: Tensor, gradient: CotangentVector) { - let (y, pb) = self.valueWithPullback(in: f) - return (y, pb(Tensor(1))) - } - - @inlinable - func gradient( - at x: T, - in f: @differentiable (Self, T) -> Tensor - ) -> (CotangentVector, T.CotangentVector) { - return self.pullback(at: x, in: f)(Tensor(1)) - } - - @inlinable - func valueWithGradient( - at x: T, - in f: @differentiable (Self, T) -> Tensor - ) -> (value: Tensor, gradient: (CotangentVector, T.CotangentVector)) { - let (y, pb) = self.valueWithPullback(at: x, in: f) - return (y, pb(Tensor(1))) - } -} - -//===------------------------------------------------------------------------------------------===// -// Free-Function-Style Differential Operators -//===------------------------------------------------------------------------------------------===// - -// Value with gradient - -@inlinable -public func valueWithGradient( - at x: T, - in f: @differentiable (T) -> Tensor -) -> (value: Tensor, gradient: T.CotangentVector) -where T: Differentiable, R: TensorFlowFloatingPoint { - let (y, pullback) = valueWithPullback(at: x, in: f) - return (y, pullback(Tensor(1))) -} - -@inlinable -public func valueWithGradient( - at x: T, - _ y: U, - in f: @differentiable (T, U) -> Tensor -) -> (value: Tensor, gradient: (T.CotangentVector, U.CotangentVector)) - where T: Differentiable, U: Differentiable, R: TensorFlowFloatingPoint { - let (y, pullback) = valueWithPullback(at: x, y, in: f) - return (y, pullback(Tensor(1))) -} - -@inlinable -public func valueWithGradient( - at x: T, - _ y: U, - _ z: V, - in f: @differentiable (T, U, V) -> Tensor -) -> (value: Tensor, gradient: (T.CotangentVector, U.CotangentVector, V.CotangentVector)) - where T: Differentiable, U: Differentiable, V: Differentiable, R: TensorFlowFloatingPoint { - let (y, pullback) = valueWithPullback(at: x, y, z, in: f) - return (y, pullback(Tensor(1))) -} - -// Value with gradient (curried) - -@inlinable -public func valueWithGradient( - of f: @escaping @differentiable (T) -> Tensor -) -> (T) -> (value: Tensor, gradient: T.CotangentVector) - where T: Differentiable, R: TensorFlowFloatingPoint { - return { x in valueWithGradient(at: x, in: f) } -} - -@inlinable -public func valueWithGradient( - of f: @escaping @differentiable (T, U) -> Tensor -) -> (T, U) -> (value: Tensor, gradient: (T.CotangentVector, U.CotangentVector)) - where T: Differentiable, U: Differentiable, R: TensorFlowFloatingPoint { - return { x, y in valueWithGradient(at: x, y, in: f) } -} - -@inlinable -public func valueWithGradient( - of f: @escaping @differentiable (T, U, V) -> Tensor -) -> (T, U, V) -> ( - value: Tensor, - gradient: (T.CotangentVector, U.CotangentVector, V.CotangentVector)) - where T: Differentiable, U: Differentiable, V: Differentiable, R: TensorFlowFloatingPoint { - return { x, y, z in valueWithGradient(at: x, y, z, in: f) } -} - -// Gradient - -@inlinable -public func gradient( - at x: T, - in f: @differentiable (T) -> Tensor -) -> T.CotangentVector where T: Differentiable, R: TensorFlowFloatingPoint { - return pullback(at: x, in: f)(Tensor(1)) -} - -@inlinable -public func gradient( - at x: T, - _ y: U, - in f: @differentiable (T, U) -> Tensor -) -> (T.CotangentVector, U.CotangentVector) - where T: Differentiable, U: Differentiable, R: TensorFlowFloatingPoint { - return pullback(at: x, y, in: f)(Tensor(1)) -} - -@inlinable -public func gradient( - at x: T, - _ y: U, - _ z: V, - in f: @differentiable (T, U, V) -> Tensor -) -> (T.CotangentVector, U.CotangentVector, V.CotangentVector) - where T: Differentiable, U: Differentiable, V: Differentiable, R: TensorFlowFloatingPoint { - return pullback(at: x, y, z, in: f)(Tensor(1)) -} - -// Gradient (curried) - -@inlinable -public func gradient( - of f: @escaping @differentiable (T) -> Tensor -) -> (T) -> T.CotangentVector where T: Differentiable, R: TensorFlowFloatingPoint { - return { x in gradient(at: x, in: f) } -} - -@inlinable -public func gradient( - of f: @escaping @differentiable (T, U) -> Tensor -) -> (T, U) -> (T.CotangentVector, U.CotangentVector) - where T: Differentiable, U: Differentiable, R: TensorFlowFloatingPoint { - return { x, y in gradient(at: x, y, in: f) } -} - -@inlinable -public func gradient( - of f: @escaping @differentiable (T, U, V) -> Tensor -) -> (T, U, V) -> (T.CotangentVector, U.CotangentVector, V.CotangentVector) - where T: Differentiable, U: Differentiable, V: Differentiable, R: TensorFlowFloatingPoint { - return { x, y, z in gradient(at: x, y, z, in: f) } -} diff --git a/Sources/DeepLearning/Helpers.swift b/Sources/DeepLearning/Helpers.swift index 4d9c0217b..86aec74bf 100644 --- a/Sources/DeepLearning/Helpers.swift +++ b/Sources/DeepLearning/Helpers.swift @@ -13,20 +13,12 @@ // limitations under the License. #if !COMPILING_TENSORFLOW_MODULE -@_exported import TensorFlow +import TensorFlow #endif -/// Returns a tensor with the same shape and scalars as the specified tensor. -@inlinable -@differentiable -public func identity(_ x: Tensor) -> Tensor { - return x -} - // `pow` is defined in Darwin/Glibc on `Float` and `Double`, but there doesn't exist a generic // version for `FloatingPoint`. // This is a manual definition. -@inlinable -func pow(_ x: T, _ y: T) -> T { +func pow(_ x: T, _ y: T) -> T { return T(pow(Double(x), Double(y))) } diff --git a/Sources/DeepLearning/Initializers.swift b/Sources/DeepLearning/Initializers.swift index 3ab3f5654..0655b817b 100644 --- a/Sources/DeepLearning/Initializers.swift +++ b/Sources/DeepLearning/Initializers.swift @@ -13,294 +13,9 @@ // limitations under the License. #if !COMPILING_TENSORFLOW_MODULE -import TensorFlow +@_exported import TensorFlow #endif -public extension Tensor { - /// Creates a tensor with the specified shape and a single, repeated scalar - /// value. - /// - /// - Parameters: - /// - shape: The dimensions of the tensor. - /// - repeatedValue: The scalar value to repeat. - @inlinable - @available(*, deprecated, renamed: "init(repeating:shape:)") - init(shape: TensorShape, repeating repeatedValue: Scalar) { - self.init(repeating: repeatedValue, shape: shape) - } - - /// Creates a tensor with the specified shape and a single, repeated scalar value. - /// - /// - Parameters: - /// - repeatedValue: The scalar value to repeat. - /// - shape: The dimensions of the tensor. - @inlinable - @differentiable( - vjp: _vjpInit(repeating:shape:) where Scalar: TensorFlowFloatingPoint) - init(repeating repeatedValue: Scalar, shape: TensorShape) { - self = Raw.fill( - dims: Tensor(shape.dimensions.map(Int32.init)), - value: Tensor(repeatedValue)) - } - - /// Creates a tensor by broadcasting the given scalar to a given rank with - /// all dimensions being 1. - @inlinable - // @differentiable(where Scalar: TensorFlowFloatingPoint) - init(broadcasting scalar: Scalar, rank: Int) { - self = Tensor(scalar).reshaped(to: TensorShape(repeating: 1, count: rank)) - } - - /// Creates a tensor of shape `[4]` from a 4-tuple. - /// - Note: This is intended for internal use, for example, to initialize a - /// tensor attribute from `convolved2D`'s `strides` argument. - @inlinable - internal init(_ scalars: (Scalar, Scalar, Scalar, Scalar)) { - self.init([scalars.0, scalars.1, scalars.2, scalars.3]) - } -} - -internal extension Tensor where Scalar: TensorFlowFloatingPoint { - @inlinable - static func _vjpInit( - repeating repeatedValue: Scalar, - shape: TensorShape - ) -> (Tensor, (Tensor) -> Scalar) { - return (Tensor(repeating: repeatedValue, shape: shape), { - $0.sum().scalarized() - }) - } -} - -//===------------------------------------------------------------------------------------------===// -// Casting -//===------------------------------------------------------------------------------------------===// - -public extension Tensor where Scalar: Numeric { - /// Perform an element-wise type conversion from a `Bool` tensor. - @inlinable - init(_ other: Tensor) { - self = Raw.cast(other) - } - - /// Perform an element-wise conversion from another `Tensor`. - @inlinable - @differentiable( - vjp: _vjpCast where Scalar: TensorFlowFloatingPoint, OtherScalar: TensorFlowFloatingPoint) - init(_ other: Tensor) { - self = Raw.cast(other) - } -} - -internal extension Tensor where Scalar: TensorFlowFloatingPoint { - @inlinable - static func _vjpCast( - _ other: Tensor - ) -> (Tensor, (Tensor) -> Tensor) { - return (Tensor(other), { v in Tensor(v) }) - } -} - -//===------------------------------------------------------------------------------------------===// -// Stacking / Concatenating / Tiling -//===------------------------------------------------------------------------------------------===// - -public extension Tensor { - /// Creates a tensor from an array of tensors (which may themselves be scalars). - @inlinable - @differentiable(where Scalar: TensorFlowFloatingPoint) - init(_ elements: [Tensor]) { - self = Tensor(stacking: elements) - } - - /// Stacks `tensors`, along the `axis` dimension, into a new tensor with rank one higher than - /// the current tensor and each tensor in `tensors`. - /// - /// Given that `tensors` all have shape `[A, B, C]`, and `tensors.count = N`, then: - /// - if `axis == 0` then the resulting tensor will have the shape `[N, A, B, C]`. - /// - if `axis == 1` then the resulting tensor will have the shape `[A, N, B, C]`. - /// - etc. - /// - /// For example: - /// ``` - /// // 'x' is [1, 4] - /// // 'y' is [2, 5] - /// // 'z' is [3, 6] - /// Tensor(stacking: [x, y, z]) // is [[1, 4], [2, 5], [3, 6]] - /// Tensor(stacking: [x, y, z], alongAxis: 1) // is [[1, 2, 3], [4, 5, 6]] - /// ``` - /// - /// This is the opposite of `Tensor.unstack(alongAxis:)`. - /// - /// - Parameters: - /// - tensors: Tensors to stack. - /// - axis: Dimension along which to stack. Negative values wrap around. - /// - /// - Precondition: All tensors must have the same shape. - /// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of the - /// provided tensors. - /// - /// - Returns: The stacked tensor. - @inlinable - @differentiable(vjp: _vjpStacking where Scalar: TensorFlowFloatingPoint) - init(stacking tensors: [Tensor], alongAxis axis: Int = 0) { - self = Raw.pack(tensors, axis: Int64(axis)) - } - - /// Concatenates `tensors` along the `axis` dimension. - /// - /// Given that `tensors[i].shape = [D0, D1, ... Daxis(i), ...Dn]`, then the concatenated result - /// has shape `[D0, D1, ... Raxis, ...Dn]`, where `Raxis = sum(Daxis(i))`. That is, the data - /// from the input tensors is joined along the `axis` dimension. - /// - /// For example: - /// ``` - /// // t1 is [[1, 2, 3], [4, 5, 6]] - /// // t2 is [[7, 8, 9], [10, 11, 12]] - /// Tensor(concatenating: [t1, t2]) // is [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]] - /// Tensor(concatenating: [t1, t2], alongAxis: 1) // is [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]] - /// - /// // t3 has shape [2, 3] - /// // t4 has shape [2, 3] - /// Tensor(concatenating: [t3, t4]) // has shape [4, 3] - /// Tensor(concatenating: [t3, t4], alongAxis: 1) // has shape [2, 6] - /// ``` - /// - /// - Note: If you are concatenating along a new axis consider using - /// `Tensor.init(stacking:alongAxis:)`. - /// - /// - Parameters: - /// - tensors: Tensors to concatenate. - /// - axis: Dimension along which to concatenate. Negative values wrap around. - /// - /// - Precondition: All tensors must have the same rank and all dimensions except `axis` - /// must be equal. - /// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of the - /// provided tensors. - /// - /// - Returns: The concatenated tensor. - @inlinable - @differentiable(vjp: _vjpConcatenating where Scalar: TensorFlowFloatingPoint) - init(concatenating tensors: [Tensor], alongAxis axis: Int = 0) { - precondition(tensors.count > 0) - self = Raw.concatV2(tensors, axis: Tensor(Int32(axis))) - } -} - -internal extension Tensor where Scalar: TensorFlowFloatingPoint { - @inlinable - static func _vjpStacking( - stacking tensors: [Tensor], - alongAxis axis: Int = 0 - ) -> (Tensor, (Tensor) -> Array.DifferentiableView) { - let result = Tensor(stacking: tensors, alongAxis: axis) - return (result, { v in - Array.DifferentiableView(v.unstack(alongAxis: axis)) - }) - } - - @inlinable - static func _vjpConcatenating( - concatenating tensors: [Tensor], - alongAxis axis: Int = 0 - ) -> (Tensor, (Tensor) -> Array.DifferentiableView) { - let result = Tensor(concatenating: tensors, alongAxis: axis) - let posAxis = axis < 0 ? axis + tensors[0].rank: axis - let sizes = Tensor(stacking: tensors.map { $0.shapeTensor[posAxis] }) - return (result, { [count = tensors.count] v in - if count == 1 { return Array.DifferentiableView([v]) } - let splits = v.split(sizes: sizes, alongAxis: posAxis) - return Array.DifferentiableView(splits) - }) - } -} - -//===------------------------------------------------------------------------------------------===// -// Numeric -//===------------------------------------------------------------------------------------------===// - -public extension Tensor where Scalar: Numeric { - /// Creates a tensor with all scalars set to zero. - /// - /// - Parameter shape: Shape of the tensor. - @inlinable - init(zeros shape: TensorShape) { - self.init(repeating: 0, shape: shape) - } - - /// Creates a tensor with all scalars set to one. - /// - /// - Parameter shape: Shape of the tensor. - @inlinable - init(ones shape: TensorShape) { - self.init(repeating: 1, shape: shape) - } - - /// Creates a 1-D tensor representing a sequence from a starting value to, but not including, - /// an end value, stepping by the specified amount. - /// - /// - Parameters: - /// - start: The starting value to use for the sequence. If the sequence - /// contains any values, the first one is `start`. - /// - end: An end value to limit the sequence. `end` is never an element of - /// the resulting sequence. - /// - stride: The amount to step by with each iteration. `stride` must be - /// positive. - /// - @inlinable - init(rangeFrom start: Scalar, to end: Scalar, stride: Scalar) { - self = Raw.range(start: Tensor(start), limit: Tensor(end), delta: Tensor(stride)) - } - - /// Creates a one-hot tensor at given indices. The locations represented by - /// `indices` take value `onValue` (`1` by default), while all other locations - /// take value `offValue` (`0` by default). If the input `indices` is rank - /// `n`, the new tensor will have rank `n+1`. The new axis is created at - /// dimension `axis` (by default, the new axis is appended at the end). - /// - /// If `indices` is a scalar, the new tensor's shape will be a vector of - /// length `depth`. - /// - /// If `indices` is a vector of length `features`, the output shape will be: - /// features x depth, if axis == -1 - /// depth x features, if axis == 0 - /// - /// If `indices` is a matrix (batch) with shape `[batch, features]`, the - /// output shape will be: - /// batch x features x depth, if axis == -1 - /// batch x depth x features, if axis == 1 - /// depth x batch x features, if axis == 0 - /// - /// - Parameters: - /// - indices: A `Tensor` of indices. - /// - depth: A scalar defining the depth of the one hot dimension. - /// - onValue: A scalar defining the value at the location referred to by - /// some index in `indices`. - /// - offValue: A scalar defining the value at a location that is not - /// referred to by any index in `indices`. - /// - axis: The axis to fill. The default is `-1`, a new inner-most axis. - /// - @inlinable - init( - oneHotAtIndices indices: Tensor, - depth: Int, - onValue: Scalar = 1, - offValue: Scalar = 0, - axis: Int = -1 - ) { - self = Raw.oneHot( - indices: indices, - depth: Tensor(Int32(depth)), - onValue: Tensor(onValue), - offValue: Tensor(offValue), - axis: Int64(axis)) - } -} - -//===------------------------------------------------------------------------------------------===// -// Random -//===------------------------------------------------------------------------------------------===// - public extension Tensor where Scalar == Int32 { /// Creates a tensor with the specified shape, randomly sampling scalar values /// from a discrete uniform distribution. @@ -309,10 +24,8 @@ public extension Tensor where Scalar == Int32 { /// - shape: The dimensions of the tensor. /// - generator: Random number generator to use. /// - init( - randomStandardUniform shape: TensorShape, - generator: inout G - ) { + init(randomStandardUniform shape: TensorShape, + generator: inout G) { let dist = UniformIntegerDistribution() var scalars: [Scalar] = [] for _ in 0 ..< shape.contiguousSize { @@ -381,10 +94,8 @@ public extension Tensor where Scalar: BinaryFloatingPoint, /// - shape: The dimensions of the tensor. /// - generator: Random number generator to use. /// - init( - randomUniform shape: TensorShape, - generator: inout G - ) { + init(randomUniform shape: TensorShape, + generator: inout G) { let dist = UniformFloatingPointDistribution() var scalars: [Scalar] = [] for _ in 0 ..< shape.contiguousSize { @@ -402,12 +113,10 @@ public extension Tensor where Scalar: BinaryFloatingPoint, /// - stddev: The standard deviation of the distribution. /// - generator: Random number generator to use. /// - init( - randomNormal shape: TensorShape, - mean: Scalar = 0, - stddev: Scalar = 1, - generator: inout G - ) { + init(randomNormal shape: TensorShape, + mean: Scalar = 0, + stddev: Scalar = 1, + generator: inout G) { let dist = NormalDistribution(mean: mean, standardDeviation: stddev) var scalars: [Scalar] = [] for _ in 0 ..< shape.contiguousSize { @@ -417,7 +126,7 @@ public extension Tensor where Scalar: BinaryFloatingPoint, } } -fileprivate extension Tensor where Scalar: BinaryFloatingPoint { +fileprivate extension Tensor where Scalar : BinaryFloatingPoint { private static func glorot( fromStandardUniform randomUniform: __shared Tensor, shape: __shared TensorShape @@ -441,11 +150,9 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { /// - Parameters: /// - shape: The dimensions of the tensor. /// - init( - glorotUniform shape: TensorShape, - seed: (Int64, Int64) = (Int64.random(in: Int64.min..(_ x: Tensor) -> Tensor { + return Raw.round(x) +} + +/// Returns a tensor with the same shape and scalars as the specified tensor. +@differentiable +public func identity(_ x: Tensor) -> Tensor { + return x +} + //===------------------------------------------------------------------------------------------===// // Normalization //===------------------------------------------------------------------------------------------===// public extension Tensor where Scalar: TensorFlowFloatingPoint { - /// Computes the batch normalized tensor along the specified axis. - /// - /// Specifically, returns `(self - mu) / (var + epsilon) * gamma + beta` where `mu` and `var` - /// are respectively the mean and variance of `self` along `axis`. - /// - /// - Parameters: - /// - axis: The batch dimension. - /// - offset: The offset, also known as beta. - /// - scale: The scale, also known as gamma. - /// - epsilon: A small value added to the denominator for numerical stability. - @inlinable - @differentiable(wrt: (self, offset, scale), vjp: _vjpBatchNormalized) - func batchNormalized( - alongAxis axis: Int, - offset: Tensor = Tensor(0), - scale: Tensor = Tensor(1), - epsilon: Scalar = 0.001 - ) -> Tensor { - let mean = self.mean(alongAxes: axis) - let squaredDiff: Tensor = Raw.squaredDifference(self, mean) - let variance = squaredDiff.mean(alongAxes: axis) - let inv = rsqrt(variance + epsilon) * scale - return self * inv + offset - mean * inv - } - // TODO: Verify that these calculations are correct. @inlinable internal func _vjpBatchNormalized( @@ -54,15 +40,18 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { scale: Tensor, epsilon: Scalar ) -> (Tensor, (Tensor) -> (Tensor, Tensor, Tensor)) { - let value = batchNormalized(alongAxis: axis, offset: offset, scale: scale, epsilon: epsilon) + let value = batchNormalized(alongAxis: axis, offset: offset, scale: scale, + epsilon: epsilon) return (value, { v in let mean = self.mean(alongAxes: axis) let squaredDiff: Tensor = Raw.squaredDifference(self, mean) let variance = squaredDiff.mean(alongAxes: axis) - let diff = self - mean + + let diff = self - mean let inv = rsqrt(variance + epsilon) let norm = diff * inv - let dNorm = v * scale + + let dNorm = v * scale let dVariance = -(dNorm * diff).sum(alongAxes: axis) / 2 * pow(inv, -3) let dMean = (-dNorm * inv).sum(alongAxes: axis) + dVariance * (-diff * 2).mean(alongAxes: axis) @@ -91,8 +80,9 @@ public extension Tensor where Scalar: BinaryFloatingPoint { /// stability. @inlinable @differentiable( - wrt: (self, offset, scale), - vjp: _vjpBatchNormalized where Scalar: TensorFlowFloatingPoint) + wrt: (self, offset, scale), vjp: _vjpBatchNormalized + where Scalar : TensorFlowFloatingPoint + ) func batchNormalized( alongAxis axis: Int, offset: Tensor = Tensor(0), @@ -182,12 +172,14 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { _ strides: (Int, Int, Int, Int), _ padding: Padding ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - let value = conv2DBackpropInput( - shape: shape, filter: filter, strides: strides, padding: padding) + let value = conv2DBackpropInput(shape: shape, filter: filter, strides: strides, + padding: padding) return (value, { v in - (self.conv2DBackpropFilter( - input: v, filterSizes: shape, strides: strides, padding: padding), - v.convolved2D(withFilter: filter, strides: strides, padding: padding)) + return ( + self.conv2DBackpropFilter(input: v, filterSizes: shape, strides: strides, + padding: padding), + v.convolved2D(withFilter: filter, strides: strides, padding: padding) + ) }) } @@ -198,12 +190,14 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { _ strides: (Int, Int, Int, Int), _ padding: Padding ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - let value = conv2DBackpropFilter( - input: input, filterSizes: filterSizes, strides: strides, padding: padding) + let value = conv2DBackpropFilter(input: input, filterSizes: filterSizes, + strides: strides, padding: padding) return (value, { v in - (self.conv2DBackpropInput( - shape: filterSizes, filter: v, strides: strides, padding: padding), - input.convolved2D(withFilter: v, strides: strides, padding: padding)) + return ( + self.conv2DBackpropInput(shape: filterSizes, filter: v, strides: strides, + padding: padding), + input.convolved2D(withFilter: v, strides: strides, padding: padding) + ) }) } @@ -213,12 +207,19 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { strides: (Int, Int, Int, Int), padding: Padding ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - let value = convolved2D(withFilter: filter, strides: strides, padding: padding) + let value = convolved2D(withFilter: filter, strides: strides, + padding: padding) return (value, { v in - (v.conv2DBackpropInput( - shape: self.shapeTensor, filter: filter, strides: strides, padding: padding), - v.conv2DBackpropFilter( - input: self, filterSizes: filter.shapeTensor, strides: strides, padding: padding)) + return ( + v.conv2DBackpropInput( + shape: self.shapeTensor, filter: filter, + strides: strides, padding: padding + ), + v.conv2DBackpropFilter( + input: self, filterSizes: filter.shapeTensor, + strides: strides, padding: padding + ) + ) }) } @@ -230,9 +231,10 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { ) -> (Tensor, (Tensor) -> Tensor) { // TODO: Currently this is not higher order differentiable. Redefine in // closed form. - let value = maxPooled(kernelSize: kernelSize, strides: strides, padding: padding) + let value = maxPooled(kernelSize: kernelSize, strides: strides, + padding: padding) return (value, { v in - Raw.maxPoolGradV2( + return Raw.maxPoolGradV2( origInput: self, origOutput: value, grad: v, @@ -240,7 +242,8 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { Int32(kernelSize.2), Int32(kernelSize.3)]), strides: Tensor([Int32(strides.0), Int32(strides.1), Int32(strides.2), Int32(strides.3)]), - padding: padding.raw) + padding: padding.raw + ) }) } @@ -252,15 +255,17 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { ) -> (Tensor, (Tensor) -> Tensor) { // TODO: Currently this is not higher order differentiable. Redefine in // closed form. - let value = averagePooled(kernelSize: kernelSize, strides: strides, padding: padding) + let value = averagePooled(kernelSize: kernelSize, strides: strides, + padding: padding) return (value, { v in - Raw.avgPoolGrad( + return Raw.avgPoolGrad( origInputShape: self.shapeTensor, grad: v, ksize: [Int32(kernelSize.0), Int32(kernelSize.1), Int32(kernelSize.2), Int32(kernelSize.3)], strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2), Int32(strides.3)], - padding: padding.raw) + padding: padding.raw + ) }) } } @@ -276,10 +281,11 @@ public extension Tensor where Scalar: FloatingPoint { /// - padding: The padding for the operation. /// - Precondition: `self` must have rank 4. /// - Precondition: `filter` must have rank 4. - @inlinable + @inlinable @inline(__always) @differentiable( - wrt: (self, filter), - vjp: _vjpConvolved2D where Scalar: TensorFlowFloatingPoint) + wrt: (self, filter), vjp: _vjpConvolved2D + where Scalar: TensorFlowFloatingPoint + ) func convolved2D( withFilter filter: Tensor, strides: (Int, Int, Int, Int), @@ -301,10 +307,11 @@ public extension Tensor where Scalar: FloatingPoint { /// - strides: The strides of the sliding filter for each dimension of the /// input. /// - padding: The padding for the operation. - @inlinable + @inlinable @inline(__always) @differentiable( - wrt: self, - vjp: _vjpMaxPooled(kernelSize:strides:padding:) where Scalar: TensorFlowFloatingPoint) + wrt: self, vjp: _vjpMaxPooled(kernelSize:strides:padding:) + where Scalar : TensorFlowFloatingPoint + ) func maxPooled( kernelSize: (Int, Int, Int, Int), strides: (Int, Int, Int, Int), @@ -327,10 +334,11 @@ public extension Tensor where Scalar: FloatingPoint { /// - strides: The strides of the sliding filter for each dimension of the /// input. /// - padding: The padding for the operation. - @inlinable + @inlinable @inline(__always) @differentiable( - wrt: self, - vjp: _vjpAveragePooled(kernelSize:strides:padding:) where Scalar: TensorFlowFloatingPoint) + wrt: self, vjp: _vjpAveragePooled(kernelSize:strides:padding:) + where Scalar : TensorFlowFloatingPoint + ) func averagePooled( kernelSize: (Int, Int, Int, Int), strides: (Int, Int, Int, Int), diff --git a/Sources/DeepLearning/Operators/Basic.swift b/Sources/DeepLearning/Operators/Basic.swift deleted file mode 100644 index 54f60ec3d..000000000 --- a/Sources/DeepLearning/Operators/Basic.swift +++ /dev/null @@ -1,716 +0,0 @@ -// Copyright 2018 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#if !COMPILING_TENSORFLOW_MODULE -import TensorFlow -#endif - -//===------------------------------------------------------------------------------------------===// -// Shape Transformations -//===------------------------------------------------------------------------------------------===// - -public extension TensorFlowScalar { - /// Convert to a tensor with the specified rank, with all dimensions equal to 1. - @inlinable - func makeTensor(rank: Int) -> Tensor { - return Tensor(repeating: self, shape: TensorShape(rank)) - } -} - -public extension Tensor { - /// Unpacks the given dimension of a rank-`R` tensor into multiple rank-`(R-1)` tensors. Unpacks - /// `N` tensors from this tensor by chipping it along the `axis` dimension, where `N` is - /// inferred from this tensor's shape. For example, given a tensor with shape `[A, B, C, D]`: - /// - /// - If `axis == 0` then the `i`-th tensor in the returned array is the slice - /// `self[i, :, :, :]` and each tensor in that array will have shape `[B, C, D]`. - /// (Note that the dimension unpacked along is gone, unlike - /// `Tensor.split(numSplits:alongAxis)`, or `Tensor.split(sizes:alongAxis)`). - /// - If `axis == 1` then the `i`-th tensor in the returned array is the slice - /// `value[:, i, :, :]` and each tensor in that array will have shape `[A, C, D]`. - /// - Etc. - /// - /// This is the opposite of `Tensor.init(stacking:alongAxis:)`. - /// - /// - Parameters: - /// - axis: Dimension along which to unstack. Negative values wrap around. - /// - /// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of the - /// provided tensors. - /// - /// - Returns: Array containing the unstacked tensors. - @inlinable - @differentiable(vjp: _vjpUnstack(alongAxis:) where Scalar: TensorFlowFloatingPoint) - func unstack(alongAxis axis: Int = 0) -> [Tensor] { - return Raw.unpack(value: self, num: Int64(shape[axis]), axis: Int64(axis)) - } - - /// Splits a tensor into multiple tensors. The tensor is split along dimension `axis` into - /// `numSplits` smaller tensors. This requires that `numSplits` evenly divides `shape[axis]`. - /// - /// For example: - /// ``` - /// // 'value' is a tensor with shape [5, 30] - /// // Split 'value' into 3 tensors along dimension 1: - /// let parts = value.split(numSplits: 3, alongAxis: 1) - /// parts[0] // has shape [5, 10] - /// parts[1] // has shape [5, 10] - /// parts[2] // has shape [5, 10] - /// ``` - /// - /// - Parameters: - /// - numSplits: Number of splits to create. - /// - axis: Dimension along which to split this tensor. Negative values wrap around. - /// - /// - Precondition: `numSplits` must divide the size of dimension `axis` evenly. - /// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of the - /// provided tensors. - /// - /// - Returns: Array containing the tensors parts. - @inlinable - @differentiable(vjp: _vjpSplit(numSplits:alongAxis:) where Scalar: TensorFlowFloatingPoint) - func split(numSplits: Int, alongAxis axis: Int = 0) -> [Tensor] { - return Raw.split( - splitDim: Tensor(Int32(axis)), value: self, numSplit: Int64(numSplits)) - } - - /// Splits a tensor into multiple tensors. The tensor is split into `sizes.shape[0]` pieces. - /// The shape of the `i`-th piece has the same shape as this tensor except along dimension - /// `axis` where the size is `sizes[i]`. - /// - /// For example: - /// ``` - /// // 'value' is a tensor with shape [5, 30] - /// // Split 'value' into 3 tensors with sizes [4, 15, 11] along dimension 1: - /// let parts = value.split(sizes: Tensor([4, 15, 11]), alongAxis: 1) - /// parts[0] // has shape [5, 4] - /// parts[1] // has shape [5, 15] - /// parts[2] // has shape [5, 11] - /// ``` - /// - /// - Parameters: - /// - sizes: 1-D tensor containing the size of each split. - /// - axis: Dimension along which to split this tensor. Negative values wrap around. - /// - /// - Precondition: The values in `sizes` must add up to the size of dimension `axis`. - /// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of the - /// provided tensors. - /// - /// - Returns: Array containing the tensors parts. - @inlinable - @differentiable( - wrt: self, - vjp: _vjpSplit(sizes:alongAxis:) where Scalar: TensorFlowFloatingPoint) - func split(sizes: Tensor, alongAxis axis: Int = 0) -> [Tensor] { - return Raw.splitV( - value: self, - sizeSplits: sizes, - splitDim: Tensor(Int32(axis)), - numSplit: Int64(sizes.shape[0])) - } - - /// Reshape to the shape of the specified `Tensor`. - /// - Precondition: The number of scalars matches the new shape. - @inlinable - @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) - func reshaped(like other: Tensor) -> Tensor { - return reshaped(toShape: other.shapeTensor) - } - - /// Reshape to the specified shape. - /// - Precondition: The number of scalars matches the new shape. - @inlinable - @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) - func reshaped(to newShape: TensorShape) -> Tensor { - // TODO(TF-433): Remove workaround for differentiating `map`. - return reshaped(toShape: Tensor({newShape.dimensions.map(Int32.init)}())) - } - - /// Reshape to the specified `Tensor` representing a shape. - /// - Precondition: The number of scalars matches the new shape. - @inlinable - @differentiable( - wrt: self, - vjp: _vjpReshaped(toShape:) where Scalar: TensorFlowFloatingPoint) - func reshaped(toShape newShape: Tensor) -> Tensor { - return Raw.reshape(self, shape: newShape) - } - - /// Return a copy of the tensor collapsed into a 1-D `Tensor`, in row-major order. - @inlinable - @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) - func flattened() -> Tensor { - return reshaped(to: [-1]) - } - - /// Returns a shape-expanded `Tensor`, with a dimension of 1 inserted at the specified shape - /// indices. - @inlinable - @differentiable(wrt: self where Scalar : TensorFlowFloatingPoint) - func expandingShape(at axes: Int...) -> Tensor { - return expandingShape(at: axes) - } - - /// Returns a shape-expanded `Tensor`, with a dimension of 1 inserted at the - /// specified shape indices. - @inlinable - @differentiable(wrt: self, vjp: _vjpExpandingShape(at:) where Scalar: TensorFlowFloatingPoint) - func expandingShape(at axes: [Int]) -> Tensor { - var result = self - for i in axes { result = Raw.expandDims(result, dim: Tensor(Int32(i))) } - return result - } - - /// Returns a rank-lifted `Tensor` with a leading dimension of 1. - @inlinable - @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) - func rankLifted() -> Tensor { - return expandingShape(at: 0) - } - - /// Remove the specified dimensions of size 1 from the shape of a tensor. If no dimensions are - /// specified, then all dimensions of size 1 will be removed. - @inlinable - @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) - func squeezingShape(at axes: Int...) -> Tensor { - return squeezingShape(at: axes) - } - - /// Remove the specified dimensions of size 1 from the shape of a tensor. If no dimensions are - /// specified, then all dimensions of size 1 will be removed. - @inlinable - @differentiable(wrt: self, vjp: _vjpSqueezingShape(at:) where Scalar: TensorFlowFloatingPoint) - func squeezingShape(at axes: [Int]) -> Tensor { - return Raw.squeeze(self, squeezeDims: axes.map(Int32.init)) - } -} - -internal extension Tensor where Scalar: TensorFlowFloatingPoint { - @inlinable - func _vjpUnstack( - alongAxis axis: Int = 0 - ) -> ([Tensor], (Array.CotangentVector) -> Tensor) { - let result = unstack(alongAxis: axis) - return (result, { v in Tensor(stacking: v.base, alongAxis: axis) }) - } - - @inlinable - func _vjpSplit( - numSplits: Int, - alongAxis axis: Int = 0 - ) -> ([Tensor], (Array.CotangentVector) -> Tensor) { - let result = split(numSplits: numSplits, alongAxis: axis) - return (result, { v in Tensor(concatenating: v.base, alongAxis: axis) }) - } - - @inlinable - func _vjpSplit( - sizes: Tensor, - alongAxis axis: Int = 0 - ) -> ([Tensor], (Array.CotangentVector) -> Tensor) { - let result = split(sizes: sizes, alongAxis: axis) - return (result, { v in Tensor(concatenating: v.base, alongAxis: axis) }) - } - - @inlinable - func _vjpReshaped(toShape newShape: Tensor) -> (Tensor, (Tensor) -> Tensor) { - let value = reshaped(toShape: newShape) - return (value, { [shape = shapeTensor] v in v.reshaped(toShape: shape) }) - } - - @inlinable - func _vjpExpandingShape(at axes: [Int]) -> (Tensor, (Tensor) -> Tensor) { - let value = self.expandingShape(at: axes) - return (value, { v in v.squeezingShape(at: axes) }) - } - - @inlinable - func _vjpSqueezingShape(at axes: [Int]) -> (Tensor, (Tensor) -> Tensor) { - let value = squeezingShape(at: axes) - return (value, { [shape = shapeTensor] v in v.reshaped(toShape: shape) }) - } -} - -//===------------------------------------------------------------------------------------------===// -// Other Tensor Transformations -//===------------------------------------------------------------------------------------------===// - -infix operator ++: AdditionPrecedence - -public extension Tensor { - /// Returns a transposed tensor, with dimensions permuted in the specified order. - @inlinable - @differentiable( - wrt: self, - vjp: _vjpTransposed(withPermutations:) where Scalar: TensorFlowFloatingPoint) - func transposed(withPermutations permutations: Tensor) -> Tensor { - return Raw.transpose(self, perm: permutations) - } - - /// Returns a transposed tensor, with dimensions permuted in the specified order. - @inlinable - @differentiable( - wrt: self, - vjp: _vjpTransposed(withPermutations:) where Scalar: TensorFlowFloatingPoint) - func transposed(withPermutations permutations: [Int]) -> Tensor { - let permutations = permutations.map(Int32.init) - return transposed(withPermutations: Tensor(permutations)) - } - - /// Returns a transposed tensor, with dimensions permuted in the specified order. - @inlinable - @differentiable( - wrt: self, vjp: _vjpTransposed(withPermutations:) where Scalar: TensorFlowFloatingPoint) - func transposed(withPermutations permutations: Int...) -> Tensor { - return transposed(withPermutations: permutations) - } - - /// Returns a transposed tensor, with dimensions permuted in reverse order. - @inlinable - @differentiable(wrt: self, vjp: _vjpTransposed() where Scalar: TensorFlowFloatingPoint) - func transposed() -> Tensor { - let defaultPermutations = rankTensor - 1 - Tensor( - rangeFrom: 0, to: Int32(rank), stride: 1) - return transposed(withPermutations: Tensor(defaultPermutations)) - } - - /// Concatenates tensors along the specified axis. - /// - Precondition: The tensors must have the same dimensions, except for the - /// specified axis. - /// - Precondition: The axis must be in the range `-rank.. Tensor { - return Tensor(concatenating: [self, other], alongAxis: axis) - } - - /// Concatenation operator. - /// - Note: `++` is a custom operator that does not exist in Swift, but does - /// in Haskell/Scala. Its addition is not an insignificant language change - /// and may be controversial. The existence/naming of `++` will be discussed - /// during a later API design phase. - @inlinable - @differentiable(where Scalar: TensorFlowFloatingPoint) - static func ++ (lhs: Tensor, rhs: Tensor) -> Tensor { - return lhs.concatenated(with: rhs) - } -} - -internal extension Tensor where Scalar: TensorFlowFloatingPoint { - @inlinable - func _vjpTransposed( - withPermutations permutations: Tensor - ) -> (Tensor, (Tensor) -> Tensor) { - let value = transposed(withPermutations: permutations) - return (value, { $0.transposed(withPermutations: permutations) }) - } - - @inlinable - func _vjpTransposed(withPermutations permutations: [Int]) -> (Tensor, (Tensor) -> Tensor) { - let value = transposed(withPermutations: permutations) - return (value, { $0.transposed(withPermutations: permutations) }) - } - - @inlinable - func _vjpTransposed(withPermutations permutations: Int...) -> (Tensor, (Tensor) -> Tensor) { - let value = transposed(withPermutations: permutations) - return (value, { $0.transposed(withPermutations: permutations) }) - } - - @inlinable - func _vjpTransposed() -> (Tensor, (Tensor) -> Tensor) { - return (transposed(), { $0.transposed() }) - } - - @inlinable - func _vjpConcatenated( - with other: Tensor, - alongAxis axis: Int - ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - let idx = axis < 0 ? axis + rank: axis - let splits = Tensor([shapeTensor[idx], other.shapeTensor[idx]]) - return (concatenated(with: other, alongAxis: axis), { result in - let gradients = result.split(sizes: splits, alongAxis: axis) - return (gradients[0], gradients[1]) - }) - } -} - -//===------------------------------------------------------------------------------------------===// -// Broadcasting -//===------------------------------------------------------------------------------------------===// - -// TODO: What about precedence? Also, why is this operator meaningful for broadcasting? -infix operator .= - -public extension Tensor { - @inlinable - func broadcast(toShape shape: Tensor) -> Tensor { - return Raw.broadcastTo(self, shape: shape) - } - - @inlinable - func broadcast(to shape: TensorShape) -> Tensor { - return broadcast(toShape: Tensor(shape.dimensions.map(Int32.init))) - } - - /// Broadcast to the same shape as the specified `Tensor`. - /// - Precondition: The specified shape must be compatible for broadcasting. - @inlinable - func broadcast(like other: Tensor) -> Tensor { - return broadcast(toShape: other.shapeTensor) - } - - @inlinable - static func .= (lhs: inout Tensor, rhs: Tensor) { - lhs = rhs.broadcast(like: lhs) - } -} - -// TODO: Why is this limited only to numeric data types whereas `broadcast` is not? -public extension Tensor where Scalar: Numeric { - @inlinable - func unbroadcast(toShape otherShape: Tensor) -> Tensor { - let rankDiff = (rankTensor - otherShape.scalarCountTensor).rankLifted() - let ones: Tensor = Raw.fill(dims: rankDiff, value: Tensor(1)) - let paddedShape = ones ++ otherShape - let nonEqualIndices = paddedShape .!= shapeTensor - let broadcastIndices = Raw.where_(nonEqualIndices).flattened() - let unbroadcasted: Tensor = Raw.sum( - self, reductionIndices: Tensor(broadcastIndices), keepDims: false) - return Raw.reshape(unbroadcasted, shape: otherShape) - } - - @inlinable - func unbroadcast(like other: Tensor) -> Tensor { - return unbroadcast(toShape: other.shapeTensor) - } - - @inlinable - func unbroadcast(to shape: TensorShape) -> Tensor { - return unbroadcast(toShape: Tensor(shape.dimensions.map(Int32.init))) - } -} - -//===------------------------------------------------------------------------------------------===// -// Padding -//===------------------------------------------------------------------------------------------===// - -public extension Tensor where Scalar: Numeric { - /// Returns a padded tensor according to the specified padding sizes. - @inlinable - func padded(forSizes sizes: [(before: Int, after: Int)], with value: Scalar = 0) -> Tensor { - let paddings = Tensor( - shape: [sizes.count, 2], - scalars: sizes.flatMap { [Int32($0.before), Int32($0.after)] }) - return Raw.padV2(self, paddings: paddings, constantValues: Tensor(value)) - } -} - -//===------------------------------------------------------------------------------------------===// -// Indexing and Slicing -//===------------------------------------------------------------------------------------------===// - -// TODO: Negative indexing and strides syntax. - -public extension Tensor { - /// Extracts a slice from the tensor defined by lower and upper bounds for - /// each dimension. - /// - /// - Parameter lowerBounds: The lower bounds at each dimension. - /// - Parameter upperBounds: The upper bounds at each dimension. - @inlinable - @differentiable(wrt: self) - func slice(lowerBounds: [Int], upperBounds: [Int]) -> Tensor { - // TODO: Precondition `lowerBounds.count == upperBounds.count`, - // preferably in graph. - // TODO: Differentiating control flow is not supported yet, thus the thunks. - let lowerBoundsTensor = Tensor({lowerBounds.map(Int32.init)}()) - let upperBoundsTensor = Tensor({upperBounds.map(Int32.init)}()) - return slice(lowerBounds: lowerBoundsTensor, sizes: upperBoundsTensor - lowerBoundsTensor) - } - - @inlinable - @differentiable(wrt: self, vjp: _vjpSlice) - func slice(lowerBounds: Tensor, sizes: Tensor) -> Tensor { - return Raw.slice(self, begin: lowerBounds, size: sizes) - } - - @inlinable - internal func _vjpSlice( - lowerBounds: Tensor, - sizes: Tensor - ) -> (Tensor, (Tensor) -> Tensor) { - let value = slice(lowerBounds: lowerBounds, sizes: sizes) - let afterPaddings = shapeTensor - value.shapeTensor - lowerBounds - return (value, { [after = afterPaddings] v in - let beforePaddings = lowerBounds.expandingShape(at: 1) - let afterPaddings = after.expandingShape(at: 1) - let paddings = Tensor( - concatenating: [beforePaddings, afterPaddings], alongAxis: 1) - return Raw.pad(v, paddings: paddings) - }) - } -} - -public enum TensorRange: TensorRangeExpression { - case ellipsis - case newAxis - case squeezeAxis - case index(Int) - case range(Range, stride: Int) - case closedRange(ClosedRange, stride: Int) - case partialRangeFrom(PartialRangeFrom, stride: Int) - case partialRangeUpTo(PartialRangeUpTo, stride: Int) - case partialRangeThrough(PartialRangeThrough, stride: Int) - - public var tensorRange: TensorRange { return self } -} - -extension TensorRange: Equatable { - public static func == (lhs: TensorRange, rhs: TensorRange) -> Bool { - switch (lhs, rhs) { - case (.ellipsis, .ellipsis), - (.newAxis, .newAxis), - (.squeezeAxis, .squeezeAxis): - return true - case (let .index(i1), let .index(i2)): return i1 == i2 - case (let .range(r1, s1), let .range(r2, s2)): return r1 == r2 && s1 == s2 - case (let .closedRange(r1, s1), let .closedRange(r2, s2)): - return r1 == r2 && s1 == s2 - case (let .partialRangeFrom(r1, s1), let .partialRangeFrom(r2, s2)): - return r1.lowerBound == r2.lowerBound && s1 == s2 - case (let .partialRangeUpTo(r1, s1), let .partialRangeUpTo(r2, s2)): - return r1.upperBound == r2.upperBound && s1 == s2 - case (let .partialRangeThrough(r1, s1), let .partialRangeThrough(r2, s2)): - return r1.upperBound == r2.upperBound && s1 == s2 - default: return false - } - } -} - -public protocol TensorRangeExpression { - var tensorRange: TensorRange { get } -} - -// TODO: Cannot extend non-nominal type 'UnboundedRange'. -// extension UnboundedRange: TensorRangeExpression { -// public var tensorRange: TensorRange { return .ellipsis } -// } - -extension Int: TensorRangeExpression { - public var tensorRange: TensorRange { return .index(self) } -} - -extension Range: TensorRangeExpression where Bound == Int { - public var tensorRange: TensorRange { - return .range(self, stride: 1) - } -} - -extension ClosedRange: TensorRangeExpression where Bound == Int { - public var tensorRange: TensorRange { - return .closedRange(self, stride: 1) - } -} - -extension PartialRangeFrom: TensorRangeExpression where Bound == Int { - public var tensorRange: TensorRange { - return .partialRangeFrom(self, stride: 1) - } -} - -extension PartialRangeUpTo: TensorRangeExpression where Bound == Int { - public var tensorRange: TensorRange { - return .partialRangeUpTo(self, stride: 1) - } -} - -extension PartialRangeThrough: TensorRangeExpression where Bound == Int { - public var tensorRange: TensorRange { - return .partialRangeThrough(self, stride: 1) - } -} - -infix operator ..: StridedRangeFormationPrecedence -precedencegroup StridedRangeFormationPrecedence { - associativity: left - higherThan: CastingPrecedence - lowerThan: RangeFormationPrecedence -} - -public extension Range where Bound == Int { - static func .. (range: Range, stride: Int) -> TensorRange { - return .range(range, stride: stride) - } -} - -public extension ClosedRange where Bound == Int { - static func .. (range: ClosedRange, stride: Int) -> TensorRange { - return .closedRange(range, stride: stride) - } -} - -public extension PartialRangeFrom where Bound == Int { - static func .. (range: PartialRangeFrom, stride: Int) -> TensorRange { - return .partialRangeFrom(range, stride: stride) - } -} - -public extension PartialRangeUpTo where Bound == Int { - static func .. (range: PartialRangeUpTo, stride: Int) -> TensorRange { - return .partialRangeUpTo(range, stride: stride) - } -} - -public extension PartialRangeThrough where Bound == Int { - static func .. (range: PartialRangeThrough, stride: Int) -> TensorRange { - return .partialRangeThrough(range, stride: stride) - } -} - -public extension Tensor { - @_fixed_layout @usableFromInline - internal struct IndexPath { - @usableFromInline - let begin, end, strides: Tensor - - @usableFromInline - let beginMask, endMask, ellipsisMask, newAxisMask, squeezeAxisMask: Int64 - - @inlinable - public init( - begin: Tensor, end: Tensor, strides: Tensor, - beginMask: Int64, endMask: Int64, ellipsisMask: Int64, newAxisMask: Int64, - squeezeAxisMask: Int64 - ) { - self.begin = begin - self.end = end - self.strides = strides - self.beginMask = beginMask - self.endMask = endMask - self.ellipsisMask = ellipsisMask - self.newAxisMask = newAxisMask - self.squeezeAxisMask = squeezeAxisMask - } - } - - @inlinable - @differentiable(wrt: self, vjp: _vjpSubscript) - internal subscript(_ indexPath: IndexPath) -> Tensor { - get { - return Raw.stridedSlice( - self, begin: indexPath.begin, end: indexPath.end, - strides: indexPath.strides, beginMask: indexPath.beginMask, - endMask: indexPath.endMask, ellipsisMask: indexPath.ellipsisMask, - newAxisMask: indexPath.newAxisMask, - shrinkAxisMask: indexPath.squeezeAxisMask) - } - set { - self = Raw.tensorStridedSliceUpdate( - self, begin: indexPath.begin, end: indexPath.end, - strides: indexPath.strides, value: newValue, - beginMask: indexPath.beginMask, endMask: indexPath.endMask, - ellipsisMask: indexPath.ellipsisMask, - newAxisMask: indexPath.newAxisMask, - shrinkAxisMask: indexPath.squeezeAxisMask) - } - } - - @inlinable - // TODO: @differentiable(wrt: self) - subscript(_ ranges: TensorRangeExpression...) -> Tensor { - get { - return self[IndexPath(ranges.map { $0.tensorRange })] - } - set { - self[IndexPath(ranges.map { $0.tensorRange })] = newValue - } - } - - @usableFromInline - internal func _vjpSubscript( - _ indexPath: IndexPath - ) -> (Tensor, (Tensor) -> Tensor) { - return (self[indexPath], { [shape = shapeTensor] v in - Raw.stridedSliceGrad( - shape: shape, begin: indexPath.begin, end: indexPath.end, - strides: indexPath.strides, dy: v, beginMask: indexPath.beginMask, - endMask: indexPath.endMask, ellipsisMask: indexPath.ellipsisMask, - newAxisMask: indexPath.newAxisMask, - shrinkAxisMask: indexPath.squeezeAxisMask) - }) - } -} - -internal extension Tensor.IndexPath { - @inlinable - init(_ ranges: [TensorRange]) { - precondition(!ranges.isEmpty, "The tensor range collection cannot be empty.") - precondition(ranges.count { $0 == TensorRange.ellipsis } < 2, - "Only one ellipsis is allowed per tensor range collection.") - - var begin = [Int32](repeating: 0, count: ranges.count) - var end = [Int32](repeating: 0, count: ranges.count) - var strides = [Int32](repeating: 1, count: ranges.count) - var beginMask: Int64 = 0 - var endMask: Int64 = 0 - var ellipsisMask: Int64 = 0 - var newAxisMask: Int64 = 0 - var squeezeAxisMask: Int64 = 0 - for (i, index) in ranges.enumerated() { - switch index { - case .ellipsis: ellipsisMask |= 1 << i - case .newAxis: newAxisMask |= 1 << i - case .squeezeAxis: squeezeAxisMask |= 1 << i - case .index(let index): - begin[i] = Int32(index) - end[i] = Int32(index) + 1 - squeezeAxisMask |= 1 << i - case .range(let range, let stride): - begin[i] = Int32(range.lowerBound) - end[i] = Int32(range.upperBound) - strides[i] = Int32(stride) - case .closedRange(let range, let stride): - begin[i] = Int32(range.lowerBound) - switch Int32(range.upperBound) { - case -1: endMask |= 1 << i - case let u: end[i] = u + 1 - } - strides[i] = Int32(stride) - case .partialRangeFrom(let range, let stride): - begin[i] = Int32(range.lowerBound) - strides[i] = Int32(stride) - endMask |= 1 << i - case .partialRangeUpTo(let range, let stride): - end[i] = Int32(range.upperBound) - strides[i] = Int32(stride) - beginMask |= 1 << i - case .partialRangeThrough(let range, let stride): - end[i] = Int32(range.upperBound) + 1 - strides[i] = Int32(stride) - beginMask |= 1 << i - } - } - - self.begin = Tensor(begin) - self.end = Tensor(end) - self.strides = Tensor(strides) - self.beginMask = beginMask - self.endMask = endMask - self.ellipsisMask = ellipsisMask - self.newAxisMask = newAxisMask - self.squeezeAxisMask = squeezeAxisMask - } -} diff --git a/Sources/DeepLearning/Operators/Comparison.swift b/Sources/DeepLearning/Operators/Comparison.swift deleted file mode 100644 index 02bf5fadf..000000000 --- a/Sources/DeepLearning/Operators/Comparison.swift +++ /dev/null @@ -1,237 +0,0 @@ -// Copyright 2018 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#if !COMPILING_TENSORFLOW_MODULE -import TensorFlow -#endif - -infix operator .<: ComparisonPrecedence -infix operator .<=: ComparisonPrecedence -infix operator .>=: ComparisonPrecedence -infix operator .>: ComparisonPrecedence -infix operator .==: ComparisonPrecedence -infix operator .!=: ComparisonPrecedence - -public extension Tensor where Scalar: Numeric & Comparable { - /// Computes `lhs < rhs` element-wise and returns a `Tensor` of Boolean /// scalars. - @inlinable - static func .< (lhs: Tensor, rhs: Tensor) -> Tensor { - return Raw.less(lhs, rhs) - } - - /// Computes `lhs <= rhs` element-wise and returns a `Tensor` of Boolean scalars. - @inlinable - static func .<= (lhs: Tensor, rhs: Tensor) -> Tensor { - return Raw.lessEqual(lhs, rhs) - } - - /// Computes `lhs > rhs` element-wise and returns a `Tensor` of Boolean scalars. - @inlinable - static func .> (lhs: Tensor, rhs: Tensor) -> Tensor { - return Raw.greater(lhs, rhs) - } - - /// Computes `lhs >= rhs` element-wise and returns a `Tensor` of Boolean scalars. - @inlinable - static func .>= (lhs: Tensor, rhs: Tensor) -> Tensor { - return Raw.greaterEqual(lhs, rhs) - } - - /// Computes `lhs < rhs` element-wise and returns a `Tensor` of Boolean scalars. - /// - Note: `.<` supports broadcasting. - @inlinable - static func .< (lhs: Scalar, rhs: Tensor) -> Tensor { - return Raw.less(Tensor(lhs), rhs) - } - - /// Computes `lhs <= rhs` element-wise and returns a `Tensor` of Boolean scalars. - /// - Note: `.<=` supports broadcasting. - @inlinable - static func .<= (lhs: Scalar, rhs: Tensor) -> Tensor { - return Raw.lessEqual(Tensor(lhs), rhs) - } - - /// Computes `lhs > rhs` element-wise and returns a `Tensor` of Boolean scalars. - /// - Note: `.>` supports broadcasting. - @inlinable - static func .> (lhs: Scalar, rhs: Tensor) -> Tensor { - return Raw.greater(Tensor(lhs), rhs) - } - - /// Computes `lhs >= rhs` element-wise and returns a `Tensor` of Boolean scalars. - /// - Note: `.>=` supports broadcasting. - @inlinable - static func .>= (lhs: Scalar, rhs: Tensor) -> Tensor { - return Raw.greaterEqual(Tensor(lhs), rhs) - } - - /// Computes `lhs < rhs` element-wise and returns a `Tensor` of Boolean scalars. - /// - Note: `.<` supports broadcasting. - @inlinable - static func .< (lhs: Tensor, rhs: Scalar) -> Tensor { - return Raw.less(lhs, Tensor(rhs)) - } - - /// Computes `lhs <= rhs` element-wise and returns a `Tensor` of Boolean scalars. - /// - Note: `.<=` supports broadcasting. - @inlinable - static func .<= (lhs: Tensor, rhs: Scalar) -> Tensor { - return Raw.lessEqual(lhs, Tensor(rhs)) - } - - /// Computes `lhs > rhs` element-wise and returns a `Tensor` of Boolean scalars. - /// - Note: `.>` supports broadcasting. - @inlinable - static func .> (lhs: Tensor, rhs: Scalar) -> Tensor { - return Raw.greater(lhs, Tensor(rhs)) - } - - /// Computes `lhs >= rhs` element-wise and returns a `Tensor` of Boolean scalars. - /// - Note: `.>=` supports broadcasting. - @inlinable - static func .>= (lhs: Tensor, rhs: Scalar) -> Tensor { - return Raw.greaterEqual(lhs, Tensor(rhs)) - } -} - -extension Tensor: Equatable where Scalar: Equatable { - @inlinable - public static func == (lhs: Tensor, rhs: Tensor) -> Bool { - return (lhs .== rhs).all() - } - - @inlinable - public static func != (lhs: Tensor, rhs: Tensor) -> Bool { - return (lhs .== rhs).any() - } -} - -extension Tensor: Comparable where Scalar: Numeric & Comparable { - /// Returns a Boolean value indicating whether the value of the first argument is - /// lexicographically less than that of the second argument. - @inlinable - public static func < (lhs: Tensor, rhs: Tensor) -> Bool { - return (lhs .< rhs).all() - } - - /// Returns a Boolean value indicating whether the value of the first argument is - /// lexicographically less than or equal to that of the second argument. - @inlinable - public static func <= (lhs: Tensor, rhs: Tensor) -> Bool { - return (lhs .<= rhs).all() - } - - /// Returns a Boolean value indicating whether the value of the first argument is - /// lexicographically greater than that of the second argument. - @inlinable - public static func > (lhs: Tensor, rhs: Tensor) -> Bool { - return (lhs .> rhs).all() - } - - /// Returns a Boolean value indicating whether the value of the first argument is - /// lexicographically greater than or equal to that of the second argument. - @inlinable - public static func >= (lhs: Tensor, rhs: Tensor) -> Bool { - return (lhs .>= rhs).all() - } -} - -public extension Tensor where Scalar: Numeric & Comparable { - /// Returns a Boolean value indicating whether the value of the first argument is - /// lexicographically less than that of the second argument. - @inlinable - static func < (lhs: Tensor, rhs: Scalar) -> Bool { - return (lhs .< rhs).all() - } - - /// Returns a Boolean value indicating whether the value of the first argument is - /// lexicographically less than or equal to that of the second argument. - @inlinable - static func <= (lhs: Tensor, rhs: Scalar) -> Bool { - return (lhs .<= rhs).all() - } - - /// Returns a Boolean value indicating whether the value of the first argument is - /// lexicographically greater than that of the second argument. - @inlinable - static func > (lhs: Tensor, rhs: Scalar) -> Bool { - return (lhs .> rhs).all() - } - - /// Returns a Boolean value indicating whether the value of the first argument is - /// lexicographically greater than or equal to that of the second argument. - @inlinable - static func >= (lhs: Tensor, rhs: Scalar) -> Bool { - return (lhs .>= rhs).all() - } -} - -public extension Tensor where Scalar: Equatable { - /// Computes `lhs != rhs` element-wise and returns a `Tensor` of Boolean scalars. - /// - Note: `.==` supports broadcasting. - @inlinable - static func .==(lhs: Tensor, rhs: Tensor) -> Tensor { - return Raw.equal(lhs, rhs) - } - - /// Computes `lhs != rhs` element-wise and returns a `Tensor` of Boolean scalars. - /// - Note: `.!=` supports broadcasting. - @inlinable - static func .!=(lhs: Tensor, rhs: Tensor) -> Tensor { - return Raw.notEqual(lhs, rhs) - } - - /// Computes `lhs == rhs` element-wise and returns a `Tensor` of Boolean scalars. - /// - Note: `.==` supports broadcasting. - @inlinable - static func .==(lhs: Scalar, rhs: Tensor) -> Tensor { - return Tensor(lhs) .== rhs - } - - /// Computes `lhs != rhs` element-wise and returns a `Tensor` of Boolean scalars. - /// - Note: `.!=` supports broadcasting. - @inlinable - static func .!=(lhs: Scalar, rhs: Tensor) -> Tensor { - return Tensor(lhs) .!= rhs - } - - /// Computes `lhs == rhs` element-wise and returns a `Tensor` of Boolean - /// scalars. - /// - Note: `.==` supports broadcasting. - @inlinable - static func .==(lhs: Tensor, rhs: Scalar) -> Tensor { - return lhs .== Tensor(rhs) - } - - /// Computes `lhs != rhs` element-wise and returns a `Tensor` of Boolean scalars. - /// - Note: `.!=` supports broadcasting. - @inlinable - static func .!=(lhs: Tensor, rhs: Scalar) -> Tensor { - return lhs .!= Tensor(rhs) - } -} - -// TODO: infix operator ≈: ComparisonPrecedence - -public extension Tensor where Scalar: FloatingPoint & Equatable { - /// Returns a `Tensor` of Boolean values indicating whether the elements of `self` are - /// approximately equal to those of `other`. - @inlinable - func elementsApproximatelyEqual( - _ other: Tensor, - tolerance: Double = 0.00001 - ) -> Tensor { - return Raw.approximateEqual(self, other, tolerance: tolerance) - } -} diff --git a/Sources/DeepLearning/Operators/Math.swift b/Sources/DeepLearning/Operators/Math.swift deleted file mode 100644 index 3255aea10..000000000 --- a/Sources/DeepLearning/Operators/Math.swift +++ /dev/null @@ -1,1505 +0,0 @@ -// Copyright 2018 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#if !COMPILING_TENSORFLOW_MODULE -import TensorFlow -#endif - -#if COMPILING_TENSORFLOW_MODULE -infix operator .>: ComparisonPrecedence -infix operator .==: ComparisonPrecedence -#endif - -// TODO: -// - Consider explicit broadcasting for elementwise binary ops when -// scalarization and rank getter are implemented. - -//===------------------------------------------------------------------------------------------===// -// Additive Group -//===------------------------------------------------------------------------------------------===// - -extension Tensor: AdditiveArithmetic where Scalar: Numeric { - /// A scalar zero tensor. - @inlinable - public static var zero: Tensor { - return Tensor(zeros: []) - } - - /// Adds two tensors and produces their sum. - /// - Note: `+` supports broadcasting. - @inlinable - @differentiable(vjp: _vjpAdd(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) - public static func + (lhs: Tensor, rhs: Tensor) -> Tensor { - return Raw.add(lhs, rhs) - } - - /// Subtracts one tensor from another and produces their difference. - /// - Note: `-` supports broadcasting. - @inlinable - @differentiable(vjp: _vjpSubtract(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) - public static func - (lhs: Tensor, rhs: Tensor) -> Tensor { - return Raw.sub(lhs, rhs) - } -} - -internal extension Tensor where Scalar: TensorFlowFloatingPoint { - @inlinable - static func _vjpAdd(lhs: Tensor, rhs: Tensor) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - return (lhs + rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in - (v.unbroadcast(toShape: lhsShape), v.unbroadcast(toShape: rhsShape)) - }) - } - - @inlinable - static func _vjpSubtract(lhs: Tensor, rhs: Tensor) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - return (lhs - rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in - (v.unbroadcast(toShape: lhsShape), -v.unbroadcast(toShape: rhsShape)) - }) - } -} - -//===------------------------------------------------------------------------------------------===// -// Vector Space -//===------------------------------------------------------------------------------------------===// - -extension Tensor: VectorNumeric where Scalar: Numeric { - /// Multiplies the scalar with every scalar of the tensor and produces the product. - @inlinable - @differentiable(vjp: _vjpMultiply(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) - public static func * (lhs: Scalar, rhs: Tensor) -> Tensor { - return Tensor(lhs) * rhs - } -} - -internal extension Tensor where Scalar: TensorFlowFloatingPoint { - @inlinable - static func _vjpMultiply(lhs: Tensor, rhs: Tensor) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - return (lhs * rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in - ((rhs * v).unbroadcast(toShape: lhsShape), (lhs * v).unbroadcast(toShape: rhsShape)) - }) - } -} - -extension Tensor: ShapedVectorNumeric where Scalar: Numeric {} - -extension Tensor: Differentiable where Scalar: TensorFlowFloatingPoint { - public typealias TangentVector = Tensor - public typealias CotangentVector = Tensor - public typealias AllDifferentiableVariables = Tensor - - @inlinable - public func tangentVector(from cotangent: CotangentVector) -> TangentVector { - return cotangent - } -} - -//===------------------------------------------------------------------------------------------===// -// Additional Element-wise Operators -//===------------------------------------------------------------------------------------------===// - -public extension Tensor where Scalar: Numeric { - /// Adds the scalar to every scalar of the tensor and produces the sum. - @inlinable - @differentiable(vjp: _vjpAdd(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) - static func + (lhs: Scalar, rhs: Tensor) -> Tensor { - return Tensor(lhs) + rhs - } - - /// Adds the scalar to every scalar of the tensor and produces the sum. - @inlinable - @differentiable(vjp: _vjpAdd(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) - static func + (lhs: Tensor, rhs: Scalar) -> Tensor { - return lhs + Tensor(rhs) - } - - /// Subtracts the scalar from every scalar of the tensor and produces the difference. - @inlinable - @differentiable(vjp: _vjpSubtract(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) - static func - (lhs: Scalar, rhs: Tensor) -> Tensor { - return Tensor(lhs) - rhs - } - - /// Subtracts the scalar from every scalar of the tensor and produces the difference - @inlinable - @differentiable(vjp: _vjpSubtract(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) - static func - (lhs: Tensor, rhs: Scalar) -> Tensor { - return lhs - Tensor(rhs) - } - - /// Adds two tensors and stores the result in the left-hand-side variable. - /// - Note: `+=` supports broadcasting. - @inlinable - static func += (lhs: inout Tensor, rhs: Tensor) { - lhs = lhs + rhs - } - - /// Adds the scalar to every scalar of the tensor and stores the result in the left-hand-side - /// variable. - @inlinable - static func += (lhs: inout Tensor, rhs: Scalar) { - lhs = lhs + rhs - } - - /// Subtracts the second tensor from the first and stores the result in the left-hand-side - /// variable. - /// - Note: `-=` supports broadcasting. - @inlinable - static func -= (lhs: inout Tensor, rhs: Tensor) { - lhs = lhs - rhs - } - - /// Subtracts the scalar from every scalar of the tensor and stores the result in the - /// left-hand-side variable. - @inlinable - static func -= (lhs: inout Tensor, rhs: Scalar) { - lhs = lhs - rhs - } - - /// Multiplies two tensors and produces their product. - /// - Note: `*` supports broadcasting. - @inlinable - @differentiable(vjp: _vjpMultiply(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) - static func * (lhs: Tensor, rhs: Tensor) -> Tensor { - return Raw.mul(lhs, rhs) - } - - /// Multiplies the scalar with every scalar of the tensor and produces the product. - @inlinable - @differentiable(vjp: _vjpMultiply(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) - static func * (lhs: Tensor, rhs: Scalar) -> Tensor { - return lhs * Tensor(rhs) - } - - /// Multiplies two tensors and stores the result in the left-hand-side variable. - /// - Note: `*=` supports broadcasting. - @inlinable - static func *= (lhs: inout Tensor, rhs: Tensor) { - lhs = lhs * rhs - } - - /// Multiplies the tensor with the scalar, broadcasting the scalar, and stores the result in the - /// left-hand-side variable. - @inlinable - static func *= (lhs: inout Tensor, rhs: Scalar) { - lhs = lhs * rhs - } - - /// Returns the quotient of dividing the first tensor by the second. - /// - Note: `/` supports broadcasting. - @inlinable - @differentiable(vjp: _vjpDivide(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) - static func / (lhs: Tensor, rhs: Tensor) -> Tensor { - return Raw.div(lhs, rhs) - } - - /// Returns the quotient of dividing the scalar by the tensor, broadcasting the scalar. - @inlinable - @differentiable(vjp: _vjpDivide(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) - static func / (lhs: Scalar, rhs: Tensor) -> Tensor { - return Tensor(lhs) / rhs - } - - /// Returns the quotient of dividing the tensor by the scalar, broadcasting the scalar. - @inlinable - @differentiable(vjp: _vjpDivide(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) - static func / (lhs: Tensor, rhs: Scalar) -> Tensor { - return lhs / Tensor(rhs) - } - - /// Divides the first tensor by the second and stores the quotient in the left-hand-side - /// variable. - @inlinable - static func /= (lhs: inout Tensor, rhs: Tensor) { - lhs = lhs / rhs - } - - /// Divides the tensor by the scalar, broadcasting the scalar, and stores the quotient in the - /// left-hand-side variable. - @inlinable - static func /= (lhs: inout Tensor, rhs: Scalar) { - lhs = lhs / rhs - } - - /// Returns the remainder of dividing the first tensor by the second. - /// - Note: `%` supports broadcasting. - @inlinable - static func % (lhs: Tensor, rhs: Tensor) -> Tensor { - return Raw.mod(lhs, rhs) - } - - /// Returns the remainder of dividing the tensor by the scalar, broadcasting the scalar. - @inlinable - static func % (lhs: Tensor, rhs: Scalar) -> Tensor { - return lhs % Tensor(rhs) - } - - /// Returns the remainder of dividing the scalar by the tensor, broadcasting the scalar. - @inlinable - static func % (lhs: Scalar, rhs: Tensor) -> Tensor { - return Tensor(lhs) % rhs - } - - /// Divides the first tensor by the second and stores the remainder in the left-hand-side - /// variable. - @inlinable - static func %= (lhs: inout Tensor, rhs: Tensor) { - lhs = lhs % rhs - } - - /// Divides the tensor by the scalar and stores the remainder in the left-hand-side variable. - @inlinable - static func %= (lhs: inout Tensor, rhs: Scalar) { - lhs = lhs % rhs - } -} - -internal extension Tensor where Scalar: TensorFlowFloatingPoint { - @inlinable - static func _vjpAdd(lhs: Tensor, rhs: Scalar) -> (Tensor, (Tensor) -> (Tensor, Scalar)) { - return (lhs + rhs, { v in (v, v.sum().scalarized()) }) - } - - @inlinable - static func _vjpAdd(lhs: Scalar, rhs: Tensor) -> (Tensor, (Tensor) -> (Scalar, Tensor)) { - return (lhs + rhs, { v in (v.sum().scalarized(), v) }) - } - - @inlinable - static func _vjpSubtract(lhs: Tensor, rhs: Scalar) -> (Tensor, (Tensor) -> (Tensor, Scalar)) { - return (lhs - rhs, { v in (v, 0 - v.sum().scalarized()) }) - } - - @inlinable - static func _vjpSubtract(lhs: Scalar, rhs: Tensor) -> (Tensor, (Tensor) -> (Scalar, Tensor)) { - return (lhs - rhs, { v in (v.sum().scalarized(), 0 - v) }) - } - - @inlinable - static func _vjpMultiply(lhs: Tensor, rhs: Scalar) -> (Tensor, (Tensor) -> (Tensor, Scalar)) { - return (lhs * rhs, { v in (v * rhs, (v * lhs).sum().scalarized()) }) - } - - @inlinable - static func _vjpMultiply(lhs: Scalar, rhs: Tensor) -> (Tensor, (Tensor) -> (Scalar, Tensor)) { - return (lhs * rhs, { v in ((v * rhs).sum().scalarized(), v * lhs) }) - } - - @inlinable - static func _vjpDivide(lhs: Tensor, rhs: Tensor) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - return (lhs / rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in - ((v / rhs).unbroadcast(toShape: lhsShape), - ((-lhs) / rhs.squared() * v).unbroadcast(toShape: rhsShape)) - }) - } - - @inlinable - static func _vjpDivide(lhs: Tensor, rhs: Scalar) -> (Tensor, (Tensor) -> (Tensor, Scalar)) { - return (lhs / rhs, { v in - (v / rhs, (v * (0 - lhs) / Tensor(rhs).squared()).sum().scalarized()) - }) - } - - @inlinable - static func _vjpDivide(lhs: Scalar, rhs: Tensor) -> (Tensor, (Tensor) -> (Scalar, Tensor)) { - return (lhs / rhs, { v in ((v / rhs).sum().scalarized(), v * -lhs / rhs.squared()) }) - } -} - -public extension Tensor where Scalar == Bool { - /// Computes `!self` element-wise. - @inlinable - func elementsLogicalNot() -> Tensor { - return Raw.logicalNot(self) - } - - /// Computes `self && other` element-wise. - /// - Note: `&&` supports broadcasting. - @inlinable - func elementsLogicalAnd(_ other: Tensor) -> Tensor { - return Raw.logicalAnd(self, other) - } - - /// Computes `self && other` element-wise, broadcasting `other`. - @inlinable - func elementsLogicalAnd(_ other: Scalar) -> Tensor { - return elementsLogicalAnd(Tensor(other)) - } - - /// Computes `self || other` element-wise. - @inlinable - func elementsLogicalOr(_ other: Tensor) -> Tensor { - return Raw.logicalOr(self, other) - } - - /// Computes `self || other` element-wise, broadcasting `other`. - @inlinable - func elementsLogicalOr(_ other: Scalar) -> Tensor { - return elementsLogicalOr(Tensor(other)) - } -} - -//===------------------------------------------------------------------------------------------===// -// Element-wise Unary Math Functions -//===------------------------------------------------------------------------------------------===// - -// Export Glibc/Darwin math functions. We should not require users to import -// Foundation/Darwin/Glibc in order to use scalar math functions. -// -#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) -@_exported import Darwin.C -#else -@_exported import Glibc -#endif -// -// FIXME(rxwei): Scoped imports are not yet supported in parseable module -// interfaces, so `@_exported import` won't work. When that becomes supported, -// switch to `@_exported import` by removing `import Darwin.C/Glibc` above and -// uncommenting the following lines. In the meantime, consider using indirect -// wrappers for each function so that random libc symbols won't be leaked to -// users' code completion. -// -// #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) -// @_exported import func Darwin.C.sin -// @_exported import func Darwin.C.cos -// @_exported import func Darwin.C.tan -// @_exported import func Darwin.C.sinf -// @_exported import func Darwin.C.cosf -// @_exported import func Darwin.C.tanf -// @_exported import func Darwin.C.sinh -// @_exported import func Darwin.C.cosh -// @_exported import func Darwin.C.tanh -// @_exported import func Darwin.C.sinhf -// @_exported import func Darwin.C.coshf -// @_exported import func Darwin.C.tanhf -// @_exported import func Darwin.C.log -// @_exported import func Darwin.C.logf -// @_exported import func Darwin.C.exp -// @_exported import func Darwin.C.expf -// @_exported import func Darwin.C.pow -// @_exported import func Darwin.C.powf -// #else -// @_exported import func Glibc.sin -// @_exported import func Glibc.cos -// @_exported import func Glibc.tan -// @_exported import func Glibc.sinf -// @_exported import func Glibc.cosf -// @_exported import func Glibc.tanf -// @_exported import func Glibc.sinh -// @_exported import func Glibc.cosh -// @_exported import func Glibc.tanh -// @_exported import func Glibc.sinhf -// @_exported import func Glibc.coshf -// @_exported import func Glibc.tanhf -// @_exported import func Glibc.log -// @_exported import func Glibc.logf -// @_exported import func Glibc.exp -// @_exported import func Glibc.expf -// @_exported import func Glibc.pow -// @_exported import func Glibc.powf -// #endif - -public extension Tensor where Scalar: SignedNumeric { - /// Computes the negation of the specified tensor element-wise. - @inlinable - @differentiable(vjp: _vjpNegate(_:) where Scalar: TensorFlowFloatingPoint) - static prefix func - (rhs: Tensor) -> Tensor { - return Raw.neg(rhs) - } -} - -internal extension Tensor where Scalar: TensorFlowFloatingPoint { - @inlinable - static func _vjpNegate(_ x: Tensor) -> (Tensor, (Tensor) -> Tensor) { - return (-x, { v in -v }) - } -} - -/// Computes the absolute value of the specified tensor element-wise. -@inlinable -@differentiable(vjp: _vjpAbs(_:) where T: TensorFlowFloatingPoint) -public func abs(_ x: Tensor) -> Tensor { - return Raw.abs(x) -} - -@inlinable -internal func _vjpAbs( - _ x: Tensor -) -> (Tensor, (Tensor) -> Tensor) { - let sign = Raw.sign(x) - return (abs(x), { v in v * sign }) -} - -/// Computes the natural logarithm of the specified tensor element-wise. -@inlinable -@differentiable(vjp: _vjpLog(_:) where T: TensorFlowFloatingPoint) -public func log(_ x: Tensor) -> Tensor { - return Raw.log(x) -} - -@inlinable -internal func _vjpLog( - _ x: Tensor -) -> (Tensor, (Tensor) -> Tensor) { - return (log(x), { v in v / x }) -} - -/// Computes `sin` of the specified tensor element-wise. -@inlinable -@differentiable(vjp: _vjpSin(_:) where T: TensorFlowFloatingPoint) -public func sin(_ x: Tensor) -> Tensor { - return Raw.sin(x) -} - -@inlinable -internal func _vjpSin( - _ x: Tensor -) -> (Tensor, (Tensor) -> Tensor) { - return (sin(x), { v in v * cos(x) }) -} - -/// Computes `cos` of the specified tensor element-wise. -@inlinable -@differentiable(vjp: _vjpCos(_:) where T: TensorFlowFloatingPoint) -public func cos(_ x: Tensor) -> Tensor { - return Raw.cos(x) -} - -@inlinable -internal func _vjpCos( - _ x: Tensor -) -> (Tensor, (Tensor) -> Tensor) { - return (cos(x), { v in -v * sin(x) }) -} - -/// Computes `tan` of the specified tensor element-wise. -@inlinable -@differentiable(vjp: _vjpTan(_:) where T: TensorFlowFloatingPoint) -public func tan(_ x: Tensor) -> Tensor { - return Raw.tan(x) -} - -@inlinable -internal func _vjpTan( - _ x: Tensor -) -> (Tensor, (Tensor) -> Tensor) { - let value = tan(x) - return (value, { v in v * (1 + value.squared()) }) -} - -/// Computes `sinh` of the specified tensor element-wise. -@inlinable -@differentiable(vjp: _vjpSinh(_:) where T: TensorFlowFloatingPoint) -public func sinh(_ x: Tensor) -> Tensor { - return Raw.sinh(x) -} - -@inlinable -internal func _vjpSinh( - _ x: Tensor -) -> (Tensor, (Tensor) -> Tensor) { - return (sinh(x), { v in v * cosh(x) }) -} - -/// Computes `cosh` of the specified tensor element-wise. -@inlinable -@differentiable(vjp: _vjpCosh(_:) where T: TensorFlowFloatingPoint) -public func cosh(_ x: Tensor) -> Tensor { - return Raw.cosh(x) -} - -@inlinable -internal func _vjpCosh( - _ x: Tensor -) -> (Tensor, (Tensor) -> Tensor) { - return (cosh(x), { v in v * sinh(x) }) -} - -/// Computes `tanh` of the specified tensor element-wise. -@inlinable -@differentiable(vjp: _vjpTanh(_:) where T: TensorFlowFloatingPoint) -public func tanh(_ x: Tensor) -> Tensor { - return Raw.tanh(x) -} - -@inlinable -internal func _vjpTanh( - _ x: Tensor -) -> (Tensor, (Tensor) -> Tensor) { - let value = tanh(x) - return (value, { v in v * (1 - value.squared()) }) -} - -/// Computes the square of the tensor. -public extension Tensor where Scalar: Numeric { - @inlinable - @differentiable(wrt: self, vjp: _vjpSquared() where Scalar: TensorFlowFloatingPoint) - func squared() -> Tensor { - return Raw.square(self) - } -} - -internal extension Tensor where Scalar: TensorFlowFloatingPoint { - @inlinable - func _vjpSquared() -> (Tensor, (Tensor) -> Tensor) { - return (squared(), { 2 * self * $0 }) - } -} - -/// Computes the square root of the specified tensor element-wise. -@inlinable -@differentiable(vjp: _vjpSqrt(_:) where T: TensorFlowFloatingPoint) -public func sqrt(_ x: Tensor) -> Tensor { - return Raw.sqrt(x) -} - -@inlinable -internal func _vjpSqrt( - _ x: Tensor -) -> (Tensor, (Tensor) -> Tensor) { - let value = sqrt(x) - return (value, { v in v / (2 * value) }) -} - -/// Computes the inverse square root of the specified tensor element-wise. -@inlinable -@differentiable(vjp: _vjpRsqrt(_:) where T: TensorFlowFloatingPoint) -public func rsqrt(_ x: Tensor) -> Tensor { - return Raw.rsqrt(x) -} - -@inlinable -internal func _vjpRsqrt( - _ x: Tensor -) -> (Tensor, (Tensor) -> Tensor) { - let value = rsqrt(x) - return (value, { v in -v / 2 * value }) -} - -/// Computes `exp` of the specified tensor element-wise. -@inlinable -@differentiable(vjp: _vjpExp(_:) where T: TensorFlowFloatingPoint) -public func exp(_ x: Tensor) -> Tensor { - return Raw.exp(x) -} - -@inlinable -internal func _vjpExp( - _ x: Tensor -) -> (Tensor, (Tensor) -> Tensor) { - let value = exp(x) - return (value, { v in value * v }) -} - -/// Computes the ceiling of the specified tensor element-wise. -@inlinable -@differentiable(vjp: _vjpCeil(_:) where T: TensorFlowFloatingPoint) -public func ceil(_ x: Tensor) -> Tensor { - return Raw.ceil(x) -} - -@inlinable -internal func _vjpCeil( - _ x: Tensor -) -> (Tensor, (Tensor) -> Tensor) { - return (ceil(x), { _ in Tensor(0).broadcast(like: x) }) -} - -/// Computes the floor of the specified tensor element-wise. -@inlinable -@differentiable(vjp: _vjpFloor(_:) where T: TensorFlowFloatingPoint) -public func floor(_ x: Tensor) -> Tensor { - return Raw.floor(x) -} - -@inlinable -internal func _vjpFloor( - _ x: Tensor -) -> (Tensor, (Tensor) -> Tensor) { - return (floor(x), { _ in Tensor(0).broadcast(like: x) }) -} - -/// Computes the sigmoid of the specified tensor element-wise. -/// Specifically, computes `1 / (1 + exp(-x))`. -@inlinable -@differentiable(vjp: _vjpSigmoid) -public func sigmoid(_ x: Tensor) -> Tensor { - return Raw.sigmoid(x) -} - -@inlinable -internal func _vjpSigmoid( - _ x: Tensor -) -> (Tensor, (Tensor) -> Tensor) { - return (sigmoid(x), { v in Raw.sigmoidGrad(x, dy: v) }) -} - -/// Computes the softmax of the specified tensor along the last axis. -/// Specifically, computes `exp(x) / exp(x).sum(alongAxes: -1)`. -@inlinable -@differentiable(vjp: _vjpSoftmax(_:) where T: TensorFlowFloatingPoint) -public func softmax(_ x: Tensor) -> Tensor { - return Raw.softmax(logits: x) -} - -/// Computes the softmax of the specified tensor along the specified axis. -/// Specifically, computes `exp(x) / exp(x).sum(alongAxes: axis)`. -@inlinable -// TODO: [AD]. -public func softmax(_ x: Tensor, alongAxis axis: Int) -> Tensor { - let xExp = exp(x) - return xExp / xExp.sum(alongAxes: Tensor(Int32(axis))) -} - -@inlinable -func _vjpSoftmax( - _ x: Tensor -) -> (Tensor, (Tensor) -> Tensor) { - let value = softmax(x) - return (value, { v in - let sumChannels = (v * value).sum(alongAxes: -1) - return (v - sumChannels) * value - }) -} - -/// Computes the log-softmax of the specified tensor element-wise. -@inlinable -@differentiable(vjp: _vjpLogSoftmax(_:) where T: TensorFlowFloatingPoint) -public func logSoftmax(_ x: Tensor) -> Tensor { - return Raw.logSoftmax(logits: x) -} - -@inlinable -func _vjpLogSoftmax( - _ x: Tensor -) -> (Tensor, (Tensor) -> Tensor) { - let value = logSoftmax(x) - return (value, { v in v - v.sum(alongAxes: -1) * exp(value) }) -} - -/// Computes `relu` of the specified tensor element-wise. -/// Specifically, computes `max(0, x)`. -@inlinable -@differentiable(vjp: _vjpRelu(_:) where T: TensorFlowFloatingPoint) -public func relu(_ x: Tensor) -> Tensor { - return max(0, x) -} - -@inlinable -func _vjpRelu( - _ x: Tensor -) -> (Tensor, (Tensor) -> Tensor) { - return (relu(x), { v in Tensor(x .> 0) * v }) -} - -//===------------------------------------------------------------------------------------------===// -// Element-wise Binary Math Functions -//===------------------------------------------------------------------------------------------===// - -/// Computes the power of the first tensor to the second tensor. -@inlinable -@differentiable(vjp: _vjpPow(_:_:) where T: TensorFlowFloatingPoint) -public func pow(_ lhs: Tensor, _ rhs: Tensor) -> Tensor where T: FloatingPoint { - return Raw.pow(lhs, rhs) -} - -@inlinable -internal func _vjpPow( - _ x: Tensor, _ y: Tensor -) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - let value = pow(x, y) - return (value, { v in - ((v * y * pow(x, y-1)).unbroadcast(like: x), - (v * log(x) * value).unbroadcast(like: y)) - }) -} - -/// Computes the power of the scalar to the tensor, broadcasting the scalar. -@inlinable -// @differentiable(where T: TensorFlowFloatingPoint) -public func pow(_ lhs: T, _ rhs: Tensor) -> Tensor where T: FloatingPoint { - return pow(Tensor(lhs), rhs) -} - -/// Computes the power of the tensor to the scalar, broadcasting the scalar. -@inlinable -// @differentiable(where T: TensorFlowFloatingPoint) -public func pow(_ lhs: Tensor, _ rhs: T) -> Tensor where T: FloatingPoint { - return pow(lhs, Tensor(rhs)) -} - -/// Computes the element-wise maximum of two tensors. -/// - Note: `max` supports broadcasting. -@inlinable -@differentiable(vjp: _vjpMax(_:_:) where T: TensorFlowFloatingPoint) -public func max(_ lhs: Tensor, _ rhs: Tensor) -> Tensor where T: Numeric & Comparable { - return Raw.maximum(lhs, rhs) -} - -@inlinable -internal func _vjpMax( - _ x: Tensor, _ y: Tensor -) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - let value = max(x, y) - return (value, { v in _vjpMinMaxHelper(x, y, originalValue: value, vector: v) }) -} - -/// Computes the element-wise maximum of the scalar and the tensor, broadcasting the scalar. -@inlinable -// @differentiable(where T: TensorFlowFloatingPoint) -public func max(_ lhs: T, _ rhs: Tensor) -> Tensor where T: Numeric & Comparable { - return max(Tensor(lhs), rhs) -} - -/// Computes the element-wise maximum of the scalar and the tensor, broadcasting the scalar. -@inlinable -// @differentiable(where T: TensorFlowFloatingPoint) -public func max(_ lhs: Tensor, _ rhs: T) -> Tensor where T: Numeric & Comparable { - return max(lhs, Tensor(rhs)) -} - -/// Computes the element-wise minimum of two tensors. -/// - Note: `min` supports broadcasting. -@inlinable -@differentiable(vjp: _vjpMin(_:_:) where T: TensorFlowFloatingPoint) -public func min(_ lhs: Tensor, _ rhs: Tensor) -> Tensor where T: Numeric & Comparable { - return Raw.minimum(lhs, rhs) -} - -@inlinable -internal func _vjpMin( - _ x: Tensor, _ y: Tensor -) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - let value = min(x, y) - return (value, { v in _vjpMinMaxHelper(x, y, originalValue: value, vector: v) }) -} - -/// Computes the element-wise minimum of the scalar and the tensor, broadcasting the scalar. -@inlinable -// @differentiable(where T: TensorFlowFloatingPoint) -public func min(_ lhs: T, _ rhs: Tensor) -> Tensor where T: Numeric & Comparable { - return min(Tensor(lhs), rhs) -} - -/// Computes the element-wise minimum of the scalar and the tensor, broadcasting the scalar. -@inlinable -// @differentiable(where T: TensorFlowFloatingPoint) -public func min(_ lhs: Tensor, _ rhs: T) -> Tensor where T: Numeric & Comparable { - return min(lhs, Tensor(rhs)) -} - -@inlinable -internal func _vjpMinMaxHelper( - _ x: Tensor, - _ y: Tensor, - originalValue: Tensor, - vector: Tensor -) -> (Tensor, Tensor) { - let denom = 1 + Tensor(x .== y) - let dfdx = vector * Tensor(x .== originalValue) / denom - let dfdy = vector * Tensor(y .== originalValue) / denom - return (dfdx.unbroadcast(like: x), dfdy.unbroadcast(like: y)) -} - -//===------------------------------------------------------------------------------------------===// -// Selection Functions -//===------------------------------------------------------------------------------------------===// - -public extension Tensor where Scalar == Bool { - /// Returns a new tensor containing elements from either `left` or `right`, - /// depending on the elements of `self`. - /// - /// `self` acts as a mask that chooses, based on the value at each scalar, - /// whether the corresponding scalar in the output should be taken from - /// `left` (if `true`) or `right` (if `false`). - /// - /// - Precondition: `left` and `right` must have the same shape. If - /// `left` and `right` are scalar, then `self` must also be scalar. If - /// `left` and `right` have rank greater than or equal to 1, then `self` - /// must be either have the same shape as `left` or be a 1-D `Tensor` such - /// that `self.scalarCount == left[0]`. - @available(*, deprecated, message: "Use '.replacing(with:mask:)' instead") - @inlinable - func selecting(_ left: Tensor, _ right: Tensor) -> Tensor { - return left.replacing(with: right, where: self) - } -} - -public extension Tensor { - /// Replaces elements of this tensor with `other` in the lanes where `mask` is - /// `true`. - /// - /// - Precondition: `self` and `other` must have the same shape. If - /// `self` and `other` are scalar, then `mask` must also be scalar. If - /// `self` and `other` have rank greater than or equal to `1`, then `mask` - /// must be either have the same shape as `self` or be a 1-D `Tensor` such - /// that `mask.scalarCount == self.shape[0]`. - @inlinable - @differentiable(wrt: (self, other), vjp: _vjpReplacing where Scalar: TensorFlowFloatingPoint) - func replacing(with other: Tensor, where mask: Tensor) -> Tensor { - return Raw.select(condition: mask, t: self, e: other) - } -} - -internal extension Tensor where Scalar: TensorFlowFloatingPoint { - @inlinable - func _vjpReplacing( - with other: Tensor, - where mask: Tensor - ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - return (replacing(with: other, where: mask), { v in - let zeros = Tensor(zeros: v.shape) - return (v.replacing(with: zeros, where: mask), zeros.replacing(with: v, where: mask)) - }) - } -} - -//===------------------------------------------------------------------------------------------===// -// Reduction Functions -//===------------------------------------------------------------------------------------------===// - -public extension Tensor where Scalar == Bool { - /// Returns `true` if all scalars are equal to `true`. Otherwise, returns `false`. - // NOTE: This overload is necessary, otherwise `all()` would refer to the variadic method - // `all(squeezingAxes:)` with zero indices. - @inlinable - func all() -> Bool { - let axes = Tensor(rangeFrom: 0, to: Int32(rank), stride: 1) - return _TFGetScalarOrDie(Raw.all(self, reductionIndices: axes).handle) - } - - /// Returns `true` if any scalars are equal to `true`. Otherwise, returns `false`. - // NOTE: This overload is necessary, otherwise `any()` would refer to the variadic method - // `any(squeezingAxes:)` with zero indices. - @inlinable - func any() -> Bool { - let axes = Tensor(rangeFrom: 0, to: Int32(rank), stride: 1) - return _TFGetScalarOrDie(Raw.any(self, reductionIndices: axes).handle) - } - - /// Performs a logical AND operation along the specified axes. The reduced dimensions are - /// removed. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - let axes = axes.map(Int32.init) - return Raw.all(self, reductionIndices: Tensor(axes), keepDims: false) - } - - /// Performs a logical AND operation along the specified axes. The reduced dimensions are - /// removed. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - let axes = axes.map(Int32.init) - return Raw.any(self, reductionIndices: Tensor(axes), keepDims: false) - } - - /// Performs a logical AND operation along the specified axes. The reduced dimensions are - /// retained with value 1. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - let axes = axes.map(Int32.init) - return Raw.all(self, reductionIndices: Tensor(axes), keepDims: true) - } - - /// Performs a logical OR operation along the specified axes. The reduced - /// dimensions are retained with value 1. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - let axes = axes.map(Int32.init) - return Raw.any(self, reductionIndices: Tensor(axes), keepDims: true) - } -} - -public extension Tensor where Scalar: Numeric & Comparable { - // NOTE: This overload is necessary, otherwise `min()` would refer to the variadic method - // `min(squeezingAxes:)` with zero indices. - @inlinable - func min() -> Tensor { - let axes = Tensor(rangeFrom: 0, to: Int32(rank), stride: 1) - return Raw.min(self, reductionIndices: axes) - } - - // NOTE: This overload is necessary, otherwise `max()` would refer to the variadic method - // `max(squeezingAxes:)` with zero indices. - @inlinable - func max() -> Tensor { - let axes = Tensor(rangeFrom: 0, to: Int32(rank), stride: 1) - return Raw.max(self, reductionIndices: axes) - } - - /// Returns the maximum values along the specified axes. The reduced dimensions are removed. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - let axes = axes.map(Int32.init) - return Raw.max(self, reductionIndices: Tensor(axes), keepDims: false) - } - - /// Returns the maximum values along the specified axes. The reduced dimensions are removed. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - return max(squeezingAxes: axes) - } - - /// Returns the minimum values along the specified axes. The reduced dimensions are removed. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - let axes = axes.map(Int32.init) - return Raw.min(self, reductionIndices: Tensor(axes), keepDims: false) - } - - /// Returns the minimum values along the specified axes. The reduced dimensions are removed. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - return min(squeezingAxes: axes) - } - - /// Returns the indices of the maximum values along the specified axes. The reduced dimensions - /// are removed. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - return Raw.argMax(self, dimension: Tensor(Int32(axis))) - } - - /// Returns the indices of the minimum values along the specified axes. The reduced dimensions - /// are removed. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - return Raw.argMin(self, dimension: Tensor(Int32(axis))) - } - - /// Returns the minimum along the specified axes. The reduced dimensions are retained with - /// value 1. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - let axes = axes.map(Int32.init) - return Raw.min(self, reductionIndices: Tensor(axes), keepDims: true) - } - - /// Returns the minimum along the specified axes. The reduced dimensions are retained with - /// value 1. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - return min(alongAxes: axes) - } - - /// Returns the minimum along the specified axes. The reduced dimensions are retained with - /// value 1. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - let axes = axes.map(Int32.init) - return Raw.max(self, reductionIndices: Tensor(axes), keepDims: true) - } - - /// Returns the minimum along the specified axes. The reduced dimensions are retained with - /// value 1. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - return max(alongAxes: axes) - } - - /// Returns the index of the maximum value of the flattened scalars. - @inlinable - func argmax() -> Tensor { - return flattened().argmax(squeezingAxis: 0) - } - - /// Returns the index of the minimum value of the flattened scalars. - @inlinable - func argmin() -> Tensor { - return flattened().argmin(squeezingAxis: 0) - } -} - -// MARK: - Numeric Reductions - -public extension Tensor where Scalar: Numeric { - // MARK: - Sum - - /// Returns the sum along the specified axes. The reduced dimensions are removed. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank...rank`. - @inlinable - @differentiable(wrt: self, vjp: _vjpSum(squeezingAxes:) where Scalar: TensorFlowFloatingPoint) - func sum(squeezingAxes axes: Tensor) -> Tensor { - return Raw.sum(self, reductionIndices: Tensor(axes), keepDims: false) - } - - /// Returns the sum along the specified axes. The reduced dimensions are removed. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank...rank`. - @inlinable - @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) - func sum(squeezingAxes axes: [Int]) -> Tensor { - // TODO(TF-433): Remove workaround for differentiating `map`. - let axes = {axes.map(Int32.init)}() - return sum(squeezingAxes: Tensor(axes)) - } - - /// Returns the sum along the specified axes. The reduced dimensions are removed. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank...rank`. - @inlinable - @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) - func sum(squeezingAxes axes: Int...) -> Tensor { - return sum(squeezingAxes: axes) - } - - @inlinable - @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) - func sum() -> Tensor { - return flattened().sum(squeezingAxes: 0) - } - - /// Returns the sum along the specified axes. The reduced dimensions are retained with value 1. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank..) -> Tensor { - return Raw.sum(self, reductionIndices: axes, keepDims: true) - } - - /// Returns the sum along the specified axes. The reduced dimensions are retained with value 1. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - // TODO(TF-433): Remove workaround for differentiating `map`. - let axes = {axes.map(Int32.init)}() - return sum(alongAxes: Tensor(axes)) - } - - /// Returns the sum along the specified axes. The reduced dimensions are retained with value 1. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - return sum(alongAxes: axes) - } - - // MARK: - Product - - /// Returns the product along the specified axes. The reduced dimensions are removed. - /// - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank...rank`. - // TODO: Make this @differentiable. - @inlinable - func product(squeezingAxes axes: Tensor) -> Tensor { - return Raw.prod(self, reductionIndices: axes, keepDims: false) - } - - /// Returns the product along the specified axes. The reduced dimensions are removed. - /// - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank...rank`. - @inlinable - func product(squeezingAxes axes: [Int]) -> Tensor { - // TODO(TF-433): Remove workaround for differentiating `map`. - let axes = {axes.map(Int32.init)}() - return product(squeezingAxes: Tensor(axes)) - } - - /// Returns the product along the specified axes. The reduced dimensions are removed. - /// - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank...rank`. - @inlinable - func product(squeezingAxes axes: Int...) -> Tensor { - return product(squeezingAxes: axes) - } - - @inlinable - func product() -> Tensor { - return flattened().product(squeezingAxes: 0) - } - - /// Returns the product along the specified axes. The reduced dimensions are retained with - /// value 1. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank..) -> Tensor { - return Raw.prod(self, reductionIndices: axes, keepDims: true) - } - - /// Returns the product along the specified axes. The reduced dimensions are retained with - /// value 1. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - // TODO(TF-433): Remove workaround for differentiating `map`. - let axes = {axes.map(Int32.init)}() - return product(alongAxes: Tensor(axes)) - } - - /// Returns the product along the specified axes. The reduced dimensions are retained with - /// value 1. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - return product(alongAxes: axes) - } - - // MARK: - Mean - - /// Returns the arithmetic mean along the specified axes. The reduced dimensions are removed. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank...rank`. - @inlinable - @differentiable(wrt: self, vjp: _vjpMean(squeezingAxes:) where Scalar: TensorFlowFloatingPoint) - func mean(squeezingAxes axes: Tensor) -> Tensor { - return Raw.mean(self, reductionIndices: axes, keepDims: false) - } - - /// Returns the arithmetic mean along the specified axes. The reduced dimensions are removed. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank...rank`. - @inlinable - @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) - func mean(squeezingAxes axes: [Int]) -> Tensor { - // TODO(TF-433): Remove workaround for differentiating `map`. - let axes = {axes.map(Int32.init)}() - return mean(squeezingAxes: Tensor(axes)) - } - - /// Returns the arithmetic mean along the specified axes. The reduced dimensions are removed. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank...rank`. - @inlinable - @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) - func mean(squeezingAxes axes: Int...) -> Tensor { - return mean(squeezingAxes: axes) - } - - @inlinable - @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) - func mean() -> Tensor { - return flattened().mean(squeezingAxes: [0]) - } - - /// Returns the arithmetic mean along the specified axes. The reduced dimensions are retained - /// with value 1. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank..) -> Tensor { - return Raw.mean(self, reductionIndices: axes, keepDims: true) - } - - /// Returns the arithmetic mean along the specified axes. The reduced dimensions are retained - /// with value 1. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - // TODO(TF-433): Remove workaround for differentiating `map`. - let axes = {axes.map(Int32.init)}() - return mean(alongAxes: Tensor(axes)) - } - - /// Returns the arithmetic mean along the specified axes. The reduced dimensions are retained - /// with value 1. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - return mean(alongAxes: axes) - } - - // MARK: - Variance - - /// Returns the variance along the specified axes. The reduced dimensions are removed. Does not - /// apply Bessel's correction. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank..) -> Tensor { - let squaredDiff = (self - mean(alongAxes: axes)).squared() - return squaredDiff.mean(squeezingAxes: axes) - } - - /// Returns the variance along the specified axes. The reduced dimensions are removed. Does not - /// apply Bessel's correction. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - // TODO(TF-433): Remove workaround for differentiating `map`. - let axes = {axes.map(Int32.init)}() - return variance(squeezingAxes: Tensor(axes)) - } - - /// Returns the variance along the specified axes. The reduced dimensions are retained with - /// value 1. Does not apply Bessel's correction. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - return variance(squeezingAxes: axes) - } - - @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) - @inlinable - func variance() -> Tensor { - let mean = self.mean() - let squaredDiff = (self - mean).squared() - return squaredDiff.mean() - } - - /// Returns the variance along the specified axes. The reduced dimensions are retained with - /// value 1. Does not apply Bessel's correction. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank..) -> Tensor { - let squaredDiff = (self - mean(alongAxes: axes)).squared() - return squaredDiff.mean(alongAxes: axes) - } - - /// Returns the variance along the specified axes. The reduced dimensions are retained with - /// value 1. Does not apply Bessel's correction. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - // TODO(TF-433): Remove workaround for differentiating `map`. - let axes = {axes.map(Int32.init)}() - return variance(alongAxes: Tensor(axes)) - } - - /// Returns the variance along the specified axes. The reduced dimensions are retained with - /// value 1. Does not apply Bessel's correction. - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - return variance(alongAxes: axes) - } -} - -internal extension Tensor where Scalar: TensorFlowFloatingPoint { - @inlinable - func _vjpSum(alongAxes axes: Tensor) -> (Tensor, (Tensor) -> Tensor) { - let value = sum(alongAxes: axes) - return (value, { [shape = shapeTensor] in $0.broadcast(toShape: shape) }) - } - - @inlinable - func _vjpSum(squeezingAxes axes: Tensor) -> (Tensor, (Tensor) -> Tensor) { - let value = sum(squeezingAxes: axes) - return (value, { [shape = shapeTensor] in - var result = $0 - for i in axes.array.scalars { result = result.expandingShape(at: Int(i)) } - return result.broadcast(toShape: shape) - }) - } - - @inlinable - func _vjpMean(alongAxes axes: Tensor) -> (Tensor, (Tensor) -> Tensor) { - let value = mean(alongAxes: axes) - let count = Raw.gather(params: shapeTensor, indices: axes).product() - return (value, { [shape = shapeTensor] in $0.broadcast(toShape: shape) / Tensor(count) }) - } - - @inlinable - func _vjpMean(squeezingAxes axes: Tensor) -> (Tensor, (Tensor) -> Tensor) { - let value = mean(squeezingAxes: axes) - let count = Raw.gather(params: shapeTensor, indices: axes).product() - return (value, { [shape = shapeTensor] in - var result = $0 - for i in axes.array.scalars { result = result.expandingShape(at: Int(i)) } - return result.broadcast(toShape: shape) / Tensor(count) - }) - } -} - -// TODO: Consider making the return type be generic over `FloatingPoint` types -// so that `self`'s scalar type can be any `Numeric` type. -public extension Tensor where Scalar: TensorFlowFloatingPoint { - /// Returns the standard deviation of the elements along the specified axes. The reduced - /// dimensions are retained with value `1`. Does not apply Bessel's correction. - /// - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank..) -> Tensor { - return sqrt(variance(squeezingAxes: axes)) - } - - /// Returns the standard deviation of the elements along the specified axes. The reduced - /// dimensions are retained with value `1`. Does not apply Bessel's correction. - /// - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - return sqrt(variance(squeezingAxes: axes)) - } - - /// Returns the standard deviation of the elements along the specified axes. The reduced - /// dimensions are retained with value `1`. Does not apply Bessel's correction. - /// - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - return standardDeviation(squeezingAxes: axes) - } - - /// Returns the standard deviation of the elements along the specified axes. The reduced - /// dimensions are retained with value `1`. Does not apply Bessel's correction. - /// - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - // Reduce along all dimensions. - return standardDeviation(squeezingAxes: Array(0..) -> Tensor { - return sqrt(variance(alongAxes: axes)) - } - - /// Returns the standard deviation of the elements along the specified axes. The reduced - /// dimensions are retained with value `1`. Does not apply Bessel's correction. - /// - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - // TODO(TF-433): Remove workaround for differentiating `map`. - let axes = {axes.map(Int32.init)}() - return standardDeviation(alongAxes: Tensor(axes)) - } - - /// Returns the standard deviation of the elements along the specified axes. The reduced - /// dimensions are retained with value `1`. Does not apply Bessel's correction. - /// - /// - Parameter axes: The dimensions to reduce. - /// - Precondition: Each value in `axes` must be in the range `-rank.. Tensor { - return sqrt(variance(alongAxes: axes)) - } -} - -//===------------------------------------------------------------------------------------------===// -// Linear Algebra -//===------------------------------------------------------------------------------------------===// - -/// Performs matrix multiplication with another tensor and produces the result. -@inlinable -@differentiable(vjp: _vjpMatmul(_:_:) where Scalar: TensorFlowFloatingPoint) -public func matmul( - _ lhs: Tensor, - _ rhs: Tensor -) -> Tensor { - // Default arguments specified explicitly to avoid "external declarations of SILFunctions with - // shared visibility is not allowed" SILVerifier error in - // "tests/AutoDiff/tensor_autodiff_runtime.swift". - return Raw.matMul(lhs, rhs, transposeA: false, transposeB: false) -} - -@inlinable -internal func _vjpMatmul( - _ lhs: Tensor, - _ rhs: Tensor -) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - let value = matmul(lhs, rhs) - return (value, { v in - (matmul(v, rhs.transposed()), matmul(lhs.transposed(), v)) - }) -} - -infix operator •: MultiplicationPrecedence - -public extension Tensor where Scalar: Numeric { - // TODO: We have to define a custom VJP on • because AD can't yet differentiate generic methods. - // After AD can differentiate generic methods, remove the custom VJP. - - /// Performs matrix multiplication between two tensors and produces the result. - @inlinable - @differentiable(vjp: _vjpMatmulOperator(lhs:rhs:) where Scalar: TensorFlowFloatingPoint) - static func • (lhs: Tensor, rhs: Tensor) -> Tensor { - return matmul(lhs, rhs) - } -} - -// TODO: We have to define a custom VJP on • because AD can't yet -// differentiate generic methods. After AD can differentiate generic methods, -// remove the custom VJP. -internal extension Tensor where Scalar: TensorFlowFloatingPoint { - @inlinable - static func _vjpMatmulOperator( - lhs: Tensor, - rhs: Tensor - ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - return _vjpMatmul(lhs, rhs) - } -} diff --git a/Sources/DeepLearning/PythonConversion.swift b/Sources/DeepLearning/PythonConversion.swift deleted file mode 100644 index 5e52548c4..000000000 --- a/Sources/DeepLearning/PythonConversion.swift +++ /dev/null @@ -1,171 +0,0 @@ -// Copyright 2018 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#if !COMPILING_TENSORFLOW_MODULE -import TensorFlow -#endif - -#if canImport(Python) -import Python - -/// The `numpy` Python module. -/// Note: Global variables are lazy, so the following declaration won't produce -// a Python import error until it is first used. -private let np = Python.import("numpy") - -private func debugLogNumpyError(_ message: String) { - debugLog("NumPy conversion error: " + message) -} - -extension ShapedArray: ConvertibleFromNumpyArray - where Scalar: NumpyScalarCompatible { - /// Creates a `ShapedArray` with the same shape and scalars as the specified - /// `numpy.ndarray` instance. - /// - /// - Parameter numpyArray: The `numpy.ndarray` instance to convert. - /// - Precondition: The `numpy` Python package must be installed. - /// - Precondition: `numpyArray` must have a compatible scalar `dtype`. - public init?(numpy numpyArray: PythonObject) { - // Check if input is a `numpy.ndarray` instance. - guard Python.isinstance(numpyArray, np.ndarray) == true else { - debugLogNumpyError(""" - PythonObject input has type '\(Python.type(numpyArray))' and is not \ - an instance of 'numpy.ndarray'. - """) - return nil - } - // Check if the dtype of the `ndarray` is compatible with the `Scalar` - // type. - guard Scalar.numpyScalarTypes.contains(numpyArray.dtype) else { - debugLogNumpyError(""" - 'numpy.ndarray' dtype '\(numpyArray.dtype)' is incompatible with \ - Swift type '\(Scalar.self)'. - """) - return nil - } - - let pyShape = numpyArray.__array_interface__["shape"] - guard let shape = [Int](pyShape) else { - debugLogNumpyError("cannot access shape of 'numpy.ndarray' instance.") - return nil - } - - // Make sure that the array is contiguous in memory. This does a copy if - // the array is not already contiguous in memory. - let contiguousNumpyArray = np.ascontiguousarray(numpyArray) - - guard let ptrVal = - UInt(contiguousNumpyArray.__array_interface__["data"].tuple2.0) else { - debugLogNumpyError("cannot access data of 'numpy.ndarray' instance.") - return nil - } - // Note: `ptr` is not nil even if the `ndarray` is empty (i.e. has a shape - // of `(0,)`). - guard let ptr = UnsafePointer(bitPattern: ptrVal) else { - fatalError("'numpy.ndarray' data pointer was nil") - } - // This code avoids calling `init(shape: [Int], scalars: S)`, - // which inefficiently copies scalars one by one. Instead, - // `init(shape: [Int], scalars: [Scalar])` is called, which efficiently - // does a `memcpy` of the entire `scalars` array. - // Unecessary copying is minimized. - let dummyPointer = UnsafeMutablePointer.allocate(capacity: 1) - let scalarCount = shape.reduce(1, *) - var scalars: [Scalar] = Array(repeating: dummyPointer.move(), count: scalarCount) - dummyPointer.deallocate() - scalars.withUnsafeMutableBufferPointer { buffPtr in - buffPtr.baseAddress!.assign(from: ptr, count: scalarCount) - } - self.init(shape: shape, scalars: scalars) - } -} - -extension Tensor: ConvertibleFromNumpyArray - where Scalar: NumpyScalarCompatible { - /// Creates a tensor with the same shape and scalars as the specified - /// `numpy.ndarray` instance. - /// - /// - Parameter numpyArray: The `numpy.ndarray` instance to convert. - /// - Precondition: The `numpy` Python package must be installed. - /// - Returns: `numpyArray` converted to an `Array`. Returns `nil` if - /// `numpyArray` does not have a compatible scalar `dtype`. - public init?(numpy numpyArray: PythonObject) { - // Check if input is a `numpy.ndarray` instance. - guard Python.isinstance(numpyArray, np.ndarray) == true else { - debugLogNumpyError(""" - PythonObject input has type '\(Python.type(numpyArray))' and is not \ - an instance of 'numpy.ndarray'. - """) - return nil - } - // Check if the dtype of the `ndarray` is compatible with the `Scalar` - // type. - guard Scalar.numpyScalarTypes.contains(numpyArray.dtype) else { - debugLogNumpyError(""" - 'numpy.ndarray' dtype '\(numpyArray.dtype)' is incompatible with \ - Swift type '\(Scalar.self)'. - """) - return nil - } - - let pyShape = numpyArray.__array_interface__["shape"] - guard let dimensions = [Int](pyShape) else { - debugLogNumpyError("cannot access shape of 'numpy.ndarray' instance.") - return nil - } - let shape = TensorShape(dimensions) - - // Make sure that the array is contiguous in memory. This does a copy if - // the array is not already contiguous in memory. - let contiguousNumpyArray = np.ascontiguousarray(numpyArray) - - guard let ptrVal = UInt(contiguousNumpyArray.__array_interface__["data"].tuple2.0) else { - debugLogNumpyError("cannot access data of 'numpy.ndarray' instance.") - return nil - } - // Note: `ptr` is not nil even if the `ndarray` is empty (i.e. has a shape - // of `(0,)`). - guard let ptr = UnsafePointer(bitPattern: ptrVal) else { - fatalError("'numpy.ndarray' data pointer was nil") - } - let buffPtr = UnsafeBufferPointer(start: ptr, count: Int(shape.contiguousSize)) - self.init(shape: shape, scalars: buffPtr) - } -} - -extension ShapedArray where Scalar: NumpyScalarCompatible { - /// Creates a `numpy.ndarray` instance with the same shape and scalars as - /// this `ShapedArray`. - /// - /// - Precondition: The `numpy` Python package must be installed. - public func makeNumpyArray() -> PythonObject { - return scalars.makeNumpyArray().reshape(shape) - } -} - -extension Tensor where Scalar: NumpyScalarCompatible { - /// Creates a `numpy.ndarray` instance with the same shape and scalars as - /// this tensor. - /// - /// - Precondition: The `numpy` Python package must be installed. - public func makeNumpyArray() -> PythonObject { return array.makeNumpyArray() } -} - -extension TensorShape: PythonConvertible { - public var pythonObject: PythonObject { - return dimensions.pythonObject - } -} - -#endif // canImport(Python) diff --git a/Sources/DeepLearning/Random.swift b/Sources/DeepLearning/Random.swift index ade75a826..44e55223c 100644 --- a/Sources/DeepLearning/Random.swift +++ b/Sources/DeepLearning/Random.swift @@ -19,7 +19,7 @@ import Glibc #endif //===------------------------------------------------------------------------------------------===// -// Random Number Generators +// Random number generators //===------------------------------------------------------------------------------------------===// /// A type that provides seedable deterministic pseudo-random data. @@ -409,8 +409,8 @@ private func makeUInt64Pair(_ vector: UInt32x4) -> (UInt64, UInt64) { //===------------------------------------------------------------------------------------------===// public protocol RandomDistribution { - associatedtype Sample - func next(using generator: inout G) -> Sample + associatedtype Sample + func next(using generator: inout G) -> Sample } @_fixed_layout @@ -429,8 +429,8 @@ public struct UniformIntegerDistribution: RandomDistributi } @_fixed_layout -public struct UniformFloatingPointDistribution: RandomDistribution - where T.RawSignificand: FixedWidthInteger { +public struct UniformFloatingPointDistribution: RandomDistribution + where T.RawSignificand : FixedWidthInteger { public let lowerBound: T public let upperBound: T @@ -445,8 +445,8 @@ public struct UniformFloatingPointDistribution: RandomDi } @_fixed_layout -public struct NormalDistribution: RandomDistribution - where T.RawSignificand: FixedWidthInteger { +public struct NormalDistribution: RandomDistribution + where T.RawSignificand : FixedWidthInteger { public let mean: T public let standardDeviation: T private let uniformDist = UniformFloatingPointDistribution() @@ -503,10 +503,10 @@ public struct BetaDistribution: RandomDistribution { /// /// - Returns: Sample obtained using Cheng's BB algorithm. private static func chengsAlgorithmBB( - _ alpha0: Float, - _ a: Float, - _ b: Float, - using rng: inout G + _ alpha0: Float, + _ a: Float, + _ b: Float, + using rng: inout G ) -> Float { let alpha = a + b let beta = sqrt((alpha - 2) / (2 * a * b - alpha)) @@ -536,7 +536,7 @@ public struct BetaDistribution: RandomDistribution { } while r + alpha * (log(alpha) - log(b + w)) < t w = min(w, Float.greatestFiniteMagnitude) - return a == alpha0 ? w / (b + w): b / (b + w) + return a == alpha0 ? w / (b + w) : b / (b + w) } /// Returns one sample from a Beta(alpha, beta) distribution using Cheng's BC @@ -550,10 +550,10 @@ public struct BetaDistribution: RandomDistribution { /// /// - Returns: Sample obtained using Cheng's BB algorithm. private static func chengsAlgorithmBC( - _ alpha0: Float, - _ a: Float, - _ b: Float, - using rng: inout G + _ alpha0: Float, + _ a: Float, + _ b: Float, + using rng: inout G ) -> Float { let alpha = a + b let beta = 1 / b @@ -592,6 +592,6 @@ public struct BetaDistribution: RandomDistribution { } w = min(w, Float.greatestFiniteMagnitude) - return a == alpha0 ? w / (b + w): b / (b + w) + return a == alpha0 ? w / (b + w) : b / (b + w) } } diff --git a/Sources/DeepLearning/Tensors.swift b/Sources/DeepLearning/Tensors.swift deleted file mode 100644 index 1c1700649..000000000 --- a/Sources/DeepLearning/Tensors.swift +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2018 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#if !COMPILING_TENSORFLOW_MODULE -import TensorFlow -#endif - -#if COMPILING_TENSORFLOW_MODULE -infix operator .==: ComparisonPrecedence -#endif - -//===------------------------------------------------------------------------------------------===// -// Tensor Properties -//===------------------------------------------------------------------------------------------===// - -public extension Tensor { - /// The rank of the tensor, represented as a `Tensor`. - @inlinable - var rankTensor: Tensor { - return Raw.rank(self) - } - - /// The dimensions of the tensor, represented as a `Tensor`. - @inlinable - var shapeTensor: Tensor { - return Raw.shape(self) - } - - /// The number of scalars in the tensor, represented as a `Tensor`. - @inlinable - var scalarCountTensor: Tensor { - return Raw.size(self) - } -} - -//===------------------------------------------------------------------------------------------===// -// Description and Visualization -//===------------------------------------------------------------------------------------------===// - -// String conversion. -extension Tensor: CustomStringConvertible { - /// A textual representation of the tensor. - /// - /// - Note: use `fullDescription` for a non-pretty-printed description showing all scalars. - public var description: String { - return array.description - } -} - -public extension Tensor { - /// A textual representation of the tensor. Returns a summarized description if `summarize` is - /// true and the element count exceeds twice the `edgeElementCount`. - /// - /// - Parameters: - /// - lineWidth: The max line width for printing. Used to determine number of scalars to print - /// per line. - /// - edgeElementCount: The maximum number of elements to print before and after summarization - /// via ellipses (`...`). - /// - summarizing: If true, summarize description if element count exceeds twice - /// `edgeElementCount`. - func description( - lineWidth: Int = 80, - edgeElementCount: Int = 3, - summarizing: Bool = false - ) -> String { - return array.description( - lineWidth: lineWidth, - edgeElementCount: edgeElementCount, - summarizing: summarizing) - } - - /// A full, non-pretty-printed textual representation of the tensor, showing - /// all scalars. - var fullDescription: String { - return array.fullDescription - } -} - -// Xcode Playground display conversion. -extension Tensor: CustomPlaygroundDisplayConvertible { - public var playgroundDescription: Any { - return description - } -} - -// Mirror representation, used by debugger/REPL. -extension Tensor: CustomReflectable { - public var customMirror: Mirror { - return Mirror(self, children: [], displayStyle: .struct) - } -} - -//===------------------------------------------------------------------------------------------===// -// Codable Conformance -//===------------------------------------------------------------------------------------------===// - -extension Tensor: Codable where Scalar: Codable { - @inlinable - public func encode(to encoder: Encoder) throws { - var container = encoder.singleValueContainer() - try container.encode(array) - } - - @inlinable - public init(from decoder: Decoder) throws { - let container = try decoder.singleValueContainer() - let array = try container.decode(ShapedArray.self) - self.init(array) - } -} diff --git a/Tests/DeepLearningTests/InitializerTests.swift b/Tests/DeepLearningTests/InitializerTests.swift deleted file mode 100644 index 3407e5816..000000000 --- a/Tests/DeepLearningTests/InitializerTests.swift +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2019 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -import XCTest -@testable import DeepLearning - -final class InitializerTests: XCTestCase { - func testInitializers() { - let scalar = Tensor(1) - let matrix: Tensor = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] - let broadcastScalar = Tensor(broadcasting: 10, rank: 3) - let some4d = Tensor( - shape: [2, 1, 2, 1], - scalars: AnyRandomAccessCollection([2, 3, 4, 5])) - XCTAssertEqual(ShapedArray(shape: [2, 1, 2, 1], scalars: [2, 3, 4, 5]), some4d.array) - XCTAssertEqual(ShapedArray(shape: [], scalars: [1]), scalar.array) - XCTAssertEqual(ShapedArray(shape: [2, 3], scalars: [1, 2, 3, 4, 5, 6]), matrix.array) - XCTAssertEqual(ShapedArray(shape: [1, 1, 1], scalars: [10]), broadcastScalar.array) - } - - func testFactoryInitializers() { - let x = Tensor(ones: [1, 10]) - XCTAssertEqual(ShapedArray(repeating: 1, shape: [1, 10]), x.array) - } - - func testNumericInitializers() { - let x = Tensor(oneHotAtIndices: [0, 2, -1, 1], depth: 3) - XCTAssertEqual(ShapedArray( - shape: [4, 3], - scalars: [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0]), x.array) - } - - func testScalarToTensorConversion() { - let tensor = Tensor(broadcasting: 42, rank: 4) - XCTAssertEqual([1, 1, 1, 1], tensor.shape) - XCTAssertEqual([42], tensor.scalars) - } - - func testArrayConversion() { - let array3D = ShapedArray(repeating: 1.0, shape: [2, 3, 4]) - let tensor3D = Tensor(array3D) - XCTAssertEqual(array3D, tensor3D.array) - } - - func testNonTPUDataTypeCast() { - // TPU does not support Int8 or 16 casting. - guard !_RuntimeConfig.executionMode.isTPU else { return } - - let x = Tensor(ones: [5, 5]) - let ints = Tensor(x) - let floats = Tensor(x) - let i8s = Tensor(floats) - XCTAssertEqual(ShapedArray(repeating: 1, shape: [5, 5]), ints.array) - XCTAssertEqual(ShapedArray(repeating: 1, shape: [5, 5]), floats.array) - XCTAssertEqual(ShapedArray(repeating: 1, shape: [5, 5]), i8s.array) - } - - func testTPUDataTypeCast() { - // Non-TPU mode (e.g. eager) does not support Uint32 casting. - guard _RuntimeConfig.executionMode.isTPU else { return } - - let x = Tensor(ones: [5, 5]) - let ints = Tensor(x) - let floats = Tensor(x) - let u32s = Tensor(floats) - XCTAssertEqual(ShapedArray(repeating: 1, shape: [5, 5]), ints.array) - XCTAssertEqual(ShapedArray(repeating: 1, shape: [5, 5]), floats.array) - XCTAssertEqual(ShapedArray(repeating: 1, shape: [5, 5]), u32s.array) - } - - func testNonTPUBoolToNumericCast() { - // TPU does not support Int8 or 16 casting. - // - // When changing to UInt32, got another TPU/XLA compilation error when - // converting from bools to Uint32 (different from missing kernel error). - if _RuntimeConfig.executionMode.isTPU { return } - - let bools = Tensor(shape: [2, 2], scalars: [true, false, true, false]) - let ints = Tensor(bools) - let floats = Tensor(bools) - let i8s = Tensor(bools) - XCTAssertEqual(ShapedArray(shape: [2, 2], scalars: [1, 0, 1, 0]), ints.array) - XCTAssertEqual(ShapedArray(shape: [2, 2], scalars: [1, 0, 1, 0]), floats.array) - XCTAssertEqual(ShapedArray(shape: [2, 2], scalars: [1, 0, 1, 0]), i8s.array) - } - - static var allTests = [ - ("testInitializers", testInitializers), - ("testFactoryInitializers", testFactoryInitializers), - ("testNumericInitializers", testNumericInitializers), - ("testScalarToTensorConversion", testScalarToTensorConversion), - ("testArrayConversion", testArrayConversion), - ("testNonTPUDataTypeCast", testNonTPUDataTypeCast), - ("testTPUDataTypeCast", testTPUDataTypeCast), - ("testNonTPUBoolToNumericCast", testNonTPUBoolToNumericCast) - ] -} diff --git a/Tests/DeepLearningTests/OperatorTests/BasicTests.swift b/Tests/DeepLearningTests/OperatorTests/BasicTests.swift deleted file mode 100644 index ae25efbc5..000000000 --- a/Tests/DeepLearningTests/OperatorTests/BasicTests.swift +++ /dev/null @@ -1,479 +0,0 @@ -// Copyright 2019 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -import XCTest -@testable import DeepLearning - -final class BasicOperatorTests: XCTestCase { - func testElementIndexing() { - // NOTE: cannot test multiple `Tensor.shape` or `Tensor.scalars` directly - // until send and receive are implemented (without writing a bunch of mini - // tests). Instead, `Tensor.array` is called to make a ShapedArray host copy - // and the ShapedArray is tested. - let tensor3D = Tensor( - shape: [3, 4, 5], scalars: Array(stride(from: 0.0, to: 60, by: 1))) - let element2D = tensor3D[2] - let element1D = tensor3D[1][3] - let element0D = tensor3D[2][0][3] - - let array2D = element2D.array - let array1D = element1D.array - let array0D = element0D.array - - /// Test shapes - XCTAssertEqual([4, 5], array2D.shape) - XCTAssertEqual([5], array1D.shape) - XCTAssertEqual([], array0D.shape) - - /// Test scalars - XCTAssertEqual(Array(stride(from: 40.0, to: 60, by: 1)), array2D.scalars) - XCTAssertEqual(Array(stride(from: 35.0, to: 40, by: 1)), array1D.scalars) - XCTAssertEqual([43], array0D.scalars) - } - - func testElementIndexingAssignment() { - // NOTE: cannot test multiple `Tensor.shape` or `Tensor.scalars` directly - // until send and receive are implemented (without writing a bunch of mini - // tests). Instead, `Tensor.array` is called to make a ShapedArray host copy - // and the ShapedArray is tested. - var tensor3D = Tensor( - shape: [3, 4, 5], scalars: Array(stride(from: 0.0, to: 60, by: 1))) - tensor3D[2] = Tensor( - shape: [4, 5], scalars: Array(stride(from: 20.0, to: 40, by: 1))) - let element2D = tensor3D[2] - let element1D = tensor3D[1][3] - let element0D = tensor3D[2][0][3] - - let array2D = element2D.array - let array1D = element1D.array - let array0D = element0D.array - - /// Test shapes - XCTAssertEqual([4, 5], array2D.shape) - XCTAssertEqual([5], array1D.shape) - XCTAssertEqual([], array0D.shape) - - /// Test scalars - XCTAssertEqual(Array(stride(from: 20.0, to: 40, by: 1)), array2D.scalars) - XCTAssertEqual(Array(stride(from: 35.0, to: 40, by: 1)), array1D.scalars) - XCTAssertEqual([23], array0D.scalars) - } - - func testNestedElementIndexing() { - // NOTE: This test could use a clearer name, along with other "indexing" - // tests. Note to update corresponding test names in other files - // (shaped_array.test) as well. - let tensor3D = Tensor( - shape: [3, 4, 5], scalars: Array(stride(from: 0.0, to: 60, by: 1))) - let element1D = tensor3D[1, 3] - let element0D = tensor3D[2, 0, 3] - - let array1D = element1D.array - let array0D = element0D.array - - /// Test shapes - XCTAssertEqual([5], array1D.shape) - XCTAssertEqual([], array0D.shape) - - /// Test scalars - XCTAssertEqual(Array(stride(from: 35.0, to: 40, by: 1)), array1D.scalars) - XCTAssertEqual([43], array0D.scalars) - } - - func testSliceIndexing() { - // NOTE: cannot test `Tensor.shape` or `Tensor.scalars` directly until send - // and receive are implemented (without writing a bunch of mini tests). - // Instead, `Tensor.array` is called to make a ShapedArray host copy and the - // ShapedArray is tested instead. - let tensor3D = Tensor( - shape: [3, 4, 5], scalars: Array(stride(from: 0.0, to: 60, by: 1))) - let slice3D = tensor3D[2...] - let slice2D = tensor3D[1][0..<2] - let slice1D = tensor3D[0][0][3..<5] - - let array3D = slice3D.array - let array2D = slice2D.array - let array1D = slice1D.array - - /// Test shapes - XCTAssertEqual([1, 4, 5], array3D.shape) - XCTAssertEqual([2, 5], array2D.shape) - XCTAssertEqual([2], array1D.shape) - - /// Test scalars - XCTAssertEqual(Array(stride(from: 40.0, to: 60, by: 1)), array3D.scalars) - XCTAssertEqual(Array(stride(from: 20.0, to: 30, by: 1)), array2D.scalars) - XCTAssertEqual(Array(stride(from: 3.0, to: 5, by: 1)), array1D.scalars) - } - - func testSliceIndexingAssignment() { - // NOTE: cannot test `Tensor.shape` or `Tensor.scalars` directly until send - // and receive are implemented (without writing a bunch of mini tests). - // Instead, `Tensor.array` is called to make a ShapedArray host copy and the - // ShapedArray is tested instead. - var tensor3D = Tensor( - shape: [3, 4, 5], scalars: Array(stride(from: 0.0, to: 60, by: 1))) - tensor3D[2, 0..<5, 0..<6] = Tensor( - shape: [4, 5], scalars: Array(stride(from: 20.0, to: 40, by: 1))) - let slice3D = tensor3D[2...] - let slice2D = tensor3D[1][0..<2] - let slice1D = tensor3D[0][0][3..<5] - - let array3D = slice3D.array - let array2D = slice2D.array - let array1D = slice1D.array - - /// Test shapes - XCTAssertEqual([1, 4, 5], array3D.shape) - XCTAssertEqual([2, 5], array2D.shape) - XCTAssertEqual([2], array1D.shape) - - /// Test scalars - XCTAssertEqual(Array(stride(from: 20.0, to: 40, by: 1)), array3D.scalars) - XCTAssertEqual(Array(stride(from: 20.0, to: 30, by: 1)), array2D.scalars) - XCTAssertEqual(Array(stride(from: 3.0, to: 5, by: 1)), array1D.scalars) - } - - func testEllipsisIndexing() { - // NOTE: cannot test `Tensor.shape` or `Tensor.scalars` directly until send - // and receive are implemented (without writing a bunch of mini tests). - // Instead, `Tensor.array` is called to make a ShapedArray host copy and the - // ShapedArray is tested instead. - var tensor3D = Tensor( - shape: [3, 4, 5], scalars: Array(stride(from: 0.0, to: 60, by: 1))) - tensor3D[2, TensorRange.ellipsis] = Tensor( - shape: [4, 5], scalars: Array(stride(from: 20.0, to: 40, by: 1))) - let slice3D = tensor3D[2..., TensorRange.ellipsis] - let slice2D = tensor3D[1][0..<2] - let slice1D = tensor3D[0][0][3..<5] - - let array3D = slice3D.array - let array2D = slice2D.array - let array1D = slice1D.array - - /// Test shapes - XCTAssertEqual([1, 4, 5], array3D.shape) - XCTAssertEqual([2, 5], array2D.shape) - XCTAssertEqual([2], array1D.shape) - - /// Test scalars - XCTAssertEqual(Array(stride(from: 20.0, to: 40, by: 1)), array3D.scalars) - XCTAssertEqual(Array(stride(from: 20.0, to: 30, by: 1)), array2D.scalars) - XCTAssertEqual(Array(stride(from: 3.0, to: 5, by: 1)), array1D.scalars) - } - - func testNewAxisIndexing() { - // NOTE: cannot test `Tensor.shape` or `Tensor.scalars` directly until send - // and receive are implemented (without writing a bunch of mini tests). - // Instead, `Tensor.array` is called to make a ShapedArray host copy and the - // ShapedArray is tested instead. - let tensor3D = Tensor( - shape: [3, 4, 5], scalars: Array(stride(from: 0.0, to: 60, by: 1))) - let newAxis = TensorRange.newAxis - let ellipsis = TensorRange.ellipsis - let slice3D = tensor3D[2..., newAxis, ellipsis] - let slice2D = tensor3D[1, newAxis][0..<1, 0..<2] - let slice1D = tensor3D[0][newAxis, 0][0..<1, 3..<5, newAxis] - - let array3D = slice3D.array - let array2D = slice2D.array - let array1D = slice1D.array - - /// Test shapes - XCTAssertEqual([1, 1, 4, 5], array3D.shape) - XCTAssertEqual([1, 2, 5], array2D.shape) - XCTAssertEqual([1, 2, 1], array1D.shape) - - /// Test scalars - XCTAssertEqual(Array(stride(from: 40.0, to: 60, by: 1)), array3D.scalars) - XCTAssertEqual(Array(stride(from: 20.0, to: 30, by: 1)), array2D.scalars) - XCTAssertEqual(Array(stride(from: 3.0, to: 5, by: 1)), array1D.scalars) - } - - func testSqueezeAxisIndexing() { - // NOTE: cannot test `Tensor.shape` or `Tensor.scalars` directly until send - // and receive are implemented (without writing a bunch of mini tests). - // Instead, `Tensor.array` is called to make a ShapedArray host copy and the - // ShapedArray is tested instead. - let tensor3D = Tensor( - shape: [3, 4, 5], scalars: Array(stride(from: 0.0, to: 60, by: 1))) - let newAxis = TensorRange.newAxis - let ellipsis = TensorRange.ellipsis - let squeezeAxis = TensorRange.squeezeAxis - let slice3D = tensor3D[2..., newAxis, ellipsis][squeezeAxis, squeezeAxis] - let slice2D = tensor3D[1, newAxis][squeezeAxis, 0..<2] - let slice1D = tensor3D[0..<1, 0, 3..<5, newAxis][ - squeezeAxis, ellipsis, squeezeAxis] - - let array3D = slice3D.array - let array2D = slice2D.array - let array1D = slice1D.array - - /// Test shapes - XCTAssertEqual([4, 5], array3D.shape) - XCTAssertEqual([2, 5], array2D.shape) - XCTAssertEqual([2], array1D.shape) - - /// Test scalars - XCTAssertEqual(Array(stride(from: 40.0, to: 60, by: 1)), array3D.scalars) - XCTAssertEqual(Array(stride(from: 20.0, to: 30, by: 1)), array2D.scalars) - XCTAssertEqual(Array(stride(from: 3.0, to: 5, by: 1)), array1D.scalars) - } - - func testStridedSliceIndexing() { - // NOTE: cannot test `Tensor.shape` or `Tensor.scalars` directly until send - // and receive are implemented (without writing a bunch of mini tests). - // Instead, `Tensor.array` is called to make a ShapedArray host copy and the - // ShapedArray is tested instead. - let tensor3D = Tensor( - shape: [3, 4, 5], scalars: Array(stride(from: 0.0, to: 60, by: 1))) - let slice3D = tensor3D[2...] - let slice2D = tensor3D[1][0..<3..2] - let slice1D = tensor3D[0][0][1..<5..2] - - let array3D = slice3D.array - let array2D = slice2D.array - let array1D = slice1D.array - - /// Test shapes - XCTAssertEqual([1, 4, 5], array3D.shape) - XCTAssertEqual([2, 5], array2D.shape) - XCTAssertEqual([2], array1D.shape) - - /// Test scalars - XCTAssertEqual(Array(stride(from: 40.0, to: 60, by: 1)), array3D.scalars) - XCTAssertEqual( - Array(stride(from: 20.0, to: 25, by: 1)) + - Array(stride(from: 30.0, to: 35, by: 1)), array2D.scalars) - XCTAssertEqual(Array(stride(from: 1.0, to: 5, by: 2)), array1D.scalars) - } - - func testStridedSliceIndexingAssignment() { - // NOTE: cannot test `Tensor.shape` or `Tensor.scalars` directly until send - // and receive are implemented (without writing a bunch of mini tests). - // Instead, `Tensor.array` is called to make a ShapedArray host copy and the - // ShapedArray is tested instead. - var tensor3D = Tensor( - shape: [3, 4, 5], scalars: Array(stride(from: 0.0, to: 60, by: 1))) - tensor3D[2, 0..<5..2, 0..<6] = Tensor( - shape: [2, 5], scalars: Array(stride(from: 20.0, to: 40, by: 2))) - let slice3D = tensor3D[2...] - let slice2D = tensor3D[1][0..<2] - let slice1D = tensor3D[0][0][3..<5] - - let array3D = slice3D.array - let array2D = slice2D.array - let array1D = slice1D.array - - /// Test shapes - XCTAssertEqual([1, 4, 5], array3D.shape) - XCTAssertEqual([2, 5], array2D.shape) - XCTAssertEqual([2], array1D.shape) - - /// Test scalars - XCTAssertEqual( - Array(stride(from: 20.0, to: 30, by: 2)) + - Array(stride(from: 45.0, to: 50, by: 1)) + - Array(stride(from: 30.0, to: 40, by: 2)) + - Array(stride(from: 55.0, to: 60, by: 1)), array3D.scalars) - XCTAssertEqual(Array(stride(from: 20.0, to: 30, by: 1)), array2D.scalars) - XCTAssertEqual(Array(stride(from: 3.0, to: 5, by: 1)), array1D.scalars) - } - - func testWholeTensorSlicing() { - let t: Tensor = [[[1, 1, 1], [2, 2, 2]], - [[3, 3, 3], [4, 4, 4]], - [[5, 5, 5], [6, 6, 6]]] - let slice2 = t.slice(lowerBounds: [1, 0, 0], upperBounds: [2, 1, 3]) - XCTAssertEqual(ShapedArray(shape: [1, 1, 3], scalars: [3, 3, 3]), slice2.array) - } - - func testAdvancedIndexing() { - // NOTE: cannot test multiple `Tensor.shape` or `Tensor.scalars` directly - // until send and receive are implemented (without writing a bunch of mini - // tests). Instead, `Tensor.array` is called to make a ShapedArray host copy - // and the ShapedArray is tested. - let tensor3D = Tensor( - shape: [3, 4, 5], scalars: Array(stride(from: 0.0, to: 60, by: 1))) - let element2D = tensor3D[1..<3, 0, 3...] - let array2D = element2D.array - - // Test shape - XCTAssertEqual([2, 2], array2D.shape) - - // Test scalars - XCTAssertEqual(Array([23.0, 24.0, 43.0, 44.0]), array2D.scalars) - } - - func testConcatenation() { - // 2 x 3 - let t1 = Tensor([[0, 1, 2], [3, 4, 5]]) - // 2 x 3 - let t2 = Tensor([[6, 7, 8], [9, 10, 11]]) - let concatenated = t1 ++ t2 - let concatenated0 = t1.concatenated(with: t2) - let concatenated1 = t1.concatenated(with: t2, alongAxis: 1) - XCTAssertEqual(ShapedArray(shape: [4, 3], scalars: Array(0..<12)), concatenated.array) - XCTAssertEqual(ShapedArray(shape: [4, 3], scalars: Array(0..<12)), concatenated0.array) - XCTAssertEqual( - ShapedArray(shape: [2, 6], scalars: [0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10, 11]), - concatenated1.array) - } - - func testVJPConcatenation() { - let a1 = Tensor([1,2,3,4]) - let b1 = Tensor([5,6,7,8,9,10]) - - let a2 = Tensor([1,1,1,1]) - let b2 = Tensor([1,1,1,1,1,1]) - - let grads = gradient(at: a2, b2) { a, b in - return ((a1 * a) ++ (b1 * b)).sum() - } - - XCTAssertEqual(a1, grads.0) - XCTAssertEqual(b1, grads.1) - } - - func testVJPConcatenationNegativeAxis() { - let a1 = Tensor([1,2,3,4]) - let b1 = Tensor([5,6,7,8,9,10]) - - let a2 = Tensor([1,1,1,1]) - let b2 = Tensor([1,1,1,1,1,1]) - - let grads = gradient(at: a2, b2) { a, b in - return (a1 * a).concatenated(with: b1 * b, alongAxis: -1).sum() - } - - XCTAssertEqual(a1, grads.0) - XCTAssertEqual(b1, grads.1) - } - - func testTranspose() { - // 3 x 2 -> 2 x 3 - let xT = Tensor([[1, 2], [3, 4], [5, 6]]).transposed() - let xTArray = xT.array - XCTAssertEqual(2, xTArray.rank) - XCTAssertEqual([2, 3], xTArray.shape) - XCTAssertEqual([1, 3, 5, 2, 4, 6], xTArray.scalars) - } - - func testReshape() { - // 2 x 3 -> 1 x 3 x 1 x 2 x 1 - let matrix = Tensor([[0, 1, 2], [3, 4, 5]]) - let reshaped = matrix.reshaped(to: [1, 3, 1, 2, 1]) - - XCTAssertEqual([1, 3, 1, 2, 1], reshaped.shape) - XCTAssertEqual(Array(0..<6), reshaped.scalars) - } - - func testFlatten() { - // 2 x 3 -> 6 - let matrix = Tensor([[0, 1, 2], [3, 4, 5]]) - let flattened = matrix.flattened() - - XCTAssertEqual([6], flattened.shape) - XCTAssertEqual(Array(0..<6), flattened.scalars) - } - - func testFlatten0D() { - let scalar = Tensor(5) - let flattened = scalar.flattened() - XCTAssertEqual([1], flattened.shape) - XCTAssertEqual([5], flattened.scalars) - } - - func testReshapeToScalar() { - // 1 x 1 -> scalar - let z = Tensor([[10]]).reshaped(to: []) - XCTAssertEqual([], z.shape) - } - - func testReshapeTensor() { - // 2 x 3 -> 1 x 3 x 1 x 2 x 1 - let x = Tensor(repeating: 0.0, shape: [2, 3]) - let y = Tensor(repeating: 0.0, shape: [1, 3, 1, 2, 1]) - let result = x.reshaped(like: y) - XCTAssertEqual([1, 3, 1, 2, 1], result.shape) - } - - func testUnbroadcast1() { - let x = Tensor(repeating: 1, shape: [2, 3, 4, 5]) - let y = Tensor(repeating: 1, shape: [4, 5]) - let z = x.unbroadcast(like: y) - XCTAssertEqual(ShapedArray(repeating: 6, shape: [4, 5]), z.array) - } - - func testUnbroadcast2() { - let x = Tensor(repeating: 1, shape: [2, 3, 4, 5]) - let y = Tensor(repeating: 1, shape: [3, 1, 5]) - let z = x.unbroadcast(like: y) - XCTAssertEqual(ShapedArray(repeating: 8, shape: [3, 1, 5]), z.array) - } - - func testSliceUpdate() { - guard !_RuntimeConfig.executionMode.isTPU else { return } - var t1 = Tensor([[1, 2, 3], [4, 5, 6]]) - t1[0] = Tensor(zeros: [3]) - XCTAssertEqual(ShapedArray(shape:[2, 3], scalars: [0, 0, 0, 4, 5, 6]), t1.array) - var t2 = t1 - t2[0][2] = Tensor(3) - XCTAssertEqual(ShapedArray(shape:[2, 3], scalars: [0, 0, 3, 4, 5, 6]), t2.array) - var t3 = Tensor([[true, true, true], [false, false, false]]) - t3[0][1] = Tensor(false) - XCTAssertEqual(ShapedArray( - shape:[2, 3], scalars: [true, false, true, false, false, false]), t3.array) - var t4 = Tensor([[true, true, true], [false, false, false]]) - t4[0] = Tensor(repeating: false, shape: [3]) - XCTAssertEqual(ShapedArray(repeating: false, shape: [2, 3]), t4.array) - } - - func testBroadcastTensor() { - // 1 -> 2 x 3 x 4 - let one = Tensor(1) - var target = Tensor(repeating: 0.0, shape: [2, 3, 4]) - let broadcasted = one.broadcast(like: target) - XCTAssertEqual(Tensor(repeating: 1, shape: [2, 3, 4]), broadcasted) - target .= Tensor(repeating: 1, shape: [1, 3, 1]) - XCTAssertEqual(Tensor(repeating: 1, shape: [2, 3, 4]), target) - } - - static var allTests = [ - ("testElementIndexing", testElementIndexing), - ("testElementIndexingAssignment", testElementIndexingAssignment), - ("testNestedElementIndexing", testNestedElementIndexing), - ("testSliceIndexing", testSliceIndexing), - ("testSliceIndexingAssignment", testSliceIndexingAssignment), - ("testEllipsisIndexing", testEllipsisIndexing), - ("testNewAxisIndexing", testNewAxisIndexing), - ("testSqueezeAxisIndexing", testSqueezeAxisIndexing), - ("testStridedSliceIndexing", testStridedSliceIndexing), - ("testStridedSliceIndexingAssignment", testStridedSliceIndexingAssignment), - ("testWholeTensorSlicing", testWholeTensorSlicing), - ("testAdvancedIndexing", testAdvancedIndexing), - ("testConcatenation", testConcatenation), - ("testVJPConcatenation", testVJPConcatenation), - ("testTranspose", testTranspose), - ("testReshape", testReshape), - ("testFlatten", testFlatten), - ("testFlatten0D", testFlatten0D), - ("testReshapeToScalar", testReshapeToScalar), - ("testReshapeTensor", testReshapeTensor), - ("testUnbroadcast1", testUnbroadcast1), - ("testUnbroadcast2", testUnbroadcast2), - ("testSliceUpdate", testSliceUpdate), - ("testBroadcastTensor", testBroadcastTensor) - ] -} diff --git a/Tests/DeepLearningTests/OperatorTests/ComparisonTests.swift b/Tests/DeepLearningTests/OperatorTests/ComparisonTests.swift deleted file mode 100644 index e20a9cdc9..000000000 --- a/Tests/DeepLearningTests/OperatorTests/ComparisonTests.swift +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2019 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -import XCTest -@testable import DeepLearning - -final class ComparisonOperatorTests: XCTestCase { - func testElementwiseComparison() { - let x = Tensor([0, 1, 2]) - let y = Tensor([2, 1, 3]) - XCTAssertEqual((x .< y).scalars, [true, false, true]) - } - - func testLexicographicalComparison() { - let x = Tensor([0, 1, 2, 3, 4]) - let y = Tensor([2, 3, 4, 5, 6]) - XCTAssertTrue(x < y) - } - - static var allTests = [ - ("testElementwiseComparison", testElementwiseComparison), - ("testLexicographicalComparison", testLexicographicalComparison) - ] -} diff --git a/Tests/DeepLearningTests/OperatorTests/MathTests.swift b/Tests/DeepLearningTests/OperatorTests/MathTests.swift deleted file mode 100644 index 3f769be07..000000000 --- a/Tests/DeepLearningTests/OperatorTests/MathTests.swift +++ /dev/null @@ -1,211 +0,0 @@ -// Copyright 2019 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -import XCTest -@testable import DeepLearning - -final class MathOperatorTests: XCTestCase { - func testReduction() { - // 2 x 5 - let x = Tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]) - XCTAssertEqual(Tensor(30), x.sum().toHost(shape: [])) - XCTAssertEqual( - Tensor(shape: [5], scalars: [2, 4, 6, 8, 10]), - x.sum(squeezingAxes: 0).toHost(shape: [])) - XCTAssertEqual( - Tensor(shape: [1, 5], scalars: [2, 4, 6, 8, 10]), - x.sum(alongAxes: 0).toHost(shape: [])) - - XCTAssertEqual(Tensor(14400), x.product().toHost(shape: [])) - XCTAssertEqual( - Tensor(shape: [5], scalars: [1, 4, 9, 16, 25]), - x.product(squeezingAxes: 0).toHost(shape: [])) - XCTAssertEqual( - Tensor(shape: [1, 5], scalars: [1, 4, 9, 16, 25]), - x.product(alongAxes: 0).toHost(shape: [])) - - XCTAssertEqual(Tensor(3), x.mean().toHost(shape: [])) - XCTAssertEqual( - Tensor(shape: [5], scalars: [1, 2, 3, 4, 5]), - x.mean(squeezingAxes: 0).toHost(shape: [])) - XCTAssertEqual( - Tensor(shape: [5], scalars: [1, 2, 3, 4, 5]), - x.mean(alongAxes: 0).toHost(shape: [])) - XCTAssertEqual( - Tensor(shape: [2], scalars: [3, 3]), - x.mean(squeezingAxes: 1).toHost(shape: [])) - XCTAssertEqual( - Tensor(shape: [1, 2], scalars: [3, 3]), - x.mean(alongAxes: 1).toHost(shape: [])) - - XCTAssertEqual(Tensor(2), x.variance().toHost(shape: [])) - XCTAssertEqual( - Tensor(shape: [5], scalars: [0, 0, 0, 0, 0]), - x.variance(squeezingAxes: 0).toHost(shape: [])) - XCTAssertEqual( - Tensor(shape: [5], scalars: [0, 0, 0, 0, 0]), - x.variance(alongAxes: 0).toHost(shape: [])) - XCTAssertEqual( - Tensor(shape: [2], scalars: [2, 2]), - x.variance(squeezingAxes: 1).toHost(shape: [])) - XCTAssertEqual( - Tensor(shape: [1, 2], scalars: [2, 2]), - x.variance(alongAxes: 1).toHost(shape: [])) - } - - func testArgmax() { - // 2 x 3 - let x = Tensor([[0, 1, 2], [3, 4, 5]]) - let argmax0 = x.argmax(squeezingAxis: 0) - let argmax1 = x.argmax(squeezingAxis: 1) - let scalarsArgmax = x.argmax() - XCTAssertEqual(ShapedArray(shape: [3], scalars: [1, 1, 1]), argmax0.array) - XCTAssertEqual(ShapedArray(shape: [2], scalars: [2, 2]), argmax1.array) - XCTAssertEqual(ShapedArray(shape: [], scalars: [5]), scalarsArgmax.array) - } - - func testCeilAndFloor() { - let x = Tensor([-1.3, -0.4, 0.5, 1.6]) - let xFloor = floor(x) - let xCeil = ceil(x) - XCTAssertEqual(ShapedArray(shape: [4], scalars: [-2, -1, 0, 1]), xFloor.array) - XCTAssertEqual(ShapedArray(shape: [4], scalars: [-1, 0, 1, 2]), xCeil.array) - } - - func testSimpleMath() { - let x = Tensor([1.2, 1.2]) - let y = tanh(x) - let array = y.array - XCTAssertEqual([2], array.shape) - XCTAssertEqual([0.833655, 0.833655], array.scalars, accuracy: 0.0001) - } - - func testStandardDeviation() { - XCTAssertEqual(Tensor(0), Tensor([1]).standardDeviation()) - XCTAssertEqual(Tensor(0.5), Tensor([0, 1]).standardDeviation(alongAxes: 0)) - XCTAssertEqual(Tensor(0.5), Tensor([0, 1]).standardDeviation()) - XCTAssertEqual( - 2.87228132, - Tensor(rangeFrom: 0, to: 10, stride: 1).standardDeviation().scalarized(), - accuracy: 0.001) - let matrix = Tensor(rangeFrom: 0, to: 10, stride: 1).reshaped(to: [2, 5]) - XCTAssertEqual(2.87228132, matrix.standardDeviation().scalarized(), accuracy: 0.001) - XCTAssertEqual( - [1.4142, 1.4142], - matrix.standardDeviation(alongAxes: 1).array.scalars, - accuracy: 0.001) - } - - func test3Adds() { - let a = Tensor([1]) - let b = Tensor([2]) - let c = Tensor([3]) - - let o = a + b + c - XCTAssertEqual([6], o.scalars) - } - - func testMultiOpMath() { - let x = Tensor([1.2, 1.2]) - let y = Tensor([2.4, 2.4]) - let t1 = x + y - let t2 = t1 * t1 - let t3 = sqrt(t2) - - let array1 = t1.array - let array2 = t2.array - let array3 = t3.array - XCTAssertEqual([2], array1.shape) - XCTAssertEqual([2], array2.shape) - XCTAssertEqual([2], array3.shape) - XCTAssertEqual([3.6, 3.6], array1.scalars, accuracy: 0.001) - XCTAssertEqual([12.96, 12.96], array2.scalars, accuracy: 0.001) - XCTAssertEqual([3.6, 3.6], array3.scalars, accuracy: 0.001) - } - - func testXWPlusB() { - // Shape: 1 x 4 - let x = Tensor([[1.0, 2.0, 2.0, 1.0]]) - // Shape: 4 x 2 - let w = Tensor([[1.0, 0.0], [3.0, 0.0], [2.0, 3.0], [1.0, 0.0]]) - // Shape: 2 - let b = Tensor([0.5, 0.5]) - // Shape: 1 x 2 (broadcasted) - let result = matmul(x, w) + b - XCTAssertEqual([1, 2], result.shape) - XCTAssertEqual([12.5, 6.5], result.scalars) - } - - func testXORInference() { - func xor(_ x: Float, _ y: Float) -> Float { - let x = Tensor([x, y]).reshaped(to: [1, 2]) - - // FIXME: If params are declared outside of `xor`, it would crash. - // 2 x 4 - let w1 = Tensor( - [[-1.83586664, -0.20809225, 0.47667537, 1.90780607], - [-1.83523219, -0.51167348, 0.15490439, 1.91018065]]) - // 1 x 4 - let b1 = Tensor([[2.54353216, 0.25132703, -0.16503136, -0.85754058]]) - // 4 x 1 - let w2 = Tensor([[3.04350065], [0.35590511], [-0.3252157], [3.49349223]]) - // 1 x 1 - let b2 = Tensor([[-0.74635993]]) - - let o1 = tanh(matmul(x, w1) + b1) - let y = tanh(matmul(o1, w2) + b2) - return y.array.scalars[0] // TODO: use better scalar getter - } - - XCTAssertEqual(0.0, xor(0.0, 0.0), accuracy: 0.1) - XCTAssertEqual(1.0, xor(0.0, 1.0), accuracy: 0.1) - XCTAssertEqual(1.0, xor(1.0, 0.0), accuracy: 0.1) - XCTAssertEqual(0.0, xor(1.0, 1.0), accuracy: 0.1) - } - - func testMLPClassifierStruct() { - struct MLPClassifier { - // 2 x 4 - var w1 = Tensor([[1.0, 0.8, 0.4, 0.4], - [0.4, 0.3, 0.2, 0.1]]) - // 4 x 1 - var w2 = Tensor([[0.4], [0.4], [0.3], [0.9]]) - var b1 = Tensor(zeros: [1, 4]) - var b2 = Tensor(zeros: [1, 1]) - - func prediction(for x: Tensor) -> Tensor { - let o1 = tanh(matmul(x, w1) + b1) - return tanh(matmul(o1, w2) + b2) - } - } - - let input = Tensor([[1, 0.5]]) - let classifier = MLPClassifier() - let prediction = classifier.prediction(for: input) - XCTAssertEqual([0.816997], prediction.scalars, accuracy: 0.001) - } - - static var allTests = [ - ("testReduction", testReduction), - ("testArgmax", testArgmax), - ("testCeilAndFloor", testCeilAndFloor), - ("testSimpleMath", testSimpleMath), - ("testStandardDeviation", testStandardDeviation), - ("test3Adds", test3Adds), - ("testMultiOpMath", testMultiOpMath), - ("testXWPlusB", testXWPlusB), - ("testXORInference", testXORInference), - ("testMLPClassifierStruct", testMLPClassifierStruct) - ] -} diff --git a/Tests/DeepLearningTests/TensorTests.swift b/Tests/DeepLearningTests/TensorTests.swift deleted file mode 100644 index ec7d1f6e3..000000000 --- a/Tests/DeepLearningTests/TensorTests.swift +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2019 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -import XCTest -@testable import DeepLearning - -final class TensorTests: XCTestCase { - func testSimpleCond() { - func selectValue(_ pred: Bool) -> Tensor { - let a = Tensor(0) - let b = Tensor(1) - if pred { - return a - } - return b - } - - XCTAssertEqual(0, selectValue(true).scalar) - } - - func testRankGetter() { - let vector = Tensor([1]) - let matrix = Tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - let ones = Tensor(ones: [1, 2, 2, 2, 2, 2, 1]) - let tensor = Tensor(shape: [3, 4, 5], scalars: Array(0..<60)) - XCTAssertEqual(1, vector.rank) - XCTAssertEqual(2, matrix.rank) - XCTAssertEqual(7, ones.rank) - XCTAssertEqual(3, tensor.rank) - } - - func testShapeGetter() { - let vector = Tensor([1]) - let matrix = Tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - let ones = Tensor(ones: [1, 2, 2, 2, 2, 2, 1]) - let tensor = Tensor(shape: [3, 4, 5], scalars: Array(0..<60)) - XCTAssertEqual([1], vector.shape) - XCTAssertEqual([2, 3], matrix.shape) - XCTAssertEqual([1, 2, 2, 2, 2, 2, 1], ones.shape) - XCTAssertEqual([3, 4, 5], tensor.shape) - } - - static var allTests = [ - ("testSimpleCond", testSimpleCond), - ("testRankGetter", testRankGetter), - ("testShapeGetter", testShapeGetter) - ] -} diff --git a/Tests/DeepLearningTests/XCTestManifests.swift b/Tests/DeepLearningTests/XCTestManifests.swift index e75c25298..96a9048a5 100644 --- a/Tests/DeepLearningTests/XCTestManifests.swift +++ b/Tests/DeepLearningTests/XCTestManifests.swift @@ -22,10 +22,6 @@ public func allTests() -> [XCTestCaseEntry] { testCase(TrivialModelTests.allTests), testCase(SequentialTests.allTests), testCase(LayerTests.allTests), - testCase(TensorTests.allTests), - testCase(BasicOperatorTests.allTests), - testCase(ComparisonOperatorTests.allTests), - testCase(MathOperatorTests.allTests), ] } #endif