From a686a76eccf5d44d21aec4af9544c22655aa73f4 Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Tue, 2 Apr 2019 08:17:10 -0400 Subject: [PATCH 1/3] Updated the convolution ops to support explicit paddings. --- Sources/DeepLearning/Operators.swift | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/Sources/DeepLearning/Operators.swift b/Sources/DeepLearning/Operators.swift index 51b2406a8..3a660c07a 100644 --- a/Sources/DeepLearning/Operators.swift +++ b/Sources/DeepLearning/Operators.swift @@ -98,12 +98,14 @@ public extension Tensor where Scalar: BinaryFloatingPoint { } //===------------------------------------------------------------------------------------------===// -// Convolution and pooling +// Convolution and Pooling //===------------------------------------------------------------------------------------------===// /// A padding scheme. Used by padding, convolution, and pooling ops. // @_frozen // SR-9739 public enum Padding { + /// The "explicit" padding scheme. + case explicit(_ paddings: [Int32]) /// The "valid" padding scheme. case valid /// The "same" padding scheme. @@ -112,12 +114,22 @@ public enum Padding { public extension Padding { @inlinable - var raw: Raw.Padding { + var raw: Raw.Padding2 { switch self { + case .explicit: return .explicit case .same: return .same case .valid: return .valid } } + + @inlinable + internal var explicitPaddings: [Int32] { + switch self { + case .explicit(let paddings): return paddings + case .same: return [] + case .valid: return [] + } + } } public extension Tensor where Scalar: TensorFlowFloatingPoint { @@ -135,7 +147,8 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { filter: filter, outBackprop: self, strides: [strides.0, strides.1, strides.2, strides.3], - padding: padding.raw) + padding: padding.raw, + explicitPaddings: padding.explicitPaddings) } /// TensorFlow builtin conv2d gradient helper for the filter. @@ -152,7 +165,8 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { filterSizes: filterSizes, outBackprop: self, strides: [strides.0, strides.1, strides.2, strides.3], - padding: padding.raw) + padding: padding.raw, + explicitPaddings: padding.explicitPaddings) } @inlinable @@ -282,7 +296,8 @@ public extension Tensor where Scalar: FloatingPoint { self, filter: filter, strides: [strides.0, strides.1, strides.2, strides.3], - padding: padding.raw) + padding: padding.raw, + explicitPaddings: padding.explicitPaddings) } /// Computes a 2-D max pooling, with the specified kernel sizes, strides, and From 8494c04d0ef849119b73cd566bccb304bb11a999 Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Tue, 2 Apr 2019 08:22:58 -0400 Subject: [PATCH 2/3] Small fix. --- Sources/DeepLearning/Operators.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/DeepLearning/Operators.swift b/Sources/DeepLearning/Operators.swift index 3a660c07a..501722a9e 100644 --- a/Sources/DeepLearning/Operators.swift +++ b/Sources/DeepLearning/Operators.swift @@ -105,7 +105,7 @@ public extension Tensor where Scalar: BinaryFloatingPoint { // @_frozen // SR-9739 public enum Padding { /// The "explicit" padding scheme. - case explicit(_ paddings: [Int32]) + case explicit([Int32]) /// The "valid" padding scheme. case valid /// The "same" padding scheme. From 5dfaaeecf5b48e4784ec2c3b9c0607f2bb7ceced Mon Sep 17 00:00:00 2001 From: Anthony Platanios Date: Tue, 2 Apr 2019 08:32:29 -0400 Subject: [PATCH 3/3] Added documentation string for the "explicit" padding scheme. --- Sources/DeepLearning/Operators.swift | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Sources/DeepLearning/Operators.swift b/Sources/DeepLearning/Operators.swift index 501722a9e..827890b3d 100644 --- a/Sources/DeepLearning/Operators.swift +++ b/Sources/DeepLearning/Operators.swift @@ -104,7 +104,8 @@ public extension Tensor where Scalar: BinaryFloatingPoint { /// A padding scheme. Used by padding, convolution, and pooling ops. // @_frozen // SR-9739 public enum Padding { - /// The "explicit" padding scheme. + /// The "explicit" padding scheme, which is defined by an array indicating the explicit padding + /// sizes at the start and end of each dimension. case explicit([Int32]) /// The "valid" padding scheme. case valid