Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Conversation

jansel
Copy link
Contributor

@jansel jansel commented Sep 1, 2022

Fixes #1085

if allow_alpha:
if alpha is not None and alpha != 1:
inputs = list(inputs)
inputs[-1] = mul(inputs[-1], alpha)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert alpha is a scalar?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mul() already has asserts on its inputs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I assumed mul would be able to take tensors, where alpha was required to be a scalar. As long as mul asserts, sgtm!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Often constants get wrapped as tensors in AOT Autograd. So we want this to work for tensors.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we still need to check that it's 0d and that it's dtype is correct for the other inputs (e.g. integral inputs and float alpha error out in eager)

@jansel jansel changed the title [inductor] Support alpha=1 in add/sub [inductor] Support alpha= in add/sub Sep 1, 2022
@jansel jansel merged commit 9292e40 into pytorch:main Sep 1, 2022
@jansel jansel deleted the alpha branch September 1, 2022 22:50
@lezcano
Copy link
Contributor

lezcano commented Oct 3, 2022

Why was this added here, rather than simply using the add decomposition from PrimTorch and let the lowering fuse the pointwise operations? The PrimTorch operation already has the correct checks on alpha.

@ngimel
Copy link

ngimel commented Oct 3, 2022

we can't use aten -> prim decompositions from primTorch because they hardcode broadcasting.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Inductor fails on optimizer.step
5 participants