diff --git a/Sources/TensorFlow/Operators/NN.swift b/Sources/TensorFlow/Operators/NN.swift index 6c90ae4da..dc43087d3 100644 --- a/Sources/TensorFlow/Operators/NN.swift +++ b/Sources/TensorFlow/Operators/NN.swift @@ -185,7 +185,7 @@ func _vjpConv2DBackpropInput( let value = conv2DBackpropInput(x, shape: shape, filter: filter, strides: strides, padding: padding, dilations: dilations) return (value, { v in - (conv2DBackpropFilter(x, input: v, filterSizes: shape, strides: strides, + (conv2DBackpropFilter(x, input: v, filterSizes: filter.shapeTensor, strides: strides, padding: padding, dilations: dilations), conv2D(v, filter: filter, strides: strides, padding: padding, dilations: dilations)) })