Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[spmd] complete softmax and _softmax_backward_data to support aggregate on sharding dim #440

Closed
wants to merge 31 commits into from

Conversation

XilunWu
Copy link
Contributor

@XilunWu XilunWu commented Sep 7, 2022

  1. Adapt softmax and _softmax_backward_data (added in [Not for landing] Local change to enable TP in ViT model prototyping #382) ops to shard propagation rule and move them to tensor_ops.py
  2. Move relevant tests to test_tensor_ops.py
  3. Extend test coverage on (batch_dim, softmax_dim) combination except for case batch_dim == softmax_dim.

… on different batching dim; 2. add _softmax_backward_data
)
input = torch.rand(8, 12, 16, device=self.device_type)
shard0_spec = Shard(0)
shard1_spec = Shard(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why we need to test this if we already have it in the list for the dtensor_ops_db test?
nit: shard1_spec and shard2_spec are unused..

Copy link
Contributor Author

@XilunWu XilunWu Sep 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anj-s sorry I'm not familiar with the dtensor_ops_db test but one example I think appropriate to refer is https://github.com/pytorch/tau/blob/6bbe0872aeb8faa8fab73862bae9d3f806ec7836/test/spmd/tensor/test_dtensor_ops.py#L180

bmm is in the list however it's still tested in https://github.com/pytorch/tau/blob/6bbe0872aeb8faa8fab73862bae9d3f806ec7836/test/spmd/tensor/test_dtensor_ops.py#L180

Does this resolve your concern?

nit: shard1_spec and shard2_spec are unused..

Thx for catching it! I plan to add tests for the whole space of (batch dim, softmax dim) combination. e.g. [0, 1, 2, -1] $\times$ [0, 1, 2]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, @XilunWu could you enable this by deleting xfail("softmax") in the test_densor_ops.py? it will test forward automatically for you, you can probably change this test to only test backward.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ, so since we have not supported softmax for sharding dim. Will this be a problem for test_dtensor_ops.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried running pytest test/spmd/tensor/test_dtensor_ops.py -s -k softmax and the "softmax" test is currently passing. The set of parameters tested is as follows:

Tensor Dim Softmax Dim Sharding
0D (scalar) 0 Replicate
1D 0 Replicate
2D {0, -1} Replicate
3D {2} Replicate

And the current "test_softmax" parameter set is:

Tensor Dim Softmax Dim Sharding
3D {0, 1, 2, -1} {0, 1, 2, -1}

My question is, since test_softmax in test_tensor_ops.py is not a duplicate to the "softmax" test in test_dtensor_ops.py, should we keep it?

)
input = torch.rand(8, 12, 16, device=self.device_type)
shard0_spec = Shard(0)
shard1_spec = Shard(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, @XilunWu could you enable this by deleting xfail("softmax") in the test_densor_ops.py? it will test forward automatically for you, you can probably change this test to only test backward.

@@ -39,6 +39,7 @@ def no_shard_prop_rule(op_schema: OpSchema) -> OutputSharding:
"aten.is_same_size.default",
"aten.ones_like.default",
"aten.new_empty_strided.default",
"aten._softmax.default",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is essentially a math op, not a tensor op, we should add this to the math_ops.py, i.e. just @register_prop_rule in math_ops.py for aten._softmax.default and aten._softmax_backward_data, note that in the rule you should explicitly check if the sharding dim is the same as softmax dim (if it is, we should error out for now)

Copy link
Contributor

@fduwjj fduwjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this one and sending out PR so quickly. Left out some comments.

@@ -146,7 +146,7 @@ def wrapped(fn):
xfail("_masked.norm"),
xfail("_masked.prod"),
xfail("_masked.softmin"),
xfail("_masked.softmax"),
#xfail("_masked.softmax"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of commenting it out, can we remove it directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will put this one back since it's the wrong test. skip("softmax") has been removed.

@@ -39,6 +39,7 @@ def no_shard_prop_rule(op_schema: OpSchema) -> OutputSharding:
"aten.is_same_size.default",
"aten.ones_like.default",
"aten.new_empty_strided.default",
"aten._softmax_backward_data.default",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, is this enough for it to work? We might want to give it a prop rule. Because down the road, if we want to add sharding dim softmax, we need to call collectives in the backward as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. I currently leave it as a default rule and it kind of works on y.backward() but not on y.sum().backward() as I mentioned earlier. Need investigate what is missing.

)
input = torch.rand(8, 12, 16, device=self.device_type)
shard0_spec = Shard(0)
shard1_spec = Shard(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ, so since we have not supported softmax for sharding dim. Will this be a problem for test_dtensor_ops.py?

Comment on lines 136 to 141
dist_y_grad = torch.ones_like(dist_y)
# sum().backward() on dist_y has issue:
# dist_y.sum().backward(dist_y_grad)
# RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([8, 12, 16]) and output[0] has a shape of torch.Size([]).
dist_y.backward(dist_y_grad)
self.assertIsNotNone(dist_x.grad)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed, let's just use sum() for now. Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#440 (comment)
Is there a quick way to check? Because running pytest test/spmd/tensor/test_dtensor_ops.py takes a long time. Can I use -s option with this file? something like pytest test/spmd/tensor/test_dtensor_ops.py -s -k test_softmax?

@XilunWu
Copy link
Contributor Author

XilunWu commented Sep 10, 2022

test_softmax_with_bwd in test_tensor_ops.py has result mismatch in backward pass when doing softmax on dim=-1 on CPU but is bug free on GPU.

How to reproduce: pytest test/spmd/tensor/test_tensor_ops.py -s -k test_softmax_with_bwd

@XilunWu XilunWu marked this pull request as ready for review September 10, 2022 18:53
@XilunWu XilunWu marked this pull request as draft September 12, 2022 16:29
@fduwjj
Copy link
Contributor

fduwjj commented Sep 12, 2022

What kind of difference did we observe for CPU/GPU difference?

@XilunWu
Copy link
Contributor Author

XilunWu commented Sep 12, 2022

Bug Triage: Ran _softmax_backward_data on a [4, 4, 4] tensor with shard_dim = 0; aggregation_dim = 2 and shard_dim = 0; aggregation_dim = -1. Pippy produces correct result for the first pair of parameters on CPU&GPU but wrong result for the second pair. Note: this error only happens when aggregation_dim = -1 and tensor is not replicated (i.e. sharding on dim 0 or 1, otherwise it's auto replicated by my softmax rule).

Here is the output report:
local_grad=
tensor([[[ 0.0000e+00, -2.5286e-08, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 8.4356e-09, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 1.7677e-08, 0.0000e+00, 0.0000e+00]],

    [[ 0.0000e+00, -3.2625e-08,  0.0000e+00,  0.0000e+00],
     [ 0.0000e+00,  0.0000e+00,  1.0334e-08,  0.0000e+00],
     [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
     [ 0.0000e+00,  9.0607e-09,  0.0000e+00,  0.0000e+00]],

    [[ 0.0000e+00, -2.3336e-08,  0.0000e+00,  0.0000e+00],
     [ 0.0000e+00,  0.0000e+00,  2.0839e-08,  0.0000e+00],
     [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
     [ 0.0000e+00,  1.5786e-08,  0.0000e+00,  0.0000e+00]],

    [[ 0.0000e+00, -3.7963e-08,  0.0000e+00,  0.0000e+00],
     [ 0.0000e+00,  0.0000e+00,  1.9997e-08,  0.0000e+00],
     [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
     [ 0.0000e+00,  1.7081e-08,  0.0000e+00,  0.0000e+00]]])

(Correct gradients in local computation result: small gradients. Same result on GPU&CPU)

dist_grad=
tensor([[[ 0.0203, -0.0075, 0.0298, -0.0105],
[-0.0075, 0.0298, -0.0105, 0.0140],
[ 0.0298, -0.0105, 0.0140, -0.0277],
[-0.0105, 0.0140, -0.0277, 0.0214]],

    [[ 0.0203, -0.0101,  0.0208, -0.0153],
     [-0.0101,  0.0208, -0.0153,  0.0259],
     [ 0.0208, -0.0153,  0.0259, -0.0234],
     [-0.0153,  0.0259, -0.0234,  0.0471]],

    [[ 0.0375, -0.0085,  0.0458, -0.0149],
     [-0.0085,  0.0458, -0.0149,  0.0281],
     [ 0.0458, -0.0149,  0.0281, -0.0288],
     [-0.0149,  0.0281, -0.0288,  0.0297]],

    [[ 0.0407, -0.0105,  0.0289, -0.0205],
     [-0.0105,  0.0289, -0.0205,  0.0181],
     [ 0.0289, -0.0205,  0.0181, -0.0265],
     [-0.0205,  0.0181, -0.0265,  0.0301]]])

(Wrong gradients in distributed tensor on CPU)

dist_grad=
tensor([[[ 0.0000e+00, -4.4105e-08, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 1.7012e-08, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 4.1126e-08, 0.0000e+00, 1.4592e-08],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],

    [[ 0.0000e+00, -2.1497e-08,  0.0000e+00,  0.0000e+00],
     [ 0.0000e+00,  1.6550e-08,  0.0000e+00,  0.0000e+00],
     [ 0.0000e+00,  2.6419e-08,  0.0000e+00,  1.8104e-08],
     [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]],

    [[ 0.0000e+00, -2.3221e-08,  0.0000e+00,  0.0000e+00],
     [ 0.0000e+00,  8.7823e-09,  0.0000e+00,  0.0000e+00],
     [ 0.0000e+00,  2.7092e-08,  0.0000e+00,  1.4168e-08],
     [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]],

    [[ 0.0000e+00, -3.0386e-08,  0.0000e+00,  0.0000e+00],
     [ 0.0000e+00,  1.7260e-08,  0.0000e+00,  0.0000e+00],
     [ 0.0000e+00,  2.4572e-08,  0.0000e+00,  1.2741e-08],
     [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]]],
   device='cuda:2')

(Correct gradients in distributed tensor on GPU, from another run)

@register_prop_rule("aten._softmax_backward_data.default")
def softmax_bwd_rule(op_schema: OpSchema) -> OutputSharding:
input_specs = cast(List[DTensorSpec], op_schema.args_spec)
ops_dim_map = pytree.tree_map(lambda spec: spec.dim_map, input_specs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we better not using pytree in sharding rules unless it's absolutely needed. There're only two tensor arguments, we better just directly call dim_map instead of using pytree

def softmax_bwd_rule(op_schema: OpSchema) -> OutputSharding:
input_specs = cast(List[DTensorSpec], op_schema.args_spec)
ops_dim_map = pytree.tree_map(lambda spec: spec.dim_map, input_specs)
softmax_dim = cast(int, op_schema.args_schema[len(op_schema.args_spec)])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you just unwrap like

grad_out_spec, out_spec, dim, input_dtype = op_schema.args_schema

as the backward op is just like this signature:

_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) 

schema_suggestion = None
failed_reason = None
if softmax_dim < len(dim_map) and dim_map[softmax_dim] >= 0:
# suggest replicating the input tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's do this suggestion in a follow up PR.

dim_map = input_spec.dim_map
softmax_dim = cast(
int, op_schema.args_schema[len(op_schema.args_spec)]
) # Is it better to put it into kwargs? e.g. op_schema.kwargs_schema['dim']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/pytorch/pytorch/blob/1cad744694d7feb7c55e5f4ff4a6ae749686bfb5/aten/src/ATen/native/native_functions.yaml#L4721

softmax is only having positional argument for dim, so that's just keep positional arg



@register_prop_rule("aten._softmax_backward_data.default")
def softmax_bwd_rule(op_schema: OpSchema) -> OutputSharding:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know softmax_bwd_rule maybe a pointwise rule, but maybe it's better to categorize the softmax op together? Let's put them together in the math_ops.py, you can call into the pointwise_ops from math_ops.py

…ta discrepancy; reconsider schema suggestion which is now simply replicating tensors
@XilunWu
Copy link
Contributor Author

XilunWu commented Sep 13, 2022

Split #440 into 2 parts:
Part 1: productionization of original softmax op prototyped in #382 (#455)
Part 2: complete softmax ops by enabling case shard_dim == softmax_dim (#440)

@XilunWu XilunWu changed the title [spmd] adapt softmax and _softmax_backward_data to shard prop rule [spmd] complete softmax and _softmax_backward_data to support aggregate on sharding dim Sep 19, 2022
@wanchaol
Copy link
Contributor

DTensor now lives in pytorch, related PRs need to be submitted to pytorch directly, see #576 for context.

@wanchaol wanchaol closed this Nov 28, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants