diff --git a/Sources/TensorFlow/Layers/Convolutional.swift b/Sources/TensorFlow/Layers/Convolutional.swift index cd3bf75a4..a4f99d6b0 100644 --- a/Sources/TensorFlow/Layers/Convolutional.swift +++ b/Sources/TensorFlow/Layers/Convolutional.swift @@ -632,6 +632,113 @@ public struct ZeroPadding3D: ParameterlessLayer } } +/// A 1-D separable convolution layer. +/// +/// This layer performs a depthwise convolution that acts separately on channels followed by +/// a pointwise convolution that mixes channels. +@frozen +public struct SeparableConv1D: Layer { + /// The 3-D depthwise convolution kernel. + public var depthwiseFilter: Tensor + /// The 3-D pointwise convolution kernel. + public var pointwiseFilter: Tensor + /// The bias vector. + public var bias: Tensor + /// The element-wise activation function. + @noDerivative public let activation: Activation + /// The strides of the sliding window for spatial dimensions. + @noDerivative public let stride: Int + /// The padding algorithm for convolution. + @noDerivative public let padding: Padding + + /// The element-wise activation function type. + public typealias Activation = @differentiable (Tensor) -> Tensor + + /// Creates a `SeparableConv1D` layer with the specified depthwise and pointwise filter, + /// bias, activation function, strides, and padding. + /// + /// - Parameters: + /// - depthwiseFilter: The 3-D depthwise convolution kernel + /// `[filter width, input channels count, channel multiplier]`. + /// - pointwiseFilter: The 3-D pointwise convolution kernel + /// `[1, channel multiplier * input channels count, output channels count]`. + /// - bias: The bias vector. + /// - activation: The element-wise activation function. + /// - strides: The strides of the sliding window for spatial dimensions. + /// - padding: The padding algorithm for convolution. + public init( + depthwiseFilter: Tensor, + pointwiseFilter: Tensor, + bias: Tensor, + activation: @escaping Activation = identity, + stride: Int = 1, + padding: Padding = .valid + ) { + self.depthwiseFilter = depthwiseFilter + self.pointwiseFilter = pointwiseFilter + self.bias = bias + self.activation = activation + self.stride = stride + self.padding = padding + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameter input: The input to the layer. + /// - Returns: The output. + @differentiable + public func callAsFunction(_ input: Tensor) -> Tensor { + let depthwise = depthwiseConv2D( + input.expandingShape(at: 1), + filter: depthwiseFilter.expandingShape(at: 1), + strides: (1, stride, stride, 1), + padding: padding) + let x = conv2D( + depthwise, + filter: pointwiseFilter.expandingShape(at: 1), + strides: (1, 1, 1, 1), + padding: padding, + dilations: (1, 1, 1, 1)) + return activation(x.squeezingShape(at: 1) + bias) + } +} + +public extension SeparableConv1D { + /// Creates a `SeparableConv1D` layer with the specified depthwise and pointwise filter shape, + /// strides, padding, and element-wise activation function. + /// + /// - Parameters: + /// - depthwiseFilterShape: The shape of the 3-D depthwise convolution kernel. + /// - pointwiseFilterShape: The shape of the 3-D pointwise convolution kernel. + /// - strides: The strides of the sliding window for temporal dimensions. + /// - padding: The padding algorithm for convolution. + /// - activation: The element-wise activation function. + /// - filterInitializer: Initializer to use for the filter parameters. + /// - biasInitializer: Initializer to use for the bias parameters. + init( + depthwiseFilterShape: (Int, Int, Int), + pointwiseFilterShape: (Int, Int, Int), + stride: Int = 1, + padding: Padding = .valid, + activation: @escaping Activation = identity, + depthwiseFilterInitializer: ParameterInitializer = glorotUniform(), + pointwiseFilterInitializer: ParameterInitializer = glorotUniform(), + biasInitializer: ParameterInitializer = zeros() + ) { + let depthwiseFilterTensorShape = TensorShape([ + depthwiseFilterShape.0, depthwiseFilterShape.1, depthwiseFilterShape.2]) + let pointwiseFilterTensorShape = TensorShape([ + pointwiseFilterShape.0, pointwiseFilterShape.1, pointwiseFilterShape.2]) + self.init( + depthwiseFilter: depthwiseFilterInitializer(depthwiseFilterTensorShape), + pointwiseFilter: pointwiseFilterInitializer(pointwiseFilterTensorShape), + bias: biasInitializer([pointwiseFilterShape.2]), + activation: activation, + stride: stride, + padding: padding) + } +} + /// A 2-D Separable convolution layer. /// /// This layer performs a depthwise convolution that acts separately on channels followed by diff --git a/Tests/TensorFlowTests/LayerTests.swift b/Tests/TensorFlowTests/LayerTests.swift index 87812fcf6..c0cd714cf 100644 --- a/Tests/TensorFlowTests/LayerTests.swift +++ b/Tests/TensorFlowTests/LayerTests.swift @@ -113,28 +113,28 @@ final class LayerTests: XCTestCase { func testConv2DGradient() { let filter = Tensor(shape: [3, 3, 2, 4], scalars: (0..<72).map(Float.init)) let bias = Tensor(zeros: [4]) - let layer = Conv2D(filter: filter, - bias: bias, + let layer = Conv2D(filter: filter, + bias: bias, activation: identity, - strides: (2, 2), + strides: (2, 2), padding: .valid) let input = Tensor(shape: [2, 4, 4, 2], scalars: (0..<64).map(Float.init)) let grads = gradient( at: input, layer) { $1($0).sum() } - // The expected gradients were computed using the following Python code: + // The expected gradients were computed using the following Python code: // ``` // x = tf.reshape(tf.range(64, dtype=tf.float32), [2, 4, 4, 2]) // filter = tf.reshape(tf.range(72, dtype=tf.float32), [3, 3, 2, 4]) // bias = tf.zeros([4]) // with tf.GradientTape() as t: // t.watch([x, filter, bias]) - // y = tf.math.reduce_sum(tf.nn.conv2d(input=x, + // y = tf.math.reduce_sum(tf.nn.conv2d(input=x, // filters=filter, // strides=[1, 2, 2, 1], // data_format="NHWC", // padding="VALID") + bias) // grads = t.gradient(y, [x, filter, bias]) // ``` - XCTAssertEqual(grads.0, + XCTAssertEqual(grads.0, [[[[ 6, 22], [ 38, 54], [ 70, 86], [ 0, 0]], [[102, 118], [134, 150], [166, 182], [ 0, 0]], [[198, 214], [230, 246], [262, 278], [ 0, 0]], @@ -143,7 +143,7 @@ final class LayerTests: XCTestCase { [[102, 118], [134, 150], [166, 182], [ 0, 0]], [[198, 214], [230, 246], [262, 278], [ 0, 0]], [[ 0, 0], [ 0, 0], [ 0, 0], [ 0, 0]]]]) - XCTAssertEqual(grads.1.filter, + XCTAssertEqual(grads.1.filter, [[[[32, 32, 32, 32], [34, 34, 34, 34]], [[36, 36, 36, 36], [38, 38, 38, 38]], [[40, 40, 40, 40], [42, 42, 42, 42]]], @@ -210,6 +210,22 @@ final class LayerTests: XCTestCase { XCTAssertEqual(output, expected) } + func testSeparableConv1D() { + let depthwiseFilter = Tensor(shape: [2, 2, 2], scalars: (0..<8).map(Float.init)) + let pointwiseFilter = Tensor(shape: [1, 4, 1], scalars: (0..<4).map(Float.init)) + let bias = Tensor([4]) + let layer = SeparableConv1D(depthwiseFilter: depthwiseFilter, + pointwiseFilter: pointwiseFilter, + bias: bias, + activation: identity, + stride: 1, + padding: .same) + let input = Tensor(shape: [2, 2, 2], scalars: (0..<8).map(Float.init)) + let output = layer.inferring(from: input) + let expected = Tensor(shape: [2, 2, 1], scalars: [17, 45, 73, 101]) + XCTAssertEqual(output, expected) + } + func testSeparableConv2D() { let depthwiseFilter = Tensor(shape: [2, 2, 2, 2], scalars: (0..<16).map(Float.init)) let pointwiseFilter = Tensor(shape: [1, 1, 4, 1], scalars: (0..<4).map(Float.init)) @@ -582,7 +598,7 @@ final class LayerTests: XCTestCase { // [ 0.0, 0.0, 0.0, 0.0]]) // XCTAssertEqual(𝛁rnn.cell.bias, [ 0.2496884, 0.66947335, 0.7978788, -0.22378457]) } - + func testLSTM() { let x = Tensor(rangeFrom: 0.0, to: 0.4, stride: 0.1).rankLifted() let inputs: [Tensor] = Array(repeating: x, count: 4) @@ -710,7 +726,7 @@ final class LayerTests: XCTestCase { // lnLayer = tf.keras.layers.LayerNormalization(axis=1, epsilon=0.001) // with tf.GradientTape() as t: // t.watch(x) - // y = lnLayer(x) + // y = lnLayer(x) // z = tf.math.reduce_sum(tf.math.square(y)) // print(y, t.gradient(z, [x] + lnLayer.trainable_variables)) // ``` @@ -729,8 +745,8 @@ final class LayerTests: XCTestCase { [-0.0019815 , 0.00164783, 0.00130618, 0.00119543, -0.00216818]], accuracy: 1e-5) assertEqual( - grad.1.offset, - [-0.645803 , -5.8017054 , 0.03168535, 5.973418 , 0.44240427], + grad.1.offset, + [-0.645803 , -5.8017054 , 0.03168535, 5.973418 , 0.44240427], accuracy: 1e-5) assertEqual( grad.1.scale, @@ -747,6 +763,7 @@ final class LayerTests: XCTestCase { ("testConv2DDilation", testConv2DDilation), ("testConv3D", testConv3D), ("testDepthConv2D", testDepthConv2D), + ("testSeparableConv1D", testSeparableConv1D), ("testSeparableConv2D", testSeparableConv2D), ("testZeroPadding1D", testZeroPadding1D), ("testZeroPadding2D", testZeroPadding2D),