diff --git a/Sources/DeepLearning/Operators.swift b/Sources/DeepLearning/Operators.swift index 827890b3d..e7a0b2ade 100644 --- a/Sources/DeepLearning/Operators.swift +++ b/Sources/DeepLearning/Operators.swift @@ -104,9 +104,6 @@ public extension Tensor where Scalar: BinaryFloatingPoint { /// A padding scheme. Used by padding, convolution, and pooling ops. // @_frozen // SR-9739 public enum Padding { - /// The "explicit" padding scheme, which is defined by an array indicating the explicit padding - /// sizes at the start and end of each dimension. - case explicit([Int32]) /// The "valid" padding scheme. case valid /// The "same" padding scheme. @@ -115,20 +112,18 @@ public enum Padding { public extension Padding { @inlinable - var raw: Raw.Padding2 { + internal var raw: Raw.Padding { switch self { - case .explicit: return .explicit case .same: return .same case .valid: return .valid } } @inlinable - internal var explicitPaddings: [Int32] { + internal var raw2: Raw.Padding2 { switch self { - case .explicit(let paddings): return paddings - case .same: return [] - case .valid: return [] + case .same: return .same + case .valid: return .valid } } } @@ -148,8 +143,8 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { filter: filter, outBackprop: self, strides: [strides.0, strides.1, strides.2, strides.3], - padding: padding.raw, - explicitPaddings: padding.explicitPaddings) + padding: padding.raw2, + explicitPaddings: []) } /// TensorFlow builtin conv2d gradient helper for the filter. @@ -166,8 +161,8 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { filterSizes: filterSizes, outBackprop: self, strides: [strides.0, strides.1, strides.2, strides.3], - padding: padding.raw, - explicitPaddings: padding.explicitPaddings) + padding: padding.raw2, + explicitPaddings: []) } @inlinable @@ -297,8 +292,8 @@ public extension Tensor where Scalar: FloatingPoint { self, filter: filter, strides: [strides.0, strides.1, strides.2, strides.3], - padding: padding.raw, - explicitPaddings: padding.explicitPaddings) + padding: padding.raw2, + explicitPaddings: []) } /// Computes a 2-D max pooling, with the specified kernel sizes, strides, and