From f72084f26524cfbfaf5762e52761268adc7e3296 Mon Sep 17 00:00:00 2001 From: PAWAN SASANKA AMMANAMANCHI Date: Tue, 21 May 2019 23:25:09 +0530 Subject: [PATCH 1/6] adding convolutional 3-d --- Sources/DeepLearning/Layer.swift | 116 +++++++++++++++++++++++++ Sources/DeepLearning/Operators.swift | 122 +++++++++++++++++++++++++++ 2 files changed, 238 insertions(+) diff --git a/Sources/DeepLearning/Layer.swift b/Sources/DeepLearning/Layer.swift index 2eeb63f01..5a92c863f 100644 --- a/Sources/DeepLearning/Layer.swift +++ b/Sources/DeepLearning/Layer.swift @@ -513,6 +513,122 @@ public extension Conv2D { } } +/// A 3-D convolution layer for spatial/spatio-temporal convolution over images. +/// +/// This layer creates a convolution filter that is convolved with the layer input to produce a +/// tensor of outputs. +@_fixed_layout +public struct Conv3D: Layer { + /// The 5-D convolution kernel. + public var filter: Tensor + /// The bias vector. + public var bias: Tensor + /// An activation function. + public typealias Activation = @differentiable (Tensor) -> Tensor + /// The element-wise activation function. + @noDerivative public let activation: Activation + /// The strides of the sliding window for spatial dimensions. + @noDerivative public let strides: (Int, Int, Int) + /// The padding algorithm for convolution. + @noDerivative public let padding: Padding + + /// Creates a `Conv3D` layer with the specified filter, bias, activation function, strides, and + /// padding. + /// + /// - Parameters: + /// - filter: The 5-D convolution kernel. + /// - bias: The bias vector. + /// - activation: The element-wise activation function. + /// - strides: The strides of the sliding window for spatial dimensions. + /// - padding: The padding algorithm for convolution. + public init( + filter: Tensor, + bias: Tensor, + activation: @escaping Activation, + strides: (Int, Int), + padding: Padding + ) { + self.filter = filter + self.bias = bias + self.activation = activation + self.strides = strides + self.padding = padding + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameter input: The input to the layer. + /// - Returns: The output. + @differentiable + public func call(_ input: Tensor) -> Tensor { + return activation(input.convolved3D(withFilter: filter, + strides: (1, strides.0, strides.1, strides.2, 1), + padding: padding) + bias) + } +} + +public extension Conv3D { + /// Creates a `Conv3D` layer with the specified filter shape, strides, padding, and + /// element-wise activation function. The filter tensor is initialized using Glorot uniform + /// initialization with the specified generator. The bias vector is initialized with zeros. + /// + /// - Parameters: + /// - filterShape: The shape of the 5-D convolution kernel. + /// - strides: The strides of the sliding window for spatial/spatio-temporal dimensions. + /// - padding: The padding algorithm for convolution. + /// - activation: The element-wise activation function. + /// - generator: The random number generator for initialization. + /// + /// - Note: Use `init(filterShape:strides:padding:activation:seed:)` for faster random + /// initialization. + init( + filterShape: (Int, Int, Int, Int, Int), + strides: (Int, Int, Int) = (1, 1, 1), + padding: Padding = .valid, + activation: @escaping Activation = identity, + generator: inout G + ) { + let filterTensorShape = TensorShape([ + filterShape.0, filterShape.1, filterShape.2, filterShape.3, filterShape.4]) + self.init( + filter: Tensor(glorotUniform: filterTensorShape, generator: &generator), + bias: Tensor(zeros: TensorShape([filterShape.4])), + activation: activation, + strides: strides, + padding: padding) + } +} + +public extension Conv3D { + /// Creates a `Conv3D` layer with the specified filter shape, strides, padding, and + /// element-wise activation function. The filter tensor is initialized using Glorot uniform + /// initialization with the specified seed. The bias vector is initialized with zeros. + /// + /// - Parameters: + /// - filterShape: The shape of the 5-D convolution kernel. + /// - strides: The strides of the sliding window for spatial/spatio-temporal dimensions. + /// - padding: The padding algorithm for convolution. + /// - activation: The element-wise activation function. + /// - seed: The random seed for initialization. The default value is random. + init( + filterShape: (Int, Int, Int, Int, Int), + strides: (Int, Int, Int) = (1, 1, 1), + padding: Padding = .valid, + activation: @escaping Activation = identity, + seed: (Int64, Int64) = (Int64.random(in: Int64.min.. Tensor { + return Raw.conv3DBackpropInput( + self, + filter: filter, + outBackprop: self, + strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2), + Int32(strides.3), Int32(strides.4)], + padding: padding.raw) + } + + /// TensorFlow builtin conv3d gradient helper for the filter. + @inlinable + @differentiable(wrt: (self, input), vjp: _vjpConv3DBackpropFilter) + internal func conv3DBackpropFilter( + input: Tensor, + filter: Tensor, + strides: (Int, Int, Int, Int, Int), + padding: Padding + ) -> Tensor { + return Raw.conv3DBackpropFilter( + self, + filter: filter, + outBackprop: self, + strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2), + Int32(strides.3), Int32(strides.4)], + padding: padding.raw) + } + + @inlinable + internal func _vjpConv3DBackpropInput( + _ input: Tensor, + _ filter: Tensor, + _ strides: (Int, Int, Int, Int, Int), + _ padding: Padding + ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { + let value = conv3DBackpropInput(input: input, filter: filter, strides: strides, + padding: padding) + return (value, { v in + return ( + self.conv3DBackpropFilter(input: v, filter: filter, strides: strides, + padding: padding), + v.convolved3D(withFilter: filter, strides: strides, padding: padding) + ) + }) + } + + @inlinable + internal func _vjpConv3DBackpropFilter( + _ input: Tensor, + _ filter: Tensor, + _ strides: (Int, Int, Int, Int, Int), + _ padding: Padding + ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { + let value = conv3DBackpropFilter(input: input, filter: filter, + strides: strides, padding: padding) + return (value, { v in + return ( + self.conv3DBackpropInput(input: input, filter: v, strides: strides, + padding: padding), + input.convolved3D(withFilter: filter, strides: strides, padding: padding) + ) + }) + } + + @inlinable + internal func _vjpConvolved3D( + filter: Tensor, + strides: (Int, Int, Int, Int, Int), + padding: Padding + ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { + let value = convolved3D(withFilter: filter, strides: strides, + padding: padding) + return (value, { v in + return ( + v.conv3DBackpropInput( + input: self, filter: filter, + strides: strides, padding: padding + ), + v.conv3DBackpropFilter( + input: self, filter: filter, + strides: strides, padding: padding + ) + ) + }) + } + @inlinable internal func _vjpMaxPooled2D( kernelSize: (Int, Int, Int, Int), @@ -345,6 +439,34 @@ public extension Tensor where Scalar: FloatingPoint { explicitPaddings: []) } + /// Computes a 3-D convolution using `self` as input, with the specified + /// filter, strides, and padding. + /// + /// - Parameters: + /// - filter: The convolution filter. + /// - strides: The strides of the sliding filter for each dimension of the + /// input. + /// - padding: The padding for the operation. + /// - Precondition: `self` must have rank 5. + /// - Precondition: `filter` must have rank 5. + @inlinable @inline(__always) + @differentiable( + wrt: (self, filter), vjp: _vjpConvolved3D + where Scalar: TensorFlowFloatingPoint + ) + func convolved3D( + withFilter filter: Tensor, + strides: (Int, Int, Int, Int, Int), + padding: Padding + ) -> Tensor { + return Raw.conv3D( + self, + filter: filter, + strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2), + Int32(strides.3), Int32(strides.4)], + padding: padding.raw) + } + /// Computes a 2-D max pooling, with the specified kernel sizes, strides, and /// padding. /// From 750e7bf510f189ee1d966865a247b715452068da Mon Sep 17 00:00:00 2001 From: PAWAN SASANKA AMMANAMANCHI Date: Tue, 21 May 2019 23:28:39 +0530 Subject: [PATCH 2/6] stride error --- Sources/DeepLearning/Layer.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/DeepLearning/Layer.swift b/Sources/DeepLearning/Layer.swift index 5a92c863f..6469d660a 100644 --- a/Sources/DeepLearning/Layer.swift +++ b/Sources/DeepLearning/Layer.swift @@ -545,7 +545,7 @@ public struct Conv3D: Layer { filter: Tensor, bias: Tensor, activation: @escaping Activation, - strides: (Int, Int), + strides: (Int, Int, Int), padding: Padding ) { self.filter = filter From 0765d8be70c297892994a650af53406cf51f166a Mon Sep 17 00:00:00 2001 From: PAWAN SASANKA AMMANAMANCHI Date: Wed, 22 May 2019 10:08:01 +0530 Subject: [PATCH 3/6] adding v2 apis from raw --- Sources/DeepLearning/Operators.swift | 30 ++++++++++++++-------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/Sources/DeepLearning/Operators.swift b/Sources/DeepLearning/Operators.swift index 0fbfca0a2..ee95aabbc 100644 --- a/Sources/DeepLearning/Operators.swift +++ b/Sources/DeepLearning/Operators.swift @@ -229,18 +229,18 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { @inlinable @differentiable(wrt: (self, filter), vjp: _vjpConv3DBackpropInput) internal func conv3DBackpropInput( - input: Tensor, + shape: Tensor, filter: Tensor, strides: (Int, Int, Int, Int, Int), padding: Padding ) -> Tensor { - return Raw.conv3DBackpropInput( - self, + return Raw.conv3DBackpropInputV2( + inputSizes: shape, filter: filter, outBackprop: self, strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2), Int32(strides.3), Int32(strides.4)], - padding: padding.raw) + padding: padding.raw2) } /// TensorFlow builtin conv3d gradient helper for the filter. @@ -248,13 +248,13 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { @differentiable(wrt: (self, input), vjp: _vjpConv3DBackpropFilter) internal func conv3DBackpropFilter( input: Tensor, - filter: Tensor, + filterSizes: Tensor, strides: (Int, Int, Int, Int, Int), padding: Padding ) -> Tensor { - return Raw.conv3DBackpropFilter( + return Raw.conv3DBackpropFilterV2( self, - filter: filter, + filterSizes: filterSizes, outBackprop: self, strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2), Int32(strides.3), Int32(strides.4)], @@ -263,16 +263,16 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { @inlinable internal func _vjpConv3DBackpropInput( - _ input: Tensor, + _ shape: Tensor, _ filter: Tensor, _ strides: (Int, Int, Int, Int, Int), _ padding: Padding ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - let value = conv3DBackpropInput(input: input, filter: filter, strides: strides, + let value = conv3DBackpropInput(shape: shape, filter: filter, strides: strides, padding: padding) return (value, { v in return ( - self.conv3DBackpropFilter(input: v, filter: filter, strides: strides, + self.conv3DBackpropFilter(input: v, filterSizes: shape, strides: strides, padding: padding), v.convolved3D(withFilter: filter, strides: strides, padding: padding) ) @@ -282,15 +282,15 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { @inlinable internal func _vjpConv3DBackpropFilter( _ input: Tensor, - _ filter: Tensor, + _ filterSizes: Tensor, _ strides: (Int, Int, Int, Int, Int), _ padding: Padding ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - let value = conv3DBackpropFilter(input: input, filter: filter, + let value = conv3DBackpropFilter(input: input, filterSizes: filterSizes, strides: strides, padding: padding) return (value, { v in return ( - self.conv3DBackpropInput(input: input, filter: v, strides: strides, + self.conv3DBackpropInput(shape: filterSizes, filter: v, strides: strides, padding: padding), input.convolved3D(withFilter: filter, strides: strides, padding: padding) ) @@ -308,11 +308,11 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { return (value, { v in return ( v.conv3DBackpropInput( - input: self, filter: filter, + shape: self.shapeTensor, filter: filter, strides: strides, padding: padding ), v.conv3DBackpropFilter( - input: self, filter: filter, + input: self, filterSizes: filter.shapeTensor, strides: strides, padding: padding ) ) From f737777c03ba034b5ccc796f47baa6a8d0164af3 Mon Sep 17 00:00:00 2001 From: PAWAN SASANKA AMMANAMANCHI Date: Wed, 22 May 2019 10:11:08 +0530 Subject: [PATCH 4/6] build errors --- Sources/DeepLearning/Operators.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Sources/DeepLearning/Operators.swift b/Sources/DeepLearning/Operators.swift index ee95aabbc..d0e035acb 100644 --- a/Sources/DeepLearning/Operators.swift +++ b/Sources/DeepLearning/Operators.swift @@ -240,7 +240,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { outBackprop: self, strides: [Int32(strides.0), Int32(strides.1), Int32(strides.2), Int32(strides.3), Int32(strides.4)], - padding: padding.raw2) + padding: padding.raw) } /// TensorFlow builtin conv3d gradient helper for the filter. @@ -292,7 +292,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint { return ( self.conv3DBackpropInput(shape: filterSizes, filter: v, strides: strides, padding: padding), - input.convolved3D(withFilter: filter, strides: strides, padding: padding) + input.convolved3D(withFilter: v, strides: strides, padding: padding) ) }) } From 4264452b36ff3e44d7bc51a49aad03e9c51c3663 Mon Sep 17 00:00:00 2001 From: PAWAN SASANKA AMMANAMANCHI Date: Wed, 22 May 2019 15:09:30 +0530 Subject: [PATCH 5/6] adding test for conv3d --- Tests/DeepLearningTests/LayerTests.swift | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/Tests/DeepLearningTests/LayerTests.swift b/Tests/DeepLearningTests/LayerTests.swift index 4647f0366..d11535422 100644 --- a/Tests/DeepLearningTests/LayerTests.swift +++ b/Tests/DeepLearningTests/LayerTests.swift @@ -26,6 +26,18 @@ final class LayerTests: XCTestCase { XCTAssertEqual(round(output), expected) } + func testConv3D() { + let filter = Tensor(shape: [1, 2, 2, 2, 1], scalars: (0..<8).map(Float.init)) + let bias = Tensor([0, 0]) + let layer = Conv3D(filter: filter, bias: bias, activation: identity, + strides: (1, 2, 1), padding: .valid) + let input = Tensor(shape: [2, 2, 2, 2, 2], scalars: (0..<32).map(Float.init)) + let output = layer.inferring(from: input) + let expected = Tensor([[[[[140],[62]]],[[[364],[142]]]],[[[[588],[222]]], + [[[812],[302]]]]]) + XCTAssertEqual(round(output), expected) + } + func testMaxPool1D() { let layer = MaxPool1D(poolSize: 3, stride: 1, padding: .valid) let input = Tensor([[0, 1, 2, 3, 4], [10, 11, 12, 13, 14]]).expandingShape(at: 2) @@ -187,6 +199,7 @@ final class LayerTests: XCTestCase { static var allTests = [ ("testConv1D", testConv1D), + ("testConv3D", testConv3D), ("testMaxPool1D", testMaxPool1D), ("testMaxPool2D", testMaxPool2D), ("testMaxPool3D", testMaxPool3D), From cbeea35e174bdab981a6fefe8f2511d146bc5982 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Wed, 22 May 2019 09:43:27 -0700 Subject: [PATCH 6/6] Fix tests. - Change `testConv3D` to use a non-trivial bias. - Remove calls to `round` when calling `XCTAssertEqual`. - Fix `testAvgPool3D`. --- Tests/DeepLearningTests/LayerTests.swift | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/Tests/DeepLearningTests/LayerTests.swift b/Tests/DeepLearningTests/LayerTests.swift index d11535422..22e2ec403 100644 --- a/Tests/DeepLearningTests/LayerTests.swift +++ b/Tests/DeepLearningTests/LayerTests.swift @@ -23,19 +23,19 @@ final class LayerTests: XCTestCase { let input = Tensor([[0, 1, 2, 3, 4], [10, 11, 12, 13, 14]]).expandingShape(at: 2) let output = layer.inferring(from: input) let expected = Tensor([[[1, 4], [2, 7], [3, 10]], [[11, 34], [12, 37], [13, 40]]]) - XCTAssertEqual(round(output), expected) + XCTAssertEqual(output, expected) } func testConv3D() { let filter = Tensor(shape: [1, 2, 2, 2, 1], scalars: (0..<8).map(Float.init)) - let bias = Tensor([0, 0]) + let bias = Tensor([-1, 1]) let layer = Conv3D(filter: filter, bias: bias, activation: identity, strides: (1, 2, 1), padding: .valid) let input = Tensor(shape: [2, 2, 2, 2, 2], scalars: (0..<32).map(Float.init)) let output = layer.inferring(from: input) - let expected = Tensor([[[[[140],[62]]],[[[364],[142]]]],[[[[588],[222]]], - [[[812],[302]]]]]) - XCTAssertEqual(round(output), expected) + let expected = Tensor(shape: [2, 2, 1, 1, 2], + scalars: [139, 141, 363, 365, 587, 589, 811, 813]) + XCTAssertEqual(output, expected) } func testMaxPool1D() { @@ -80,7 +80,7 @@ final class LayerTests: XCTestCase { func testAvgPool3D() { let layer = AvgPool3D(poolSize: (2, 4, 5), strides: (1, 1, 1), padding: .valid) - let input = Tensor(shape: [1, 2, 4, 5, 1], scalars: (0..<20).map(Float.init)) + let input = Tensor(shape: [1, 2, 4, 5, 1], scalars: (0..<40).map(Float.init)) let output = layer.inferring(from: input) let expected = Tensor([[[[[9.5]]]]]) XCTAssertEqual(output, expected)