diff --git a/Sources/DeepLearning/Layer.swift b/Sources/DeepLearning/Layer.swift index f7560d282..e86887e8b 100644 --- a/Sources/DeepLearning/Layer.swift +++ b/Sources/DeepLearning/Layer.swift @@ -1395,6 +1395,32 @@ public struct UpSampling3D: Layer { self.size = size } + /// Repeats the elements of a tensor along an axis, like `np.repeat`. + /// Function adapted from `def repeat_elements`: + /// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/backend.py + @differentiable(vjp: _vjpRepeatingElements) + private func repeatingElements( + _ input: Tensor, alongAxis axis: Int, count: Int + ) -> Tensor { + let splits = Raw.split(splitDim: Tensor(Int32(axis)), + value: input, numSplit: Int64(input.shape[axis])) + let repeated = splits.flatMap { x in Array(repeating: x, count: count) } + return Tensor(concatenating: repeated, alongAxis: axis) + } + + private func _vjpRepeatingElements( + _ input: Tensor, alongAxis axis: Int, count: Int + ) -> (Tensor, (Tensor) -> (AllDifferentiableVariables, Tensor)) { + let value = repeatingElements(input, alongAxis: axis, count: count) + return (value, { v in + let splits = Raw.split(splitDim: Tensor(Int32(axis)), + value: v, numSplit: Int64(input.shape[axis])) + let summed = splits.map { x in x.sum(alongAxes: axis) } + let concatenated = Tensor(concatenating: summed, alongAxis: axis) + return (.zero, concatenated) + }) + } + /// Returns the output obtained from applying the layer to the given input. /// /// - Parameter input: The input to the layer. @@ -1404,11 +1430,10 @@ public struct UpSampling3D: Layer { let shape = input.shape let (batchSize, height, width, depth, channels) = (shape[0], shape[1], shape[2], shape[3], shape[4]) - let scaleOnes = Tensor(ones: [1, 1, size, 1, size, 1, size, 1]) - let upSampling = input.reshaped( - to: [batchSize, height, 1, width, 1, depth, 1, channels]) * scaleOnes - return upSampling.reshaped( - to: [batchSize, height * size, width * size, depth * size, channels]) + var result = repeatingElements(input, alongAxis: 1, count: size) + result = repeatingElements(result, alongAxis: 2, count: size) + result = repeatingElements(result, alongAxis: 3, count: size) + return result } } diff --git a/Tests/DeepLearningTests/LayerTests.swift b/Tests/DeepLearningTests/LayerTests.swift index 9fd998852..9fffb0866 100644 --- a/Tests/DeepLearningTests/LayerTests.swift +++ b/Tests/DeepLearningTests/LayerTests.swift @@ -134,15 +134,10 @@ final class LayerTests: XCTestCase { let size = 6 let layer = UpSampling3D(size: size) let input = Tensor(shape: [1, 4, 3, 2, 1], scalars: (0..<24).map(Float.init)) - // TODO(TF-525): Fix `UpSampling3D.call`. - // Broadcasting does not support tensors with high rank: - // Broadcast between [1,4,1,3,1,2,1,1] and [1,1,6,1,6,1,6,1] is not supported yet. - /* let output = layer.inferring(from: input) let expected = TensorShape([1, input.shape[1] * size, input.shape[2] * size, input.shape[3] * size, 1]) XCTAssertEqual(output.shape, expected) XCTAssertEqual(output.shape, expected) - */ } func testReshape() {