diff --git a/Sources/TensorFlow/Layers/Pooling.swift b/Sources/TensorFlow/Layers/Pooling.swift index 14a6fdfeb..7a9ca9e0e 100644 --- a/Sources/TensorFlow/Layers/Pooling.swift +++ b/Sources/TensorFlow/Layers/Pooling.swift @@ -340,3 +340,54 @@ public struct GlobalAvgPool3D: Layer { return input.mean(squeezingAxes: [1, 2, 3]) } } + +/// A global max pooling layer for temporal data. +@_fixed_layout +public struct GlobalMaxPool1D: Layer { + /// Creates a global max pooling layer. + public init() {} + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - context: The contextual information for the layer application, e.g. the current learning + /// phase. + /// - Returns: The output. + @differentiable + public func call(_ input: Tensor) -> Tensor { + return input.max(squeezingAxes: 1) + } +} + +/// A global max pooling layer for spatial data. +@_fixed_layout +public struct GlobalMaxPool2D: Layer { + /// Creates a global max pooling layer. + public init() {} + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameter input: The input to the layer. + /// - Returns: The output. + @differentiable + public func call(_ input: Tensor) -> Tensor { + return input.max(squeezingAxes: [1, 2]) + } +} + +/// A global max pooling layer for spatial and spatio-temporal data. +@_fixed_layout +public struct GlobalMaxPool3D: Layer { + /// Creates a global max pooling layer. + public init() {} + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameter input: The input to the layer. + /// - Returns: The output. + @differentiable + public func call(_ input: Tensor) -> Tensor { + return input.max(squeezingAxes: [1, 2, 3]) + } +} diff --git a/Tests/TensorFlowTests/LayerTests.swift b/Tests/TensorFlowTests/LayerTests.swift index 313a70859..d8e258d5d 100644 --- a/Tests/TensorFlowTests/LayerTests.swift +++ b/Tests/TensorFlowTests/LayerTests.swift @@ -124,6 +124,30 @@ final class LayerTests: XCTestCase { XCTAssertEqual(output, expected) } + func testGlobalMaxPool1D() { + let layer = GlobalMaxPool1D() + let input = Tensor(shape: [1, 10, 1], scalars: (0..<10).map(Float.init)) + let output = layer.inferring(from: input) + let expected = Tensor([9]) + XCTAssertEqual(output, expected) + } + + func testGlobalMaxPool2D() { + let layer = GlobalMaxPool2D() + let input = Tensor(shape: [1, 2, 10, 1], scalars: (0..<20).map(Float.init)) + let output = layer.inferring(from: input) + let expected = Tensor([19]) + XCTAssertEqual(output, expected) + } + + func testGlobalMaxPool3D() { + let layer = GlobalMaxPool3D() + let input = Tensor(shape: [1, 2, 3, 5, 1], scalars: (0..<30).map(Float.init)) + let output = layer.inferring(from: input) + let expected = Tensor([29]) + XCTAssertEqual(output, expected) + } + func testUpSampling1D() { let size = 6 let layer = UpSampling1D(size: size) @@ -226,6 +250,9 @@ final class LayerTests: XCTestCase { ("testGlobalAvgPool1D", testGlobalAvgPool1D), ("testGlobalAvgPool2D", testGlobalAvgPool2D), ("testGlobalAvgPool3D", testGlobalAvgPool3D), + ("testGlobalMaxPool1D", testGlobalMaxPool1D), + ("testGlobalMaxPool2D", testGlobalMaxPool2D), + ("testGlobalMaxPool3D", testGlobalMaxPool3D), ("testUpSampling1D", testUpSampling1D), ("testUpSampling2D", testUpSampling2D), ("testUpSampling3D", testUpSampling3D),