-
Notifications
You must be signed in to change notification settings - Fork 149
Optimize: Handle gradient wrt scalar inputs and guard against unsupported types #1784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
f26580f to
32dc544
Compare
| x: TensorVariable, | ||
| *args: Variable, | ||
| objective: Variable, | ||
| method: str = "BFGS", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the default method from the Ops, since the helpers already have them. Helps avoiding mismatches
32dc544 to
2a3adf0
Compare
| outputs: Sequence[Variable], | ||
| output_grads: Sequence[Variable], | ||
| ) -> list[Variable]: | ||
| def L_op(self, inputs, outputs, output_grads): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was the only L_op with type-hints, I removed for consistency
| grad_wrt_arg = dot(output_grad, arg_grad) | ||
| if isinstance(arg.type, ScalarType): | ||
| grad_wrt_arg = scalar_from_tensor(grad_wrt_arg) | ||
| grad_wrt_args.append(grad_wrt_arg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An an LLM would have put it: "🥳 The actual bugfix"
2a3adf0 to
5c29ce8
Compare
5c29ce8 to
fbaa2fe
Compare
Closes #1744
Closes #1743
MinimizeOps were very liberal with assuming inputs would all be tensors and so would gradients wrt to them.
This PR makes sure we only work with support types: Tensor and ScalarTypes. We could also try to support SparseVariables.
Right now the Op may still fail in the gradient pass if args are of other types (most I can think about are slices in index operations, and maybe rngs if they make sense). I added a test with a custom string type. A robust implementation of Minimize gradient would need to:
connection_patternmethod to tellgradabout this in advance (it complains otherwise).That can be left for a future PR
Gave up on mypy
Mypy didn't like my bugfix / more precise type-hints. Given the empty track record of it helping find bugs in our codebase and my dislike for it, I removed all typing.cast and added it to mypy-failing. Now we have both: more correct type-hints and more readable code. If someone wants to revert it be my guest, but I won't be helping