Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions Sources/TensorFlow/Layers/Convolutional.swift
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,113 @@ public struct ZeroPadding3D<Scalar: TensorFlowFloatingPoint>: ParameterlessLayer
}
}

/// A 1-D separable convolution layer.
///
/// This layer performs a depthwise convolution that acts separately on channels followed by
/// a pointwise convolution that mixes channels.
@frozen
public struct SeparableConv1D<Scalar: TensorFlowFloatingPoint>: Layer {
/// The 3-D depthwise convolution kernel.
public var depthwiseFilter: Tensor<Scalar>
/// The 3-D pointwise convolution kernel.
public var pointwiseFilter: Tensor<Scalar>
/// The bias vector.
public var bias: Tensor<Scalar>
/// The element-wise activation function.
@noDerivative public let activation: Activation
/// The strides of the sliding window for spatial dimensions.
@noDerivative public let stride: Int
/// The padding algorithm for convolution.
@noDerivative public let padding: Padding

/// The element-wise activation function type.
public typealias Activation = @differentiable (Tensor<Scalar>) -> Tensor<Scalar>

/// Creates a `SeparableConv1D` layer with the specified depthwise and pointwise filter,
/// bias, activation function, strides, and padding.
///
/// - Parameters:
/// - depthwiseFilter: The 3-D depthwise convolution kernel
/// `[filter width, input channels count, channel multiplier]`.
/// - pointwiseFilter: The 3-D pointwise convolution kernel
/// `[1, channel multiplier * input channels count, output channels count]`.
/// - bias: The bias vector.
/// - activation: The element-wise activation function.
/// - strides: The strides of the sliding window for spatial dimensions.
/// - padding: The padding algorithm for convolution.
public init(
depthwiseFilter: Tensor<Scalar>,
pointwiseFilter: Tensor<Scalar>,
bias: Tensor<Scalar>,
activation: @escaping Activation = identity,
stride: Int = 1,
padding: Padding = .valid
) {
self.depthwiseFilter = depthwiseFilter
self.pointwiseFilter = pointwiseFilter
self.bias = bias
self.activation = activation
self.stride = stride
self.padding = padding
}

/// Returns the output obtained from applying the layer to the given input.
///
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
let depthwise = depthwiseConv2D(
input.expandingShape(at: 1),
filter: depthwiseFilter.expandingShape(at: 1),
strides: (1, stride, stride, 1),
padding: padding)
let x = conv2D(
depthwise,
filter: pointwiseFilter.expandingShape(at: 1),
strides: (1, 1, 1, 1),
padding: padding,
dilations: (1, 1, 1, 1))
return activation(x.squeezingShape(at: 1) + bias)
}
}

public extension SeparableConv1D {
/// Creates a `SeparableConv1D` layer with the specified depthwise and pointwise filter shape,
/// strides, padding, and element-wise activation function.
///
/// - Parameters:
/// - depthwiseFilterShape: The shape of the 3-D depthwise convolution kernel.
/// - pointwiseFilterShape: The shape of the 3-D pointwise convolution kernel.
/// - strides: The strides of the sliding window for temporal dimensions.
/// - padding: The padding algorithm for convolution.
/// - activation: The element-wise activation function.
/// - filterInitializer: Initializer to use for the filter parameters.
/// - biasInitializer: Initializer to use for the bias parameters.
init(
depthwiseFilterShape: (Int, Int, Int),
pointwiseFilterShape: (Int, Int, Int),
stride: Int = 1,
padding: Padding = .valid,
activation: @escaping Activation = identity,
depthwiseFilterInitializer: ParameterInitializer<Scalar> = glorotUniform(),
pointwiseFilterInitializer: ParameterInitializer<Scalar> = glorotUniform(),
biasInitializer: ParameterInitializer<Scalar> = zeros()
) {
let depthwiseFilterTensorShape = TensorShape([
depthwiseFilterShape.0, depthwiseFilterShape.1, depthwiseFilterShape.2])
let pointwiseFilterTensorShape = TensorShape([
pointwiseFilterShape.0, pointwiseFilterShape.1, pointwiseFilterShape.2])
self.init(
depthwiseFilter: depthwiseFilterInitializer(depthwiseFilterTensorShape),
pointwiseFilter: pointwiseFilterInitializer(pointwiseFilterTensorShape),
bias: biasInitializer([pointwiseFilterShape.2]),
activation: activation,
stride: stride,
padding: padding)
}
}

