Skip to content

Conversation

@ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Dec 10, 2025

Closes #1744
Closes #1743

MinimizeOps were very liberal with assuming inputs would all be tensors and so would gradients wrt to them.

This PR makes sure we only work with support types: Tensor and ScalarTypes. We could also try to support SparseVariables.

Right now the Op may still fail in the gradient pass if args are of other types (most I can think about are slices in index operations, and maybe rngs if they make sense). I added a test with a custom string type. A robust implementation of Minimize gradient would need to:

  1. Return disconnected_type() for disconnected inputs
  2. Implement a connection_pattern method to tell grad about this in advance (it complains otherwise).

That can be left for a future PR

Gave up on mypy

Mypy didn't like my bugfix / more precise type-hints. Given the empty track record of it helping find bugs in our codebase and my dislike for it, I removed all typing.cast and added it to mypy-failing. Now we have both: more correct type-hints and more readable code. If someone wants to revert it be my guest, but I won't be helping

@ricardoV94 ricardoV94 changed the title Optimize: Guard against unsupported input types Optimize: Handle gradient wrt scalar inputs and guard against unsupported types Dec 10, 2025
@ricardoV94 ricardoV94 force-pushed the minimize_grad_scalar branch 3 times, most recently from f26580f to 32dc544 Compare December 10, 2025 12:44
x: TensorVariable,
*args: Variable,
objective: Variable,
method: str = "BFGS",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the default method from the Ops, since the helpers already have them. Helps avoiding mismatches

outputs: Sequence[Variable],
output_grads: Sequence[Variable],
) -> list[Variable]:
def L_op(self, inputs, outputs, output_grads):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the only L_op with type-hints, I removed for consistency

Comment on lines +166 to +169
grad_wrt_arg = dot(output_grad, arg_grad)
if isinstance(arg.type, ScalarType):
grad_wrt_arg = scalar_from_tensor(grad_wrt_arg)
grad_wrt_args.append(grad_wrt_arg)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An an LLM would have put it: "🥳 The actual bugfix"

@ricardoV94
Copy link
Member Author

CC @Michal-Novomestsky

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BUG: Passing check_parameters through minimize and tg.grad raises TypeError

1 participant