-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[spmd] complete softmax and _softmax_backward_data to support aggregate on sharding dim #440
Conversation
XilunWu
commented
Sep 7, 2022
- Adapt softmax and _softmax_backward_data (added in [Not for landing] Local change to enable TP in ViT model prototyping #382) ops to shard propagation rule and move them to tensor_ops.py
- Move relevant tests to test_tensor_ops.py
- Extend test coverage on (batch_dim, softmax_dim) combination except for case batch_dim == softmax_dim.
… on different batching dim; 2. add _softmax_backward_data
test/spmd/tensor/test_tensor_ops.py
Outdated
) | ||
input = torch.rand(8, 12, 16, device=self.device_type) | ||
shard0_spec = Shard(0) | ||
shard1_spec = Shard(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious why we need to test this if we already have it in the list for the dtensor_ops_db test?
nit: shard1_spec and shard2_spec are unused..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anj-s sorry I'm not familiar with the dtensor_ops_db test but one example I think appropriate to refer is https://github.com/pytorch/tau/blob/6bbe0872aeb8faa8fab73862bae9d3f806ec7836/test/spmd/tensor/test_dtensor_ops.py#L180
bmm
is in the list however it's still tested in https://github.com/pytorch/tau/blob/6bbe0872aeb8faa8fab73862bae9d3f806ec7836/test/spmd/tensor/test_dtensor_ops.py#L180
Does this resolve your concern?
nit: shard1_spec and shard2_spec are unused..
Thx for catching it! I plan to add tests for the whole space of (batch dim, softmax dim) combination. e.g. [0, 1, 2, -1]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed, @XilunWu could you enable this by deleting xfail("softmax")
in the test_densor_ops.py
? it will test forward automatically for you, you can probably change this test to only test backward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ, so since we have not supported softmax for sharding dim. Will this be a problem for test_dtensor_ops.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried running pytest test/spmd/tensor/test_dtensor_ops.py -s -k softmax
and the "softmax" test is currently passing. The set of parameters tested is as follows:
Tensor Dim | Softmax Dim | Sharding |
---|---|---|
0D (scalar) | 0 | Replicate |
1D | 0 | Replicate |
2D | {0, -1} | Replicate |
3D | {2} | Replicate |
And the current "test_softmax" parameter set is:
Tensor Dim | Softmax Dim | Sharding |
---|---|---|
3D | {0, 1, 2, -1} | {0, 1, 2, -1} |
My question is, since test_softmax
in test_tensor_ops.py
is not a duplicate to the "softmax" test in test_dtensor_ops.py
, should we keep it?
test/spmd/tensor/test_tensor_ops.py
Outdated
) | ||
input = torch.rand(8, 12, 16, device=self.device_type) | ||
shard0_spec = Shard(0) | ||
shard1_spec = Shard(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed, @XilunWu could you enable this by deleting xfail("softmax")
in the test_densor_ops.py
? it will test forward automatically for you, you can probably change this test to only test backward.
spmd/tensor/ops/tensor_ops.py
Outdated
@@ -39,6 +39,7 @@ def no_shard_prop_rule(op_schema: OpSchema) -> OutputSharding: | |||
"aten.is_same_size.default", | |||
"aten.ones_like.default", | |||
"aten.new_empty_strided.default", | |||
"aten._softmax.default", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is essentially a math op, not a tensor op, we should add this to the math_ops.py, i.e. just @register_prop_rule
in math_ops.py for aten._softmax.default
and aten._softmax_backward_data
, note that in the rule you should explicitly check if the sharding dim is the same as softmax dim (if it is, we should error out for now)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this one and sending out PR so quickly. Left out some comments.
test/spmd/tensor/test_dtensor_ops.py
Outdated
@@ -146,7 +146,7 @@ def wrapped(fn): | |||
xfail("_masked.norm"), | |||
xfail("_masked.prod"), | |||
xfail("_masked.softmin"), | |||
xfail("_masked.softmax"), | |||
#xfail("_masked.softmax"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of commenting it out, can we remove it directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will put this one back since it's the wrong test. skip("softmax")
has been removed.
spmd/tensor/ops/tensor_ops.py
Outdated
@@ -39,6 +39,7 @@ def no_shard_prop_rule(op_schema: OpSchema) -> OutputSharding: | |||
"aten.is_same_size.default", | |||
"aten.ones_like.default", | |||
"aten.new_empty_strided.default", | |||
"aten._softmax_backward_data.default", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious, is this enough for it to work? We might want to give it a prop rule. Because down the road, if we want to add sharding dim softmax, we need to call collectives in the backward as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. I currently leave it as a default rule and it kind of works on y.backward()
but not on y.sum().backward()
as I mentioned earlier. Need investigate what is missing.
test/spmd/tensor/test_tensor_ops.py
Outdated
) | ||
input = torch.rand(8, 12, 16, device=self.device_type) | ||
shard0_spec = Shard(0) | ||
shard1_spec = Shard(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ, so since we have not supported softmax for sharding dim. Will this be a problem for test_dtensor_ops.py
?
test/spmd/tensor/test_tensor_ops.py
Outdated
dist_y_grad = torch.ones_like(dist_y) | ||
# sum().backward() on dist_y has issue: | ||
# dist_y.sum().backward(dist_y_grad) | ||
# RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([8, 12, 16]) and output[0] has a shape of torch.Size([]). | ||
dist_y.backward(dist_y_grad) | ||
self.assertIsNotNone(dist_x.grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we discussed, let's just use sum() for now. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#440 (comment)
Is there a quick way to check? Because runningpytest test/spmd/tensor/test_dtensor_ops.py
takes a long time. Can I use-s
option with this file? something likepytest test/spmd/tensor/test_dtensor_ops.py -s -k test_softmax
?
… test can be removed if review is good since softmax test is turned on in test_dtensor_ops; 2.backward softmax has issue in op dispatch. need further investigation. 3. test on higher dimension mesh can be added in a separate PR
…_schema_suggestion
How to reproduce: |
What kind of difference did we observe for CPU/GPU difference? |
Bug Triage: Ran Here is the output report:
(Correct gradients in local computation result: small gradients. Same result on GPU&CPU) dist_grad=
(Wrong gradients in distributed tensor on CPU) dist_grad=
(Correct gradients in distributed tensor on GPU, from another run) |
spmd/tensor/ops/pointwise_ops.py
Outdated
@register_prop_rule("aten._softmax_backward_data.default") | ||
def softmax_bwd_rule(op_schema: OpSchema) -> OutputSharding: | ||
input_specs = cast(List[DTensorSpec], op_schema.args_spec) | ||
ops_dim_map = pytree.tree_map(lambda spec: spec.dim_map, input_specs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we better not using pytree in sharding rules unless it's absolutely needed. There're only two tensor arguments, we better just directly call dim_map
instead of using pytree
spmd/tensor/ops/pointwise_ops.py
Outdated
def softmax_bwd_rule(op_schema: OpSchema) -> OutputSharding: | ||
input_specs = cast(List[DTensorSpec], op_schema.args_spec) | ||
ops_dim_map = pytree.tree_map(lambda spec: spec.dim_map, input_specs) | ||
softmax_dim = cast(int, op_schema.args_schema[len(op_schema.args_spec)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you just unwrap like
grad_out_spec, out_spec, dim, input_dtype = op_schema.args_schema
as the backward op is just like this signature:
_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype)
schema_suggestion = None | ||
failed_reason = None | ||
if softmax_dim < len(dim_map) and dim_map[softmax_dim] >= 0: | ||
# suggest replicating the input tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: let's do this suggestion in a follow up PR.
spmd/tensor/ops/math_ops.py
Outdated
dim_map = input_spec.dim_map | ||
softmax_dim = cast( | ||
int, op_schema.args_schema[len(op_schema.args_spec)] | ||
) # Is it better to put it into kwargs? e.g. op_schema.kwargs_schema['dim'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
softmax is only having positional argument for dim
, so that's just keep positional arg
spmd/tensor/ops/pointwise_ops.py
Outdated
|
||
|
||
@register_prop_rule("aten._softmax_backward_data.default") | ||
def softmax_bwd_rule(op_schema: OpSchema) -> OutputSharding: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know softmax_bwd_rule
maybe a pointwise rule, but maybe it's better to categorize the softmax op together? Let's put them together in the math_ops.py, you can call into the pointwise_ops from math_ops.py
…PU and GPU so it's problem in DTensor's softmax rule
…U but cpu/gpu discrepancy exists within softmax_backward_data when sharding on dim -1
…ta discrepancy; reconsider schema suggestion which is now simply replicating tensors
DTensor now lives in pytorch, related PRs need to be submitted to pytorch directly, see #576 for context. |