/// A 2-D Separable convolution layer.
///
/// This layer performs a depthwise convolution that acts separately on channels followed by
Expand Down
39 changes: 28 additions & 11 deletions Tests/TensorFlowTests/LayerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -113,28 +113,28 @@ final class LayerTests: XCTestCase {
func testConv2DGradient() {
let filter = Tensor(shape: [3, 3, 2, 4], scalars: (0..<72).map(Float.init))
let bias = Tensor<Float>(zeros: [4])
let layer = Conv2D<Float>(filter: filter,
bias: bias,
let layer = Conv2D<Float>(filter: filter,
bias: bias,
activation: identity,
strides: (2, 2),
strides: (2, 2),
padding: .valid)
let input = Tensor(shape: [2, 4, 4, 2], scalars: (0..<64).map(Float.init))
let grads = gradient( at: input, layer) { $1($0).sum() }
// The expected gradients were computed using the following Python code:
// The expected gradients were computed using the following Python code:
// ```
// x = tf.reshape(tf.range(64, dtype=tf.float32), [2, 4, 4, 2])
// filter = tf.reshape(tf.range(72, dtype=tf.float32), [3, 3, 2, 4])
// bias = tf.zeros([4])
// with tf.GradientTape() as t:
// t.watch([x, filter, bias])
// y = tf.math.reduce_sum(tf.nn.conv2d(input=x,
// y = tf.math.reduce_sum(tf.nn.conv2d(input=x,
// filters=filter,
// strides=[1, 2, 2, 1],
// data_format="NHWC",
// padding="VALID") + bias)
// grads = t.gradient(y, [x, filter, bias])
// ```
XCTAssertEqual(grads.0,
XCTAssertEqual(grads.0,
[[[[ 6, 22], [ 38, 54], [ 70, 86], [ 0, 0]],
[[102, 118], [134, 150], [166, 182], [ 0, 0]],
[[198, 214], [230, 246], [262, 278], [ 0, 0]],
Expand All @@ -143,7 +143,7 @@ final class LayerTests: XCTestCase {
[[102, 118], [134, 150], [166, 182], [ 0, 0]],
[[198, 214], [230, 246], [262, 278], [ 0, 0]],
[[ 0, 0], [ 0, 0], [ 0, 0], [ 0, 0]]]])
XCTAssertEqual(grads.1.filter,
XCTAssertEqual(grads.1.filter,
[[[[32, 32, 32, 32], [34, 34, 34, 34]],
[[36, 36, 36, 36], [38, 38, 38, 38]],
[[40, 40, 40, 40], [42, 42, 42, 42]]],
Expand Down Expand Up @@ -210,6 +210,22 @@ final class LayerTests: XCTestCase {
XCTAssertEqual(output, expected)
}

func testSeparableConv1D() {
let depthwiseFilter = Tensor(shape: [2, 2, 2], scalars: (0..<8).map(Float.init))
let pointwiseFilter = Tensor(shape: [1, 4, 1], scalars: (0..<4).map(Float.init))
let bias = Tensor<Float>([4])
let layer = SeparableConv1D<Float>(depthwiseFilter: depthwiseFilter,
pointwiseFilter: pointwiseFilter,
bias: bias,
activation: identity,
stride: 1,
padding: .same)
let input = Tensor(shape: [2, 2, 2], scalars: (0..<8).map(Float.init))
let output = layer.inferring(from: input)
let expected = Tensor<Float>(shape: [2, 2, 1], scalars: [17, 45, 73, 101])
XCTAssertEqual(output, expected)
}

func testSeparableConv2D() {
let depthwiseFilter = Tensor(shape: [2, 2, 2, 2], scalars: (0..<16).map(Float.init))
let pointwiseFilter = Tensor(shape: [1, 1, 4, 1], scalars: (0..<4).map(Float.init))
Expand Down Expand Up @@ -582,7 +598,7 @@ final class LayerTests: XCTestCase {
// [ 0.0, 0.0, 0.0, 0.0]])
// XCTAssertEqual(𝛁rnn.cell.bias, [ 0.2496884, 0.66947335, 0.7978788, -0.22378457])
}

func testLSTM() {
let x = Tensor<Float>(rangeFrom: 0.0, to: 0.4, stride: 0.1).rankLifted()
let inputs: [Tensor<Float>] = Array(repeating: x, count: 4)
Expand Down Expand Up @@ -710,7 +726,7 @@ final class LayerTests: XCTestCase {
// lnLayer = tf.keras.layers.LayerNormalization(axis=1, epsilon=0.001)
// with tf.GradientTape() as t:
// t.watch(x)
// y = lnLayer(x)
// y = lnLayer(x)
// z = tf.math.reduce_sum(tf.math.square(y))
// print(y, t.gradient(z, [x] + lnLayer.trainable_variables))
// ```
Expand All @@ -729,8 +745,8 @@ final class LayerTests: XCTestCase {
[-0.0019815 , 0.00164783, 0.00130618, 0.00119543, -0.00216818]],
accuracy: 1e-5)
assertEqual(
grad.1.offset,
[-0.645803 , -5.8017054 , 0.03168535, 5.973418 , 0.44240427],
grad.1.offset,
[-0.645803 , -5.8017054 , 0.03168535, 5.973418 , 0.44240427],
accuracy: 1e-5)
assertEqual(
grad.1.scale,
Expand All @@ -747,6 +763,7 @@ final class LayerTests: XCTestCase {
("testConv2DDilation", testConv2DDilation),
("testConv3D", testConv3D),
("testDepthConv2D", testDepthConv2D),
("testSeparableConv1D", testSeparableConv1D),
("testSeparableConv2D", testSeparableConv2D),
("testZeroPadding1D", testZeroPadding1D),
("testZeroPadding2D", testZeroPadding2D),
Expand Down