Skip to content

Conversation

FindHao
Copy link
Member

@FindHao FindHao commented May 23, 2023

Issue description

The PR #100064 introduces a new RNG operation process. However, it causes every randint to load a separate random seed by default. TorchInductor generates a buffer to store all necessary random seeds and places the offsets as constant values in the subsequent compute buffers. In ir_pre_fusion generated by TorchInductor, some buffers only differ by one line, which is the load random seed with the corresponding offset. Subsequently, the codegen generates Triton kernels following the same rule. Finally, in the output_code.py, some Triton kernels only differ by one line, meaning that redundant kernels are being generated.

Solution

This PR captures the seed offset and adds it to the existing self.sizevars structure. It generates variable names as placeholders, allowing the code wrapper to pass the offset as an argument to the kernels. I've also modified the divisible_by_16 check to exclude this argument.

This PR reduces the number of generated kernels from 50 to 17 for BertForMaskedLM forward.

According to tests on my own environment, the compilation time of attention_is_all_you_need_pytorch has been reduced from 94s to 66s. The speedup remains largely unchanged, at 1.37X.

The following is a comparison for a simple example.
Before:

triton_poi_fused_0 = async_compile.triton('triton_', '''
...
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    ...
    tmp0 = tl.load(in_ptr0 + 0)
    tmp1 = x0
    tmp2 = triton_helpers.randint64(tmp0, (tmp1).to(tl.uint32), 0, 10)

triton_poi_fused_1 = async_compile.triton('triton_', '''
...
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    ...
    tmp0 = tl.load(in_ptr0 + 1)
    tmp1 = x0
    tmp2 = triton_helpers.randint64(tmp0, (tmp1).to(tl.uint32), 0, 10)
...''')

def call(args):
        triton_poi_fused_0.run(buf0, buf1, 1024, grid=grid(1024), stream=stream0)
        triton_poi_fused_1.run(buf0, buf2, 1024, grid=grid(1024), stream=stream0)

After:

triton_poi_fused_0 = async_compile.triton('triton_', '''
...
def triton_(in_ptr0, out_ptr0, load_seed_offset, xnumel, XBLOCK : tl.constexpr):
    ...
    tmp0 = tl.load(in_ptr0 + load_seed_offset)
    tmp1 = x0
    tmp2 = triton_helpers.randint64(tmp0, (tmp1).to(tl.uint32), 0, 10)
    ....

def call(args):
        triton_poi_fused_0.run(buf0, buf1, 0, 1024, grid=grid(1024), stream=stream0)
        triton_poi_fused_0.run(buf0, buf2, 1, 1024, grid=grid(1024), stream=stream0)

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10

@pytorch-bot
Copy link

pytorch-bot bot commented May 23, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/102104

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit dc1e187:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@FindHao
Copy link
Member Author

FindHao commented May 23, 2023

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label May 23, 2023
@FindHao FindHao changed the title Move offsets of loading seeds as kernel arguments Fix redudant kernel generations May 23, 2023
@FindHao FindHao marked this pull request as ready for review May 23, 2023 21:02
@FindHao FindHao requested review from desertfire, jansel and ngimel May 23, 2023 21:02
@FindHao
Copy link
Member Author

FindHao commented May 23, 2023

Some dynamic shape tests failed because it counts the number of generated kernels while this PR changes it. I'll fix them in next commit.


def seed_offset(self, name, value):
if "load_seed_offset" in self.sizevars.values():
name = "%s%d" % (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you use f-string formatting here f"{name}{expr}"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed by e6268cf

self.inplace_buffers[output_name] = buf

def seed_offset(self, name, value):
if "load_seed_offset" in self.sizevars.values():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have any testcases for multiple load_seed_offset args?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this wouldn't work if you change name at the callsite

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the hardcode string to the name in e6268cf .

The simple example for multiple load_seed_offset could be the following.

def fn():
    random_tensor1 = torch.randint(10, [1024], device="cuda")
    random_tensor2 = torch.randint(11, [1024], device="cuda")
    random_tensor3 = torch.randint(10, [1024], device="cuda")
    tensor4 = random_tensor1 + random_tensor2 +random_tensor3
    return tensor4

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking that there is another hardcode string in divisible_by_16 check. It's better to replace it too. Do you know where I should put this hardcode string as a global string?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I understand examples for multiple load_seed_offset exist, I'm asking if we have a test for those

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only tested several models. not found multiple load_seed_offset for now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean adding a unit test?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, please add a unit test for this if it doesn't exist.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a unit test by dc1e187


def seed_offset(self, name, value):
if name in self.sizevars.values():
name = f"{name}{sum(1 for value in self.sizevars.values() if value.startswith('load_seed_offset'))}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you still have load_seed_offset here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed by b84276f

@FindHao
Copy link
Member Author

FindHao commented May 24, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 24, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / win-vs2019-cpu-py3 / build

Details for Dev Infra team Raised by workflow job

@FindHao
Copy link
Member Author

FindHao commented May 24, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / win-vs2019-cpu-py3 / test (default, 2, 3, windows.4xlarge.nonephemeral)

Details for Dev Infra team Raised by workflow job

@ngimel
Copy link
Collaborator

ngimel commented May 24, 2023

@pytorchbot merge -f "test failure unrelated"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the findhao/fix-redundant-kernels branch November 23, 2024 02:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants