Skip to content

Conversation

@mrshenli
Copy link
Contributor

@mrshenli mrshenli commented Apr 6, 2023

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 6, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/98512

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit fd1105d:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

mrshenli added a commit that referenced this pull request Apr 6, 2023
@mrshenli mrshenli changed the title [WIP][Don't Review Yet] Add NLLLoss to DTensor prop rule Add NLLLoss to DTensor prop rule Apr 7, 2023
mrshenli added a commit that referenced this pull request Apr 7, 2023
ghstack-source-id: 712d129
Pull Request resolved: #98512


@register_prop_rule(aten.nll_loss_backward.default) # pyre-ignore
def _prop_nll_loss_forward(op_schema: OpSchema) -> OutputSharding:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this function should be renamed to backward?

DTensorSpec(
mesh=self.mesh,
placements=[_Partial()],
tensor_meta=TensorMetadata(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to fill this out as sharding_prop.propagate_op_sharding should automaically popolate this field.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, let me update

DTensorSpec(
mesh=self.mesh,
placements=[Replicate()],
tensor_meta=TensorMetadata(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, we don't need to fill in tensor_meta

# a scalar tensor, and hence the _Partial placements.
DTensorSpec(
mesh=self.mesh,
placements=[_Partial()],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we produce partial, I thought it would trigger pending reduction in the next op that consumes this partial placement during dtensor expand?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the graph after DTensor expansion. I think the reason it doesn't trigger allreduce was because nll_loss_forward is the last operator before backward? (same as the sparse loss that @yifuwang encountered?)

opcode         name                        target                                   args                                                         kwargs
-------------  --------------------------  ---------------------------------------  -----------------------------------------------------------  -----------------
--------------------------------------------                                                                                                                      
placeholder    arg0_1                      arg0_1                                   ()                                                           {}
placeholder    arg0_2                      arg0_2                                   ()                                                           {}
placeholder    arg1_1                      arg1_1                                   ()                                                           {}
placeholder    arg1_2                      arg1_2                                   ()                                                           {}
placeholder    arg2_1                      arg2_1                                   ()                                                           {}               
placeholder    arg2_2                      arg2_2                                   ()                                                           {}
placeholder    arg2_3                      arg2_3                                   ()                                                           {}
placeholder    arg2_4                      arg2_4                                   ()                                                           {}
call_function  t                           aten.t.default                           (arg1_1,)                                                    {}
call_function  addmm                       aten.addmm.default                       (arg1_2, arg2_3, t)                                          {}
call_function  _log_softmax                aten._log_softmax.default                (addmm, 1, False)                                            {}
call_function  detach                      aten.detach.default                      (_log_softmax,)                                              {}
call_function  nll_loss_forward            aten.nll_loss_forward.default            (_log_softmax, arg2_4, None, 1, -100)                        {}
call_function  getitem                     <built-in function getitem>              (nll_loss_forward, 0)                                        {}
call_function  getitem_1                   <built-in function getitem>              (nll_loss_forward, 1)                                        {}
call_function  ones_like                   aten.ones_like.default                   (getitem,)                                                   {'pin_memory': False, 'memory_format': torch.preserve_format}
call_function  nll_loss_backward           aten.nll_loss_backward.default           (ones_like, _log_softmax, arg2_4, None, 1, -100, getitem_1)  {}
call_function  detach_1                    aten.detach.default                      (detach,)                                                    {}
call_function  _log_softmax_backward_data  aten._log_softmax_backward_data.default  (nll_loss_backward, detach_1, 1, torch.float32)              {}
call_function  t_1                         aten.t.default                           (_log_softmax_backward_data,)                                {}
call_function  mm                          aten.mm.default                          (t_1, arg2_3)                                                {}
call_function  t_2                         aten.t.default                           (mm,)                                                        {}
call_function  sum_1                       aten.sum.dim_IntList                     (_log_softmax_backward_data, [0], True)                      {}
call_function  view                        aten.view.default                        (sum_1, [10])                                                {}
call_function  detach_2                    aten.detach.default                      (view,)                                                      {}
call_function  detach_3                    aten.detach.default                      (detach_2,)                                                  {}
call_function  t_3                         aten.t.default                           (t_2,)                                                       {}
call_function  detach_4                    aten.detach.default                      (t_3,)                                                       {}
call_function  detach_5                    aten.detach.default                      (detach_4,)                                                  {}
call_function  all_reduce                  aten.all_reduce.default                  (detach_5, 'SUM', 'ptd:0', [0, 1], 2)                        {}
call_function  wait_tensor                 aten.wait_tensor.default                 (all_reduce,)                                                {}
call_function  all_reduce_1                aten.all_reduce.default                  (detach_3, 'SUM', 'ptd:0', [0, 1], 2)                        {}
call_function  wait_tensor_1               aten.wait_tensor.default                 (all_reduce_1,)                                              {}
call_function  _foreach_add_1              aten._foreach_add.List                   ([arg1_1, arg1_2], [wait_tensor, wait_tensor_1])             {'alpha': -0.01}
call_function  getitem_2                   <built-in function getitem>              (_foreach_add_1, 0)                                          {}
call_function  getitem_3                   <built-in function getitem>              (_foreach_add_1, 1)                                          {}
call_function  copy_                       aten.copy_.default                       (arg1_1, getitem_2)                                          {}
call_function  copy__1                     aten.copy_.default                       (arg1_2, getitem_3)                                          {}
output         output                      output                                   ([None, copy_, copy__1],)                                    {}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, this is probably because ones_like rule takes in a partial but it produces a replicate iirc, so it doesn't trigger allreduce.

mrshenli added a commit that referenced this pull request Apr 7, 2023
ghstack-source-id: 19342e2
Pull Request resolved: #98512
new_self = DTensorSpec(
mesh=self.mesh,
placements=target.placements,
tensor_meta=self.tensor_meta,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @wanchaol, I assume we still need this tensor_meta, as sharding_prop only populate tensor_meta for outputs?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes if this is resharding on inputs, then we need to assign it.

@mrshenli mrshenli added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 7, 2023
@mrshenli
Copy link
Contributor Author

mrshenli commented Apr 8, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

skotapati pushed a commit to kulinseth/pytorch that referenced this pull request Apr 10, 2023
@facebook-github-bot facebook-github-bot deleted the gh/mrshenli/382/head branch June 8, 2023 18:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged merging topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants