Skip to content

scatter's backward formula should call scatter instead of scatter_ #63430

@zou3519

Description

@zou3519

🚀 Feature

Replace scatter_ in scatter's backward formula with scatter.

Motivation

functorch has a hard time with in-place operations and I don't see a good reason why scatter_ gets used here (please correct me if I'm wrong!). cc @albanD @soulitzer for sanity check - please let me know if I missed something here

- name: scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor
self: grad.clone().scatter_(dim, index, 0)
index: non_differentiable
src: grad.gather(dim, index)
- name: scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor
self: grad.clone().scatter_(dim, index, 0)
index: non_differentiable

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: scatter & gather opstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions