Skip to content

Conversation

sgugger
Copy link
Owner

@sgugger sgugger commented Apr 24, 2019

Replace this paragraph with a description of your changes and rationale. Provide links to external references/discussions if appropriate.

Resolves SR-NNNN.

rxwei and others added 6 commits April 19, 2019 13:17
)

* [TF] Fix gradients in sum(squeezinAxes:) and mean(squeezinAxes:)

sum(squeezinAxes:) and mean(squeezinAxes:) were throwing an error during the bawckward pass because the gradients weren't unsqueezed before being broadcast.
Note that this could be refactored nicely if we had a function that took a list of ints for `expandingShape`.
Second note: I may be wrong, but it seems like `_vjpMean(squeezingAxes axes: [Int])` is never used and only the Tensor<Int32> version is.

* Remove unused `_vjpMean` function.

* Update Gradients.swift

* Add test

* Minor edit for consistency.
@sgugger sgugger merged commit 070511a into sgugger:tensorflow Apr 24, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants