Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP+PP tracer issue with cast-to-bf16 #1104

Open
wconstab opened this issue May 2, 2024 · 9 comments
Open

FSDP+PP tracer issue with cast-to-bf16 #1104

wconstab opened this issue May 2, 2024 · 9 comments

Comments

@wconstab
Copy link
Contributor

wconstab commented May 2, 2024

https://github.com/pytorch/torchtitan/pull/161/files#diff-80b04fce2b861d9470c6160853441793678ca13904dae2a9b8b7145f29cd017aR254

image

In principle, the issue is that the PP model code traced the non-FSDP model, and in that case, the model code ran a .to(f32) operation which was a no-op and dropped out of the trace, or something like that.

the only proposal i recall was to change the tracer/export to handle this better and not drop the .to operation. Need to check if this has already been resolved.

cc @zhxchen17 @kwen2501

@wconstab
Copy link
Contributor Author

wconstab commented May 3, 2024

pytorch/pytorch#123732 was intended to help this case but isn't quite enough.

  1. #123732 does not appear to help for calls to .float() - it only seems to work for explicit calls to .to(). I verified that if I replace .float() calls with .to(torch.float32) calls, the errors I previously saw went away.

cc @zhxchen17

  1. after the changes for (1) i get a new error which i'm still trying to understand
image

I don't see any explicit casting operators here, and from the looks of it FSDP is expected to cast the layer inputs to bf16 but isnt, OR perhaps the inputs are in bf16 but for some reason the parameters are not?

These are the inputs to the exact op (bmm) that threw the exception
target = <OpOverload(op='aten.bmm', overload='default')>
args = [
(torch.Size([64, 2048, 2048]), torch.float32),
(torch.Size([64, 2048, 16]), torch.bfloat16),
]

This paste shows one level higher in the grph- the whole attention module.
https://www.internalfb.com/phabricator/paste/view/P1229735878

Note the traced code burns float_32 dtype kwarg into the view calls for xq, xk, xv, while the actual model code does not call float32 as part of the view.

I think this is the bug?
image
image

zhxchen17 added a commit to zhxchen17/pytorch that referenced this issue May 8, 2024
Summary:

Previously we tried to convert all .to() calls to to_copy in the graph, now some user reports that other methods like .float() is not covered: pytorch/PiPPy#1104 (comment)

I think fundemantally .float() should look similar to .to() in export and this diff tries to expand the coverage of the tensor conversion methods here.

Test Plan: buck run mode/opt caffe2/test:test_export -- -r float_conversion

Differential Revision: D56951634
pytorch-bot bot pushed a commit to pytorch/pytorch that referenced this issue May 9, 2024
Summary:

Previously we tried to convert all .to() calls to to_copy in the graph, now some user reports that other methods like .float() is not covered: pytorch/PiPPy#1104 (comment)

I think fundemantally .float() should look similar to .to() in export and this diff tries to expand the coverage of the tensor conversion methods here.

Test Plan: buck run mode/opt caffe2/test:test_export -- -r float_conversion

Differential Revision: D56951634
zhxchen17 added a commit to zhxchen17/pytorch that referenced this issue May 10, 2024
Summary:

Previously we tried to convert all .to() calls to to_copy in the graph, now some user reports that other methods like .float() is not covered: pytorch/PiPPy#1104 (comment)

I think fundemantally .float() should look similar to .to() in export and this diff tries to expand the coverage of the tensor conversion methods here.

Test Plan: buck run mode/opt caffe2/test:test_export -- -r float_conversion

Differential Revision: D56951634
@kwen2501
Copy link
Contributor

An example program shows that torch.export would not burn dtype into the ExportedProgram at trace time:
https://github.com/kwen2501/export-playground/blob/main/dtype.py
See the kwargs for zeros_like.

$ python dtype.py

opcode         name                target                   args                     kwargs
-------------  ------------------  -----------------------  -----------------------  ---------------------
placeholder    p_embedding_weight  p_embedding_weight       ()                       {}
placeholder    x                   x                        ()                       {}
call_function  embedding           aten.embedding.default   (p_embedding_weight, x)  {}
call_function  zeros_like          aten.zeros_like.default  (embedding,)             {'pin_memory': False}
output         output              output                   ((zeros_like,),)         {}

@kwen2501
Copy link
Contributor

The zeros_like's dtype in the issue's program is likely the one that causes the dtype mismatch at bmm.

We can use graph_module.print_readable() to see the original stack trace to identify which part of the code set it to FP32.

@kwen2501
Copy link
Contributor

Also confirmed that a line like this z = torch.zeros_like(y, dtype=y.dtype) would burn dtype into the kwargs:

# in forward, code: z = torch.zeros_like(y, dtype=y.dtype)
zeros_like: "f32[2, 4, 3]" = torch.ops.aten.zeros_like.default(embedding, dtype = torch.float32, pin_memory = False);  

@kwen2501
Copy link
Contributor

The doc of torch.zeros_like says:

torch.zeros_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format)

dtype (torch.dtype, optional) – the desired data type of returned Tensor. Default: if None, defaults to the dtype of input.

Thus, it is safe to just write:

z = torch.zeros_like(y)

instead of

z = torch.zeros_like(y, dtype=y.dtype)

AI: we'd need to find out the code that is in the 2nd style above and fix it.

@kwen2501
Copy link
Contributor

