Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.
25 changes: 10 additions & 15 deletions Sources/DeepLearning/Operators.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,6 @@ public extension Tensor where Scalar: BinaryFloatingPoint {
/// A padding scheme. Used by padding, convolution, and pooling ops.
// @_frozen // SR-9739
public enum Padding {
/// The "explicit" padding scheme, which is defined by an array indicating the explicit padding
/// sizes at the start and end of each dimension.
case explicit([Int32])
/// The "valid" padding scheme.
case valid
/// The "same" padding scheme.
Expand All @@ -115,20 +112,18 @@ public enum Padding {

public extension Padding {
@inlinable
var raw: Raw.Padding2 {
internal var raw: Raw.Padding {
switch self {
case .explicit: return .explicit
case .same: return .same
case .valid: return .valid
}
}

@inlinable
internal var explicitPaddings: [Int32] {
internal var raw2: Raw.Padding2 {
switch self {
case .explicit(let paddings): return paddings
case .same: return []
case .valid: return []
case .same: return .same
case .valid: return .valid
}
}
}
Expand All @@ -148,8 +143,8 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
filter: filter,
outBackprop: self,
strides: [strides.0, strides.1, strides.2, strides.3],
padding: padding.raw,
explicitPaddings: padding.explicitPaddings)
padding: padding.raw2,
explicitPaddings: [])
}

/// TensorFlow builtin conv2d gradient helper for the filter.
Expand All @@ -166,8 +161,8 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
filterSizes: filterSizes,
outBackprop: self,
strides: [strides.0, strides.1, strides.2, strides.3],
padding: padding.raw,
explicitPaddings: padding.explicitPaddings)
padding: padding.raw2,
explicitPaddings: [])
}

@inlinable
Expand Down Expand Up @@ -297,8 +292,8 @@ public extension Tensor where Scalar: FloatingPoint {
self,
filter: filter,
strides: [strides.0, strides.1, strides.2, strides.3],
padding: padding.raw,
explicitPaddings: padding.explicitPaddings)
padding: padding.raw2,
explicitPaddings: [])
}

/// Computes a 2-D max pooling, with the specified kernel sizes, strides, and
Expand Down