diff --git a/Sources/TensorFlow/Layers/Convolutional.swift b/Sources/TensorFlow/Layers/Convolutional.swift index 45edbaf4c..88fe23728 100644 --- a/Sources/TensorFlow/Layers/Convolutional.swift +++ b/Sources/TensorFlow/Layers/Convolutional.swift @@ -377,7 +377,7 @@ public struct TransposedConv2D: Layer { /// /// - Parameters: /// - filter: A 4-D tensor of shape - /// `[width, height, input channel count, output channel count]`. + /// `[height, width, output channel count, input channel count]`. /// - bias: The bias tensor of shape `[output channel count]`. /// - activation: The element-wise activation function. /// - strides: The strides of the sliding window for spatial dimensions. @@ -404,12 +404,12 @@ public struct TransposedConv2D: Layer { @differentiable public func callAsFunction(_ input: Tensor) -> Tensor { let batchSize = input.shape[0] - let w = (input.shape[1] - (1 * paddingIndex)) * + let h = (input.shape[1] - (1 * paddingIndex)) * strides.0 + (filter.shape[0] * paddingIndex) - let h = (input.shape[2] - (1 * paddingIndex)) * + let w = (input.shape[2] - (1 * paddingIndex)) * strides.1 + (filter.shape[1] * paddingIndex) let c = filter.shape[2] - let newShape = Tensor([Int32(batchSize), Int32(w), Int32(h), Int32(c)]) + let newShape = Tensor([Int32(batchSize), Int32(h), Int32(w), Int32(c)]) return activation(conv2DBackpropInput( input, shape: newShape, diff --git a/Tests/TensorFlowTests/LayerTests.swift b/Tests/TensorFlowTests/LayerTests.swift index 070d40f36..0f177dce2 100644 --- a/Tests/TensorFlowTests/LayerTests.swift +++ b/Tests/TensorFlowTests/LayerTests.swift @@ -302,6 +302,18 @@ final class LayerTests: XCTestCase { XCTAssertEqual(grads.1.bias, [4, 4, 4, 4]) } + func testTransposedConv2D() { + let filter = Tensor(shape: [4, 2, 1, 1], scalars: (0..<8).map(Float.init)) + let bias = Tensor([8]) + let layer = TransposedConv2D(filter: filter, bias: bias, activation: identity, + strides: (1, 1), padding: .same) + let input = Tensor(shape: [1, 4, 2, 1], scalars: (0..<8).map(Float.init)) + let output = layer.inferring(from: input) + let expected = Tensor(shape: [1, 4, 2, 1], + scalars: [8, 12, 12, 28, 24, 64, 48, 112]) + XCTAssertEqual(output, expected) + } + func testSeparableConv1D() { let depthwiseFilter = Tensor(shape: [2, 2, 2], scalars: (0..<8).map(Float.init)) let pointwiseFilter = Tensor(shape: [1, 4, 1], scalars: (0..<4).map(Float.init)) @@ -1318,6 +1330,7 @@ final class LayerTests: XCTestCase { ("testConv2DDilation", testConv2DDilation), ("testConv3D", testConv3D), ("testConv3DGradient", testConv3DGradient), + ("testTransposedConv2D", testTransposedConv2D), ("testDepthwiseConv2D", testDepthwiseConv2D), ("testDepthwiseConv2DGradient", testDepthwiseConv2DGradient), ("testSeparableConv1D", testSeparableConv1D),