-
Couldn't load subscription status.
- Fork 25.7k
Add NLLLoss to DTensor prop rule #98512
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/98512
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit fd1105d: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
|
|
||
|
|
||
| @register_prop_rule(aten.nll_loss_backward.default) # pyre-ignore | ||
| def _prop_nll_loss_forward(op_schema: OpSchema) -> OutputSharding: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this function should be renamed to backward?
| DTensorSpec( | ||
| mesh=self.mesh, | ||
| placements=[_Partial()], | ||
| tensor_meta=TensorMetadata( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need to fill this out as sharding_prop.propagate_op_sharding should automaically popolate this field.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, let me update
| DTensorSpec( | ||
| mesh=self.mesh, | ||
| placements=[Replicate()], | ||
| tensor_meta=TensorMetadata( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same, we don't need to fill in tensor_meta
| # a scalar tensor, and hence the _Partial placements. | ||
| DTensorSpec( | ||
| mesh=self.mesh, | ||
| placements=[_Partial()], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we produce partial, I thought it would trigger pending reduction in the next op that consumes this partial placement during dtensor expand?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the graph after DTensor expansion. I think the reason it doesn't trigger allreduce was because nll_loss_forward is the last operator before backward? (same as the sparse loss that @yifuwang encountered?)
opcode name target args kwargs
------------- -------------------------- --------------------------------------- ----------------------------------------------------------- -----------------
--------------------------------------------
placeholder arg0_1 arg0_1 () {}
placeholder arg0_2 arg0_2 () {}
placeholder arg1_1 arg1_1 () {}
placeholder arg1_2 arg1_2 () {}
placeholder arg2_1 arg2_1 () {}
placeholder arg2_2 arg2_2 () {}
placeholder arg2_3 arg2_3 () {}
placeholder arg2_4 arg2_4 () {}
call_function t aten.t.default (arg1_1,) {}
call_function addmm aten.addmm.default (arg1_2, arg2_3, t) {}
call_function _log_softmax aten._log_softmax.default (addmm, 1, False) {}
call_function detach aten.detach.default (_log_softmax,) {}
call_function nll_loss_forward aten.nll_loss_forward.default (_log_softmax, arg2_4, None, 1, -100) {}
call_function getitem <built-in function getitem> (nll_loss_forward, 0) {}
call_function getitem_1 <built-in function getitem> (nll_loss_forward, 1) {}
call_function ones_like aten.ones_like.default (getitem,) {'pin_memory': False, 'memory_format': torch.preserve_format}
call_function nll_loss_backward aten.nll_loss_backward.default (ones_like, _log_softmax, arg2_4, None, 1, -100, getitem_1) {}
call_function detach_1 aten.detach.default (detach,) {}
call_function _log_softmax_backward_data aten._log_softmax_backward_data.default (nll_loss_backward, detach_1, 1, torch.float32) {}
call_function t_1 aten.t.default (_log_softmax_backward_data,) {}
call_function mm aten.mm.default (t_1, arg2_3) {}
call_function t_2 aten.t.default (mm,) {}
call_function sum_1 aten.sum.dim_IntList (_log_softmax_backward_data, [0], True) {}
call_function view aten.view.default (sum_1, [10]) {}
call_function detach_2 aten.detach.default (view,) {}
call_function detach_3 aten.detach.default (detach_2,) {}
call_function t_3 aten.t.default (t_2,) {}
call_function detach_4 aten.detach.default (t_3,) {}
call_function detach_5 aten.detach.default (detach_4,) {}
call_function all_reduce aten.all_reduce.default (detach_5, 'SUM', 'ptd:0', [0, 1], 2) {}
call_function wait_tensor aten.wait_tensor.default (all_reduce,) {}
call_function all_reduce_1 aten.all_reduce.default (detach_3, 'SUM', 'ptd:0', [0, 1], 2) {}
call_function wait_tensor_1 aten.wait_tensor.default (all_reduce_1,) {}
call_function _foreach_add_1 aten._foreach_add.List ([arg1_1, arg1_2], [wait_tensor, wait_tensor_1]) {'alpha': -0.01}
call_function getitem_2 <built-in function getitem> (_foreach_add_1, 0) {}
call_function getitem_3 <built-in function getitem> (_foreach_add_1, 1) {}
call_function copy_ aten.copy_.default (arg1_1, getitem_2) {}
call_function copy__1 aten.copy_.default (arg1_2, getitem_3) {}
output output output ([None, copy_, copy__1],) {}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see, this is probably because ones_like rule takes in a partial but it produces a replicate iirc, so it doesn't trigger allreduce.
[ghstack-poisoned]
[ghstack-poisoned]
| new_self = DTensorSpec( | ||
| mesh=self.mesh, | ||
| placements=target.placements, | ||
| tensor_meta=self.tensor_meta, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @wanchaol, I assume we still need this tensor_meta, as sharding_prop only populate tensor_meta for outputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yes if this is resharding on inputs, then we need to assign it.
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: pytorch#98512 Approved by: https://github.com/wanchaol
Stack from ghstack (oldest at bottom):