diff --git a/Sources/TensorFlow/Operators/NN.swift b/Sources/TensorFlow/Operators/NN.swift index 094226bf5..0b11305b8 100644 --- a/Sources/TensorFlow/Operators/NN.swift +++ b/Sources/TensorFlow/Operators/NN.swift @@ -229,12 +229,10 @@ func _vjpConv3D( let value = conv3D(x, filter: filter, strides: strides, padding: padding) return (value, { v in - return ( - conv3DBackpropInput(v, shape: x.shapeTensor, filter: filter, - strides: strides, padding: padding), - conv3DBackpropFilter(v, input: x, filterSizes: filter.shapeTensor, - strides: strides, padding: padding) - ) + (conv3DBackpropInput(v, shape: x.shapeTensor, filter: filter, + strides: strides, padding: padding), + conv3DBackpropFilter(v, input: x, filterSizes: filter.shapeTensor, + strides: strides, padding: padding)) }) } @@ -268,11 +266,9 @@ func _vjpConv3DBackpropInput( let value = conv3DBackpropInput(x, shape: shape, filter: filter, strides: strides, padding: padding) return (value, { v in - return ( - conv3DBackpropFilter(x, input: v, filterSizes: shape, strides: strides, - padding: padding), - conv3D(v, filter: filter, strides: strides, padding: padding) - ) + (conv3D(v, filter: filter, strides: strides, padding: padding), + conv3DBackpropFilter(x, input: v, filterSizes: filter.shapeTensor, strides: strides, + padding: padding)) }) } @@ -287,7 +283,7 @@ func conv3DBackpropFilter( padding: Padding = .valid ) -> Tensor { return Raw.conv3DBackpropFilterV2( - x, + input, filterSizes: filterSizes, outBackprop: x, strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2), @@ -306,11 +302,9 @@ func _vjpConv3DBackpropFilter( let value = conv3DBackpropFilter(x, input: input, filterSizes: filterSizes, strides: strides, padding: padding) return (value, { v in - return ( - conv3DBackpropInput(x, shape: filterSizes, filter: v, strides: strides, - padding: padding), - conv3D(input, filter: v, strides: strides, padding: padding) - ) + (conv3D(input, filter: v, strides: strides, padding: padding), + conv3DBackpropInput(x, shape: x.shapeTensor, filter: v, strides: strides, + padding: padding)) }) } diff --git a/Tests/TensorFlowTests/LayerTests.swift b/Tests/TensorFlowTests/LayerTests.swift index acfe4bb60..85b8f6837 100644 --- a/Tests/TensorFlowTests/LayerTests.swift +++ b/Tests/TensorFlowTests/LayerTests.swift @@ -198,6 +198,55 @@ final class LayerTests: XCTestCase { XCTAssertEqual(output, expected) } + func testConv3DGradient() { + let filter = Tensor(shape: [1, 4, 4, 1, 1], scalars: (0..<16).map(Float.init)) + let bias = Tensor(ones: [2]) + let layer = Conv3D(filter: filter, + bias: bias, + activation: identity, + strides: (2, 2, 2), + padding: .same) + let input = Tensor(shape: [1, 4, 4, 4, 1], scalars: (0..<64).map(Float.init)) + let grads = gradient(at: input, layer) { $1($0).sum() } + // The expected value of the gradient was computed using the following Python code: + // ``` + // import tensorflow as tf + // x = tf.reshape(tf.range(64, dtype=tf.float32), [1, 4, 4, 4, 1]) + // filter = tf.reshape(tf.range(72, dtype=tf.float32), [1, 4, 4, 1, 1]) + // bias = tf.ones([2]) + // with tf.GradientTape() as tape: + // tape.watch([x, filter, bias]) + // y = tf.math.reduce_sum(tf.nn.conv3d(input=x, + // filters=filter, + // strides=[1, 2, 2, 2, 1], + // padding="SAME") + bias) + // print(tape.gradient(y, [x, filter, bias])) + // ``` + XCTAssertEqual(grads.0, + [[[[[10.0], [20.0], [24.0], [12.0]], + [[20.0], [40.0], [48.0], [24.0]], + [[36.0], [72.0], [80.0], [40.0]], + [[18.0], [36.0], [40.0], [20.0]]], + [[[ 0.0], [ 0.0], [ 0.0], [ 0.0]], + [[ 0.0], [ 0.0], [ 0.0], [ 0.0]], + [[ 0.0], [ 0.0], [ 0.0], [ 0.0]], + [[ 0.0], [ 0.0], [ 0.0], [ 0.0]]], + [[[10.0], [20.0], [24.0], [12.0]], + [[20.0], [40.0], [48.0], [24.0]], + [[36.0], [72.0], [80.0], [40.0]], + [[18.0], [36.0], [40.0], [20.0]]], + [[[ 0.0], [ 0.0], [ 0.0], [ 0.0]], + [[ 0.0], [ 0.0], [ 0.0], [ 0.0]], + [[ 0.0], [ 0.0], [ 0.0], [ 0.0]], + [[ 0.0], [ 0.0], [ 0.0], [ 0.0]]]]]) + XCTAssertEqual(grads.1.filter, + [[[[[ 84.0]], [[168.0]], [[176.0]], [[ 88.0]]], + [[[168.0]], [[336.0]], [[352.0]], [[176.0]]], + [[[200.0]], [[400.0]], [[416.0]], [[208.0]]], + [[[100.0]], [[200.0]], [[208.0]], [[104.0]]]]]) + XCTAssertEqual(grads.1.bias, [8.0, 8.0]) + } + func testDepthwiseConv2D() { let filter = Tensor(shape: [2, 2, 2, 2], scalars: (0..<16).map(Float.init)) let bias = Tensor([1, 2, 3, 4]) @@ -1237,6 +1286,7 @@ final class LayerTests: XCTestCase { ("testConv2DGradient", testConv2DGradient), ("testConv2DDilation", testConv2DDilation), ("testConv3D", testConv3D), + ("testConv3DGradient", testConv3DGradient), ("testDepthwiseConv2D", testDepthwiseConv2D), ("testDepthwiseConv2DGradient", testDepthwiseConv2DGradient), ("testSeparableConv1D", testSeparableConv1D),