diff --git a/Sources/TensorFlow/Operators/NN.swift b/Sources/TensorFlow/Operators/NN.swift index 2e30f521f..ae03bf15a 100644 --- a/Sources/TensorFlow/Operators/NN.swift +++ b/Sources/TensorFlow/Operators/NN.swift @@ -189,9 +189,9 @@ func _vjpConv2DBackpropFilter( let value = conv2DBackpropFilter(x, input: input, filterSizes: filterSizes, strides: strides, padding: padding, dilations: dilations) return (value, { v in - (conv2DBackpropInput(x, shape: filterSizes, filter: v, strides: strides, - padding: padding, dilations: dilations), - conv2D(input, filter: v, strides: strides, padding: padding, dilations: dilations)) + (conv2D(input, filter: v, strides: strides, padding: padding, dilations: dilations), + conv2DBackpropInput(x, shape: x.shapeTensor, filter: v, strides: strides, + padding: padding, dilations: dilations)) }) }