Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.
51 changes: 51 additions & 0 deletions Sources/TensorFlow/Layers/Pooling.swift
Original file line number Diff line number Diff line change
Expand Up @@ -340,3 +340,54 @@ public struct GlobalAvgPool3D<Scalar: TensorFlowFloatingPoint>: Layer {
return input.mean(squeezingAxes: [1, 2, 3])
}
}

/// A global max pooling layer for temporal data.
@_fixed_layout
public struct GlobalMaxPool1D<Scalar: TensorFlowFloatingPoint>: Layer {
/// Creates a global max pooling layer.
public init() {}

/// Returns the output obtained from applying the layer to the given input.
///
/// - Parameters:
/// - input: The input to the layer.
/// - context: The contextual information for the layer application, e.g. the current learning
/// phase.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
return input.max(squeezingAxes: 1)
}
}

/// A global max pooling layer for spatial data.
@_fixed_layout
public struct GlobalMaxPool2D<Scalar: TensorFlowFloatingPoint>: Layer {
/// Creates a global max pooling layer.
public init() {}

/// Returns the output obtained from applying the layer to the given input.
///
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
return input.max(squeezingAxes: [1, 2])
}
}

/// A global max pooling layer for spatial and spatio-temporal data.
@_fixed_layout
public struct GlobalMaxPool3D<Scalar: TensorFlowFloatingPoint>: Layer {
/// Creates a global max pooling layer.
public init() {}

/// Returns the output obtained from applying the layer to the given input.
///
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func call(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
return input.max(squeezingAxes: [1, 2, 3])
}
}
27 changes: 27 additions & 0 deletions Tests/TensorFlowTests/LayerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,30 @@ final class LayerTests: XCTestCase {
XCTAssertEqual(output, expected)
}

func testGlobalMaxPool1D() {
let layer = GlobalMaxPool1D<Float>()
let input = Tensor(shape: [1, 10, 1], scalars: (0..<10).map(Float.init))
let output = layer.inferring(from: input)
let expected = Tensor<Float>([9])
XCTAssertEqual(output, expected)
}

func testGlobalMaxPool2D() {
let layer = GlobalMaxPool2D<Float>()
let input = Tensor(shape: [1, 2, 10, 1], scalars: (0..<20).map(Float.init))
let output = layer.inferring(from: input)
let expected = Tensor<Float>([19])
XCTAssertEqual(output, expected)
}

func testGlobalMaxPool3D() {
let layer = GlobalMaxPool3D<Float>()
let input = Tensor<Float>(shape: [1, 2, 3, 5, 1], scalars: (0..<30).map(Float.init))
let output = layer.inferring(from: input)
let expected = Tensor<Float>([29])
XCTAssertEqual(output, expected)
}

func testUpSampling1D() {
let size = 6
let layer = UpSampling1D<Float>(size: size)
Expand Down Expand Up @@ -226,6 +250,9 @@ final class LayerTests: XCTestCase {
("testGlobalAvgPool1D", testGlobalAvgPool1D),
("testGlobalAvgPool2D", testGlobalAvgPool2D),
("testGlobalAvgPool3D", testGlobalAvgPool3D),
("testGlobalMaxPool1D", testGlobalMaxPool1D),
("testGlobalMaxPool2D", testGlobalMaxPool2D),
("testGlobalMaxPool3D", testGlobalMaxPool3D),
("testUpSampling1D", testUpSampling1D),
("testUpSampling2D", testUpSampling2D),
("testUpSampling3D", testUpSampling3D),
Expand Down