From af1333e8e3086ed5c0b1546e237080820a07088d Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sat, 13 Apr 2019 08:20:01 +0100 Subject: [PATCH] [TF] Change APIs to use `Int` instead of `Int32`. Friend PR: https://github.com/apple/swift/pull/24012 --- Sources/DeepLearning/Initializers.swift | 4 +- Sources/DeepLearning/Layer.swift | 140 ++++++++++++------------ Sources/DeepLearning/Operators.swift | 62 ++++++----- 3 files changed, 102 insertions(+), 104 deletions(-) diff --git a/Sources/DeepLearning/Initializers.swift b/Sources/DeepLearning/Initializers.swift index ef4de6228..0655b817b 100644 --- a/Sources/DeepLearning/Initializers.swift +++ b/Sources/DeepLearning/Initializers.swift @@ -61,7 +61,7 @@ public extension Tensor where Scalar: BinaryFloatingPoint { Int64.random(in: Int64.min..((0..((0..([seed.0, seed.1]) ) } @@ -79,7 +79,7 @@ public extension Tensor where Scalar: BinaryFloatingPoint { Int64.random(in: Int64.min..((0..((0..([seed.0, seed.1]) ) } diff --git a/Sources/DeepLearning/Layer.swift b/Sources/DeepLearning/Layer.swift index 43894153b..4f78aa108 100644 --- a/Sources/DeepLearning/Layer.swift +++ b/Sources/DeepLearning/Layer.swift @@ -247,9 +247,9 @@ public extension Dense { activation: @escaping Activation = identity, generator: inout G ) { - self.init(weight: Tensor(glorotUniform: [Int32(inputSize), Int32(outputSize)], + self.init(weight: Tensor(glorotUniform: [inputSize, outputSize], generator: &generator), - bias: Tensor(zeros: [Int32(outputSize)]), + bias: Tensor(zeros: [outputSize]), activation: activation) } @@ -277,9 +277,9 @@ public extension Dense { seed: (Int64, Int64) = (Int64.random(in: Int64.min..: Layer { /// The element-wise activation function. @noDerivative public let activation: Activation /// The stride of the sliding window for temporal dimension. - @noDerivative public let stride: Int32 + @noDerivative public let stride: Int /// The padding algorithm for convolution. @noDerivative public let padding: Padding @@ -322,7 +322,7 @@ public struct Conv1D: Layer { self.filter = filter self.bias = bias self.activation = activation - self.stride = Int32(stride) + self.stride = stride self.padding = padding } @@ -362,10 +362,10 @@ public extension Conv1D where Scalar.RawSignificand: FixedWidthInteger { generator: inout G ) { let filterTensorShape = TensorShape([ - Int32(filterShape.0), Int32(filterShape.1), Int32(filterShape.2)]) + filterShape.0, filterShape.1, filterShape.2]) self.init( filter: Tensor(glorotUniform: filterTensorShape), - bias: Tensor(zeros: TensorShape([Int32(filterShape.2)])), + bias: Tensor(zeros: TensorShape([filterShape.2])), activation: activation, stride: stride, padding: padding) @@ -393,12 +393,12 @@ public extension Conv1D { Int64.random(in: Int64.min..: Layer { /// The element-wise activation function. @noDerivative public let activation: Activation /// The strides of the sliding window for spatial dimensions. - @noDerivative public let strides: (Int32, Int32) + @noDerivative public let strides: (Int, Int) /// The padding algorithm for convolution. @noDerivative public let padding: Padding @@ -441,7 +441,7 @@ public struct Conv2D: Layer { self.filter = filter self.bias = bias self.activation = activation - (self.strides.0, self.strides.1) = (Int32(strides.0), Int32(strides.1)) + self.strides = strides self.padding = padding } @@ -480,11 +480,10 @@ public extension Conv2D { generator: inout G ) { let filterTensorShape = TensorShape([ - Int32(filterShape.0), Int32(filterShape.1), - Int32(filterShape.2), Int32(filterShape.3)]) + filterShape.0, filterShape.1, filterShape.2, filterShape.3]) self.init( filter: Tensor(glorotUniform: filterTensorShape, generator: &generator), - bias: Tensor(zeros: TensorShape([Int32(filterShape.3)])), + bias: Tensor(zeros: TensorShape([filterShape.3])), activation: activation, strides: strides, padding: padding) @@ -511,13 +510,12 @@ public extension Conv2D { Int64.random(in: Int64.min..) -> Tensor { let batchSize = input.shape[0] - let w = (input.shape[1] - (1 * paddingIndex)) * strides.0 + (filter.shape[0] * paddingIndex) - let h = (input.shape[2] - (1 * paddingIndex)) * strides.1 + (filter.shape[1] * paddingIndex) + let w = (input.shape[1] - (1 * paddingIndex)) * + strides.0 + (filter.shape[0] * paddingIndex) + let h = (input.shape[2] - (1 * paddingIndex)) * + strides.1 + (filter.shape[1] * paddingIndex) let c = filter.shape[2] - let newShape = Tensor([batchSize, w, h, c]) + let newShape = Tensor([Int32(batchSize), Int32(w), Int32(h), Int32(c)]) return activation(input.conv2DBackpropInput(shape: newShape, filter: filter, - strides: (1, strides.0, strides.1, 1), - padding: padding) + bias) + strides: (1, strides.0, strides.1, 1), + padding: padding) + bias) } } @@ -606,11 +606,10 @@ public extension TransposedConv2D { generator: inout G ) { let filterTensorShape = TensorShape([ - Int32(filterShape.0), Int32(filterShape.1), - Int32(filterShape.2), Int32(filterShape.3)]) + filterShape.0, filterShape.1, filterShape.2, filterShape.3]) self.init( filter: Tensor(glorotUniform: filterTensorShape, generator: &generator), - bias: Tensor(zeros: TensorShape([Int32(filterShape.3)])), + bias: Tensor(zeros: TensorShape([filterShape.3])), activation: activation, strides: strides, padding: padding) @@ -637,11 +636,10 @@ public extension TransposedConv2D { Int64.random(in: Int64.min..: Layer { /// The feature dimension. - @noDerivative public let axis: Int32 + @noDerivative public let axis: Int /// The momentum for the running mean and running variance. @noDerivative public let momentum: Tensor /// The offset value, also known as beta. @@ -692,7 +690,7 @@ public struct BatchNorm: Layer { runningMean: Tensor, runningVariance: Tensor ) { - self.axis = Int32(axis) + self.axis = axis self.momentum = momentum self.offset = offset self.scale = scale @@ -705,7 +703,7 @@ public struct BatchNorm: Layer { private func applyingTraining(to input: Tensor) -> Tensor { let positiveAxis = (input.rank + axis) % input.rank var normalizedAxes = Array(0..: Layer { axis: Int = -1, momentum: Tensor = Tensor(0.99), epsilon: Tensor = Tensor(0.001)) { - self.axis = Int32(axis) + self.axis = axis self.momentum = momentum - self.scale = Tensor(ones: [Int32(featureCount)]) - self.offset = Tensor(zeros: [Int32(featureCount)]) + self.scale = Tensor(ones: [featureCount]) + self.offset = Tensor(zeros: [featureCount]) self.epsilon = epsilon self.runningMean = Parameter(Tensor(0)) self.runningVariance = Parameter(Tensor(1)) @@ -777,9 +775,9 @@ public struct BatchNorm: Layer { @_fixed_layout public struct MaxPool1D: Layer { /// The size of the sliding reduction window for pooling. - @noDerivative let poolSize: Int32 + @noDerivative let poolSize: Int /// The stride of the sliding window for temporal dimension. - @noDerivative let stride: Int32 + @noDerivative let stride: Int /// The padding algorithm for pooling. @noDerivative let padding: Padding @@ -794,8 +792,8 @@ public struct MaxPool1D: Layer { stride: Int, padding: Padding ) { - self.poolSize = Int32(poolSize) - self.stride = Int32(stride) + self.poolSize = poolSize + self.stride = stride self.padding = padding } @@ -816,10 +814,10 @@ public struct MaxPool1D: Layer { @_fixed_layout public struct MaxPool2D: Layer { /// The size of the sliding reduction window for pooling. - @noDerivative let poolSize: (Int32, Int32, Int32, Int32) + @noDerivative let poolSize: (Int, Int, Int, Int) /// The strides of the sliding window for each dimension of a 4-D input. /// Strides in non-spatial dimensions must be `1`. - @noDerivative let strides: (Int32, Int32, Int32, Int32) + @noDerivative let strides: (Int, Int, Int, Int) /// The padding algorithm for pooling. @noDerivative let padding: Padding @@ -829,10 +827,8 @@ public struct MaxPool2D: Layer { strides: (Int, Int, Int, Int), padding: Padding ) { - (self.poolSize.0, self.poolSize.1, self.poolSize.2, self.poolSize.3) - = (Int32(poolSize.0), Int32(poolSize.1), Int32(poolSize.2), Int32(poolSize.3)) - (self.strides.0, self.strides.1, self.strides.2, self.strides.3) - = (Int32(strides.0), Int32(strides.1), Int32(strides.2), Int32(strides.3)) + self.poolSize = poolSize + self.strides = strides self.padding = padding } @@ -843,8 +839,8 @@ public struct MaxPool2D: Layer { /// - strides: The strides. /// - padding: The padding. public init(poolSize: (Int, Int), strides: (Int, Int), padding: Padding = .valid) { - self.poolSize = (1, Int32(poolSize.0), Int32(poolSize.1), 1) - self.strides = (1, Int32(strides.0), Int32(strides.1), 1) + self.poolSize = (1, poolSize.0, poolSize.1, 1) + self.strides = (1, strides.0, strides.1, 1) self.padding = padding } @@ -864,9 +860,9 @@ public struct MaxPool2D: Layer { @_fixed_layout public struct AvgPool1D: Layer { /// The size of the sliding reduction window for pooling. - @noDerivative let poolSize: Int32 + @noDerivative let poolSize: Int /// The stride of the sliding window for temporal dimension. - @noDerivative let stride: Int32 + @noDerivative let stride: Int /// The padding algorithm for pooling. @noDerivative let padding: Padding @@ -881,8 +877,8 @@ public struct AvgPool1D: Layer { stride: Int, padding: Padding ) { - self.poolSize = Int32(poolSize) - self.stride = Int32(stride) + self.poolSize = poolSize + self.stride = stride self.padding = padding } @@ -903,10 +899,10 @@ public struct AvgPool1D: Layer { @_fixed_layout public struct AvgPool2D: Layer { /// The size of the sliding reduction window for pooling. - @noDerivative let poolSize: (Int32, Int32, Int32, Int32) + @noDerivative let poolSize: (Int, Int, Int, Int) /// The strides of the sliding window for each dimension of a 4-D input. /// Strides in non-spatial dimensions must be `1`. - @noDerivative let strides: (Int32, Int32, Int32, Int32) + @noDerivative let strides: (Int, Int, Int, Int) /// The padding algorithm for pooling. @noDerivative let padding: Padding @@ -916,10 +912,8 @@ public struct AvgPool2D: Layer { strides: (Int, Int, Int, Int), padding: Padding ) { - (self.poolSize.0, self.poolSize.1, self.poolSize.2, self.poolSize.3) - = (Int32(poolSize.0), Int32(poolSize.1), Int32(poolSize.2), Int32(poolSize.3)) - (self.strides.0, self.strides.1, self.strides.2, self.strides.3) - = (Int32(strides.0), Int32(strides.1), Int32(strides.2), Int32(strides.3)) + self.poolSize = poolSize + self.strides = strides self.padding = padding } @@ -930,8 +924,8 @@ public struct AvgPool2D: Layer { /// - strides: The strides. /// - padding: The padding. public init(poolSize: (Int, Int), strides: (Int, Int), padding: Padding = .valid) { - self.poolSize = (1, Int32(poolSize.0), Int32(poolSize.1), 1) - self.strides = (1, Int32(strides.0), Int32(strides.1), 1) + self.poolSize = (1, poolSize.0, poolSize.1, 1) + self.strides = (1, strides.0, strides.1, 1) self.padding = padding } @@ -1008,7 +1002,7 @@ public struct LayerNorm: Layer { /// The scale value, also known as gamma. public var scale: Tensor /// The axis. - @noDerivative public let axis: Int32 + @noDerivative public let axis: Int /// The variance epsilon value. @noDerivative public let epsilon: Tensor @@ -1021,7 +1015,7 @@ public struct LayerNorm: Layer { ) { self.offset = offset self.scale = scale - self.axis = Int32(axis) + self.axis = axis self.epsilon = epsilon } @@ -1035,8 +1029,8 @@ public struct LayerNorm: Layer { axis: Int, epsilon: Tensor = Tensor(0.001)) { self.init( - offset: Tensor(zeros: [Int32(featureCount)]), - scale: Tensor(ones: [Int32(featureCount)]), + offset: Tensor(zeros: [featureCount]), + scale: Tensor(ones: [featureCount]), axis: axis, epsilon: epsilon ) @@ -1127,12 +1121,12 @@ public struct Dropout: Layer { /// An upsampling layer for 1-D inputs. @_fixed_layout public struct UpSampling1D: Layer { - @noDerivative public let size: Int32 + @noDerivative public let size: Int /// Creates an upsampling layer. /// /// - Parameter size: The upsampling factor for timesteps. - public init(size: Int32) { + public init(size: Int) { self.size = size } @@ -1154,12 +1148,12 @@ public struct UpSampling1D: Layer { /// An upsampling layer for 2-D inputs. @_fixed_layout public struct UpSampling2D: Layer { - @noDerivative public let size: Int32 + @noDerivative public let size: Int /// Creates an upsampling layer. /// /// - Parameter size: The upsampling factor for rows and columns. - public init(size: Int32) { + public init(size: Int) { self.size = size } @@ -1220,7 +1214,7 @@ public struct Reshape: Layer { /// /// - Parameter shape: The target shape. public init(_ shape: TensorShape) { - self.init(shape: Tensor(shape.dimensions)) + self.init(shape: Tensor(shape.dimensions.map(Int32.init))) } /// Returns the output obtained from applying the layer to the given input. diff --git a/Sources/DeepLearning/Operators.swift b/Sources/DeepLearning/Operators.swift index e7a0b2ade..918a197ae 100644 --- a/Sources/DeepLearning/Operators.swift +++ b/Sources/DeepLearning/Operators.swift @@ -35,7 +35,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { // TODO: Verify that these calculations are correct. @inlinable internal func _vjpBatchNormalized( - alongAxis axis: Int32, + alongAxis axis: Int, offset: Tensor, scale: Tensor, epsilon: Scalar @@ -84,7 +84,7 @@ public extension Tensor where Scalar: BinaryFloatingPoint { where Scalar : TensorFlowFloatingPoint ) func batchNormalized( - alongAxis axis: Int32, + alongAxis axis: Int, offset: Tensor = Tensor(0), scale: Tensor = Tensor(1), epsilon: Scalar = 0.001 @@ -135,14 +135,14 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { internal func conv2DBackpropInput( shape: Tensor, filter: Tensor, - strides: (Int32, Int32, Int32, Int32), + strides: (Int, Int, Int, Int), padding: Padding ) -> Tensor { return Raw.conv2DBackpropInput( inputSizes: shape, filter: filter, outBackprop: self, - strides: [strides.0, strides.1, strides.2, strides.3], + strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2), Int32(strides.3)], padding: padding.raw2, explicitPaddings: []) } @@ -153,14 +153,14 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { internal func conv2DBackpropFilter( input: Tensor, filterSizes: Tensor, - strides: (Int32, Int32, Int32, Int32), + strides: (Int, Int, Int, Int), padding: Padding ) -> Tensor { return Raw.conv2DBackpropFilter( input, filterSizes: filterSizes, outBackprop: self, - strides: [strides.0, strides.1, strides.2, strides.3], + strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2), Int32(strides.3)], padding: padding.raw2, explicitPaddings: []) } @@ -169,7 +169,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { internal func _vjpConv2DBackpropInput( _ shape: Tensor, _ filter: Tensor, - _ strides: (Int32, Int32, Int32, Int32), + _ strides: (Int, Int, Int, Int), _ padding: Padding ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { let value = conv2DBackpropInput(shape: shape, filter: filter, strides: strides, @@ -187,7 +187,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { internal func _vjpConv2DBackpropFilter( _ input: Tensor, _ filterSizes: Tensor, - _ strides: (Int32, Int32, Int32, Int32), + _ strides: (Int, Int, Int, Int), _ padding: Padding ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { let value = conv2DBackpropFilter(input: input, filterSizes: filterSizes, @@ -204,7 +204,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { @inlinable internal func _vjpConvolved2D( filter: Tensor, - strides: (Int32, Int32, Int32, Int32), + strides: (Int, Int, Int, Int), padding: Padding ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { let value = convolved2D(withFilter: filter, strides: strides, @@ -225,8 +225,8 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { @inlinable internal func _vjpMaxPooled( - kernelSize: (Int32, Int32, Int32, Int32), - strides: (Int32, Int32, Int32, Int32), + kernelSize: (Int, Int, Int, Int), + strides: (Int, Int, Int, Int), padding: Padding ) -> (Tensor, (Tensor) -> Tensor) { // TODO: Currently this is not higher order differentiable. Redefine in @@ -238,8 +238,10 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { origInput: self, origOutput: value, grad: v, - ksize: Tensor([kernelSize.0, kernelSize.1, kernelSize.2, kernelSize.3]), - strides: Tensor([strides.0, strides.1, strides.2, strides.3]), + ksize: Tensor([Int32(kernelSize.0), Int32(kernelSize.1), + Int32(kernelSize.2), Int32(kernelSize.3)]), + strides: Tensor([Int32(strides.0), Int32(strides.1), + Int32(strides.2), Int32(strides.3)]), padding: padding.raw ) }) @@ -247,8 +249,8 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { @inlinable internal func _vjpAveragePooled( - kernelSize: (Int32, Int32, Int32, Int32), - strides: (Int32, Int32, Int32, Int32), + kernelSize: (Int, Int, Int, Int), + strides: (Int, Int, Int, Int), padding: Padding ) -> (Tensor, (Tensor) -> Tensor) { // TODO: Currently this is not higher order differentiable. Redefine in @@ -259,8 +261,9 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { return Raw.avgPoolGrad( origInputShape: self.shapeTensor, grad: v, - ksize: [kernelSize.0, kernelSize.1, kernelSize.2, kernelSize.3], - strides: [strides.0, strides.1, strides.2, strides.3], + ksize: [Int32(kernelSize.0), Int32(kernelSize.1), + Int32(kernelSize.2), Int32(kernelSize.3)], + strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2), Int32(strides.3)], padding: padding.raw ) }) @@ -285,13 +288,13 @@ public extension Tensor where Scalar: FloatingPoint { ) func convolved2D( withFilter filter: Tensor, - strides: (Int32, Int32, Int32, Int32), + strides: (Int, Int, Int, Int), padding: Padding ) -> Tensor { return Raw.conv2D( self, filter: filter, - strides: [strides.0, strides.1, strides.2, strides.3], + strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2), Int32(strides.3)], padding: padding.raw2, explicitPaddings: []) } @@ -310,16 +313,16 @@ public extension Tensor where Scalar: FloatingPoint { where Scalar : TensorFlowFloatingPoint ) func maxPooled( - kernelSize: (Int32, Int32, Int32, Int32), - strides: (Int32, Int32, Int32, Int32), + kernelSize: (Int, Int, Int, Int), + strides: (Int, Int, Int, Int), padding: Padding ) -> Tensor { return Raw.maxPoolV2( self, - ksize: Tensor([kernelSize.0, kernelSize.1, - kernelSize.2, kernelSize.3]), - strides: Tensor([strides.0, strides.1, - strides.2, strides.3]), + ksize: Tensor([Int32(kernelSize.0), Int32(kernelSize.1), + Int32(kernelSize.2), Int32(kernelSize.3)]), + strides: Tensor([Int32(strides.0), Int32(strides.1), + Int32(strides.2), Int32(strides.3)]), padding: padding.raw) } @@ -337,14 +340,15 @@ public extension Tensor where Scalar: FloatingPoint { where Scalar : TensorFlowFloatingPoint ) func averagePooled( - kernelSize: (Int32, Int32, Int32, Int32), - strides: (Int32, Int32, Int32, Int32), + kernelSize: (Int, Int, Int, Int), + strides: (Int, Int, Int, Int), padding: Padding ) -> Tensor { return Raw.avgPool( value: self, - ksize: [kernelSize.0, kernelSize.1, kernelSize.2, kernelSize.3], - strides: [strides.0, strides.1, strides.2, strides.3], + ksize: [Int32(kernelSize.0), Int32(kernelSize.1), + Int32(kernelSize.2), Int32(kernelSize.3)], + strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2), Int32(strides.3)], padding: padding.raw) } }