Skip to content

Fix Bad Interaction Between Aligning Inputs and CUDAGraphs in Backward (LayoutLMForMaskedLM)Β #103126

@eellison

Description

@eellison

πŸ› Describe the bug

python benchmarks/dynamo/huggingface.py --device cuda --performance --backend inductor --amp --training --only LayoutLMForMaskedLM

Gives an assertion error:

β”‚                                                                                                β”‚
β”‚    888 β”‚   β”‚   β”‚   β”‚   continue                                                                β”‚
β”‚    889 β”‚   β”‚   β”‚   if data_ptr is not None:                                                    β”‚
β”‚    890 β”‚   β”‚   β”‚   β”‚   # static input, e.g., parameter                                         β”‚
β”‚ ❱  891 β”‚   β”‚   β”‚   β”‚   assert data_ptr == new_inputs[idx].data_ptr()                           β”‚
β”‚    892 β”‚   β”‚   β”‚   else:                                                                       β”‚
β”‚    893 β”‚   β”‚   β”‚   β”‚   # non-static input, need to copy it into CUDA graph                     β”‚
β”‚    894 β”‚   β”‚   β”‚   β”‚   dst = self.reconstructed_inputs[idx]                                    β”‚

Our cudagraph implementation will assume certain tensors have a fixed location - Parameters, and saved tensors from the forward of a graph which was cudagraph'd.

We also annotate all tensors to be 16 bit aligned to triton kernels because it improves perf, and copy over inputs to an aligned address if they are not already.

This led to a bad interaction in the backward as shown above. If you add the following prints you get:

runtime misaligned inputs: [] of # 212
cudagraphs removing unaligned input idxs set() of 212
runtime misaligned inputs: [31, 33] of # 326
cudagraphs removing unaligned input idxs set() of 326
runtime misaligned inputs: [] of # 212
runtime misaligned inputs: [31, 33] of # 326

We need to make sure we are not removing the misaligned inputs before we are checking for misalignment in cudagraphs, so we know not to expect a static input for the misaligned tensors.

I set cudagraph_trees to False here because it leads to simpler repro (no extra warmup) but same issue is there regardless.

Versions

master

cc @mcarilli @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225

Metadata

Metadata

Assignees

Labels

module: cuda graphsAbility to capture and then replay streams of CUDA kernelsmodule: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions