diff --git a/Sources/DeepLearning/Operators.swift b/Sources/DeepLearning/Operators.swift index 51b2406a8..827890b3d 100644 --- a/Sources/DeepLearning/Operators.swift +++ b/Sources/DeepLearning/Operators.swift @@ -98,12 +98,15 @@ public extension Tensor where Scalar: BinaryFloatingPoint { } //===------------------------------------------------------------------------------------------===// -// Convolution and pooling +// Convolution and Pooling //===------------------------------------------------------------------------------------------===// /// A padding scheme. Used by padding, convolution, and pooling ops. // @_frozen // SR-9739 public enum Padding { + /// The "explicit" padding scheme, which is defined by an array indicating the explicit padding + /// sizes at the start and end of each dimension. + case explicit([Int32]) /// The "valid" padding scheme. case valid /// The "same" padding scheme. @@ -112,12 +115,22 @@ public enum Padding { public extension Padding { @inlinable - var raw: Raw.Padding { + var raw: Raw.Padding2 { switch self { + case .explicit: return .explicit case .same: return .same case .valid: return .valid } } + + @inlinable + internal var explicitPaddings: [Int32] { + switch self { + case .explicit(let paddings): return paddings + case .same: return [] + case .valid: return [] + } + } } public extension Tensor where Scalar: TensorFlowFloatingPoint { @@ -135,7 +148,8 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { filter: filter, outBackprop: self, strides: [strides.0, strides.1, strides.2, strides.3], - padding: padding.raw) + padding: padding.raw, + explicitPaddings: padding.explicitPaddings) } /// TensorFlow builtin conv2d gradient helper for the filter. @@ -152,7 +166,8 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { filterSizes: filterSizes, outBackprop: self, strides: [strides.0, strides.1, strides.2, strides.3], - padding: padding.raw) + padding: padding.raw, + explicitPaddings: padding.explicitPaddings) } @inlinable @@ -282,7 +297,8 @@ public extension Tensor where Scalar: FloatingPoint { self, filter: filter, strides: [strides.0, strides.1, strides.2, strides.3], - padding: padding.raw) + padding: padding.raw, + explicitPaddings: padding.explicitPaddings) } /// Computes a 2-D max pooling, with the specified kernel sizes, strides, and