Exporting the llama model and printing the stack shows me that the zeros_like is from the scaled_dot_product_attention

        # File: /data/users/kw2501/torchtitan/torchtitan/models/llama/model.py:203 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
        mul_4: "f32[8, 16, 2048, 16]" = torch.ops.aten.mul.Scalar(transpose, 0.5);  transpose = None
        ones: "b8[2048, 2048]" = torch.ops.aten.ones.default([2048, 2048], dtype = torch.bool, layout = torch.strided, device = device(type='meta'))
        tril: "b8[2048, 2048]" = torch.ops.aten.tril.default(ones);  ones = None
        zeros_like: "f32[2048, 2048]" = torch.ops.aten.zeros_like.default(tril, dtype = torch.float32)
        logical_not: "b8[2048, 2048]" = torch.ops.aten.logical_not.default(tril);  tril = None
        masked_fill: "f32[2048, 2048]" = torch.ops.aten.masked_fill.Scalar(zeros_like, logical_not, -inf);  zeros_like = logical_not = None
        transpose_3: "f32[8, 16, 16, 2048]" = torch.ops.aten.transpose.int(transpose_1, -2, -1);  transpose_1 = None
        mul_5: "f32[8, 16, 16, 2048]" = torch.ops.aten.mul.Scalar(transpose_3, 0.5);  transpose_3 = None
        expand: "f32[8, 16, 2048, 16]" = torch.ops.aten.expand.default(mul_4, [8, 16, 2048, 16]);  mul_4 = None
        clone: "f32[8, 16, 2048, 16]" = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format);  expand = None
        _unsafe_view: "f32[128, 2048, 16]" = torch.ops.aten._unsafe_view.default(clone, [128, 2048, 16]);  clone = None
        expand_1: "f32[8, 16, 16, 2048]" = torch.ops.aten.expand.default(mul_5, [8, 16, 16, 2048]);  mul_5 = None
        clone_1: "f32[8, 16, 16, 2048]" = torch.ops.aten.clone.default(expand_1, memory_format = torch.contiguous_format);  expand_1 = None
        _unsafe_view_1: "f32[128, 16, 2048]" = torch.ops.aten._unsafe_view.default(clone_1, [128, 16, 2048]);  clone_1 = None
        bmm: "f32[128, 2048, 2048]" = torch.ops.aten.bmm.default(_unsafe_view, _unsafe_view_1);  _unsafe_view = _unsafe_view_1 = None
        view_14: "f32[8, 16, 2048, 2048]" = torch.ops.aten.view.default(bmm, [8, 16, 2048, 2048]);  bmm = None
        add_1: "f32[8, 16, 2048, 2048]" = torch.ops.aten.add.Tensor(view_14, masked_fill);  view_14 = masked_fill = None
        _softmax: "f32[8, 16, 2048, 2048]" = torch.ops.aten._softmax.default(add_1, -1, False);  add_1 = None
        expand_2: "f32[8, 16, 2048, 2048]" = torch.ops.aten.expand.default(_softmax, [8, 16, 2048, 2048]);  _softmax = None
        view_15: "f32[128, 2048, 2048]" = torch.ops.aten.view.default(expand_2, [128, 2048, 2048]);  expand_2 = None
        expand_3: "f32[8, 16, 2048, 16]" = torch.ops.aten.expand.default(transpose_2, [8, 16, 2048, 16]);  transpose_2 = None
        clone_2: "f32[8, 16, 2048, 16]" = torch.ops.aten.clone.default(expand_3, memory_format = torch.contiguous_format);  expand_3 = None
        _unsafe_view_2: "f32[128, 2048, 16]" = torch.ops.aten._unsafe_view.default(clone_2, [128, 2048, 16]);  clone_2 = None
        bmm_1: "f32[128, 2048, 16]" = torch.ops.aten.bmm.default(view_15, _unsafe_view_2);  view_15 = _unsafe_view_2 = None
        view_16: "f32[8, 16, 2048, 16]" = torch.ops.aten.view.default(bmm_1, [8, 16, 2048, 16]);  bmm_1 = None

@kwen2501
Copy link
Contributor

More specifically:

  • In pytorch/aten/src/ATen/native/transformers/attention.cpp:
    Screenshot 2024-05-10 at 5 29 10 PM

  • Then in convert_boolean_attn_mask:
    Screenshot 2024-05-10 at 5 30 26 PM

https://github.com/pytorch/pytorch/blob/a5c93a6899c657832944cd2eeb5069449e28dbea/aten/src/ATen/native/transformers/attention.cpp#L523

@kwen2501
Copy link
Contributor

CC: @zhxchen17
@tugsbayasgalan let me know you are preparing an improvement to unburn the dtype as well? (in addition to device).
We will be thrilled to try that out. CC: @wconstab

@kwen2501
Copy link
Contributor

Meanwhile, @tugsbayasgalan mentioned that pre-dispatch mode is now the default mode of torch.export. That can also work around this issue by using this new mode to avoid tracing into SPDA.

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue May 15, 2024
Summary:
Previously we tried to convert all .to() calls to to_copy in the graph, now some user reports that other methods like .float() is not covered: pytorch/PiPPy#1104 (comment)

I think fundemantally .float() should look similar to .to() in export and this diff tries to expand the coverage of the tensor conversion methods here.

Test Plan: buck run mode/opt caffe2/test:test_export -- -r float_conversion

Differential Revision: D56951634

Pull Request resolved: #125628
Approved by: https://github.com/tugsbayasgalan
ZelboK pushed a commit to ZelboK/pytorch that referenced this issue May 19, 2024
Summary:
Previously we tried to convert all .to() calls to to_copy in the graph, now some user reports that other methods like .float() is not covered: pytorch/PiPPy#1104 (comment)

I think fundemantally .float() should look similar to .to() in export and this diff tries to expand the coverage of the tensor conversion methods here.

Test Plan: buck run mode/opt caffe2/test:test_export -- -r float_conversion

Differential Revision: D56951634

Pull Request resolved: pytorch#125628
Approved by: https://github.com/tugsbayasgalan
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants