Skip to content

Conversation

aakhundov
Copy link
Contributor

@aakhundov aakhundov commented Mar 13, 2024

Stack from ghstack (oldest at bottom):

Summary: Previously, we didn't handle transitive replacements in MLIR walk-based function info mining in the Triton kernel mutation analysis pass. As a result, for the TTIR below:

tt.func private @cumsum__fp32S1_16S__1cconstexpr_1__2cconstexpr_False_(%arg0: tensor<1x16xf32> loc("...":296:0)) -> tensor<1x16xf32> attributes {noinline = false} {
    %0 = "tt.scan"(%arg0) <{axis = 1 : i32, reverse = false}> ({
    ^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown)):
      %1 = tt.call @_sum_combine__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32 loc(#loc16)
      tt.scan.return %1 : f32 loc(#loc16)
    }) : (tensor<1x16xf32>) -> tensor<1x16xf32> loc(#loc16)
    tt.return %0 : tensor<1x16xf32> loc(#loc18)
  } loc(#loc15)

the mined function dict looked like this:

{Intermediate(idx=25): [Op(name='tt.call',
                           fn_call_name='_sum_combine__fp32_fp32__',
                           args=[Intermediate(idx=26),
                                 Intermediate(idx=26)])],
 Intermediate(idx=27): [Op(name='tt.scan.return',
                           fn_call_name=None,
                           args=[Intermediate(idx=25)])],
 Intermediate(idx=-4): [Op(name='tt.return',
                           fn_call_name=None,
                           args=[Intermediate(idx=27)])]}

whereas it should look like this (not the Param(idx=0) arguments of the tt.call):

{Intermediate(idx=25): [Op(name='tt.call',
                           fn_call_name='_sum_combine__fp32_fp32__',
                           args=[Param(idx=0),
                                 Param(idx=0)])],
 Intermediate(idx=27): [Op(name='tt.scan.return',
                           fn_call_name=None,
                           args=[Intermediate(idx=25)])],
 Intermediate(idx=-4): [Op(name='tt.return',
                           fn_call_name=None,
                           args=[Intermediate(idx=27)])]}

This is fixed in the PR.

Test Plan:

$ python test/inductor/test_triton_kernels.py -k test_cumsum
.
----------------------------------------------------------------------
Ran 1 test in 1.771s

OK

Summary: Previously, we didn't handle transitive replacements in MLIR walk-based function info mining in the Triton kernel mutation analysis pass. As a result, for the TTIR below:

```
tt.func private @cumsum__fp32S1_16S__1cconstexpr_1__2cconstexpr_False_(%arg0: tensor<1x16xf32> loc("...":296:0)) -> tensor<1x16xf32> attributes {noinline = false} {
    %0 = "tt.scan"(%arg0) <{axis = 1 : i32, reverse = false}> ({
    ^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown)):
      %1 = tt.call @_sum_combine__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32 loc(#loc16)
      tt.scan.return %1 : f32 loc(#loc16)
    }) : (tensor<1x16xf32>) -> tensor<1x16xf32> loc(#loc16)
    tt.return %0 : tensor<1x16xf32> loc(#loc18)
  } loc(#loc15)
```

the mined function dict looked like this:

```
{Intermediate(idx=25): [Op(name='tt.call',
                           fn_call_name='_sum_combine__fp32_fp32__',
                           args=[Intermediate(idx=26),
                                 Intermediate(idx=26)])],
 Intermediate(idx=27): [Op(name='tt.scan.return',
                           fn_call_name=None,
                           args=[Intermediate(idx=25)])],
 Intermediate(idx=-4): [Op(name='tt.return',
                           fn_call_name=None,
                           args=[Intermediate(idx=27)])]}
```

whereas it should look like this (not the `Param(idx=0)` arguments of the `tt.call`):

```
{Intermediate(idx=25): [Op(name='tt.call',
                           fn_call_name='_sum_combine__fp32_fp32__',
                           args=[Param(idx=0),
                                 Param(idx=0)])],
 Intermediate(idx=27): [Op(name='tt.scan.return',
                           fn_call_name=None,
                           args=[Intermediate(idx=25)])],
 Intermediate(idx=-4): [Op(name='tt.return',
                           fn_call_name=None,
                           args=[Intermediate(idx=27)])]}
```

This is fixed in the PR.

Test Plan:

```
$ python test/inductor/test_triton_kernels.py -k test_cumsum
.
----------------------------------------------------------------------
Ran 1 test in 1.771s

OK
```

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Mar 13, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/121867

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit ca91d54 with merge base 70c6f54 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@aakhundov aakhundov requested a review from oulgen March 13, 2024 23:32
…nalysis"

Summary: Previously, we didn't handle transitive replacements in MLIR walk-based function info mining in the Triton kernel mutation analysis pass. As a result, for the TTIR below:

```
tt.func private cumsum__fp32S1_16S__1cconstexpr_1__2cconstexpr_False_(%arg0: tensor<1x16xf32> loc("...":296:0)) -> tensor<1x16xf32> attributes {noinline = false} {
    %0 = "tt.scan"(%arg0) <{axis = 1 : i32, reverse = false}> ({
    ^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown)):
      %1 = tt.call @_sum_combine__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32 loc(#loc16)
      tt.scan.return %1 : f32 loc(#loc16)
    }) : (tensor<1x16xf32>) -> tensor<1x16xf32> loc(#loc16)
    tt.return %0 : tensor<1x16xf32> loc(#loc18)
  } loc(#loc15)
```

the mined function dict looked like this:

```
{Intermediate(idx=25): [Op(name='tt.call',
                           fn_call_name='_sum_combine__fp32_fp32__',
                           args=[Intermediate(idx=26),
                                 Intermediate(idx=26)])],
 Intermediate(idx=27): [Op(name='tt.scan.return',
                           fn_call_name=None,
                           args=[Intermediate(idx=25)])],
 Intermediate(idx=-4): [Op(name='tt.return',
                           fn_call_name=None,
                           args=[Intermediate(idx=27)])]}
```

whereas it should look like this (not the `Param(idx=0)` arguments of the `tt.call`):

```
{Intermediate(idx=25): [Op(name='tt.call',
                           fn_call_name='_sum_combine__fp32_fp32__',
                           args=[Param(idx=0),
                                 Param(idx=0)])],
 Intermediate(idx=27): [Op(name='tt.scan.return',
                           fn_call_name=None,
                           args=[Intermediate(idx=25)])],
 Intermediate(idx=-4): [Op(name='tt.return',
                           fn_call_name=None,
                           args=[Intermediate(idx=27)])]}
```

This is fixed in the PR.

Test Plan:

```
$ python test/inductor/test_triton_kernels.py -k test_cumsum
.
----------------------------------------------------------------------
Ran 1 test in 1.771s

OK
```

[ghstack-poisoned]
aakhundov added a commit that referenced this pull request Mar 13, 2024
Summary: Previously, we didn't handle transitive replacements in MLIR walk-based function info mining in the Triton kernel mutation analysis pass. As a result, for the TTIR below:

```
tt.func private cumsum__fp32S1_16S__1cconstexpr_1__2cconstexpr_False_(%arg0: tensor<1x16xf32> loc("...":296:0)) -> tensor<1x16xf32> attributes {noinline = false} {
    %0 = "tt.scan"(%arg0) <{axis = 1 : i32, reverse = false}> ({
    ^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown)):
      %1 = tt.call @_sum_combine__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32 loc(#loc16)
      tt.scan.return %1 : f32 loc(#loc16)
    }) : (tensor<1x16xf32>) -> tensor<1x16xf32> loc(#loc16)
    tt.return %0 : tensor<1x16xf32> loc(#loc18)
  } loc(#loc15)
```

the mined function dict looked like this:

```
{Intermediate(idx=25): [Op(name='tt.call',
                           fn_call_name='_sum_combine__fp32_fp32__',
                           args=[Intermediate(idx=26),
                                 Intermediate(idx=26)])],
 Intermediate(idx=27): [Op(name='tt.scan.return',
                           fn_call_name=None,
                           args=[Intermediate(idx=25)])],
 Intermediate(idx=-4): [Op(name='tt.return',
                           fn_call_name=None,
                           args=[Intermediate(idx=27)])]}
```

whereas it should look like this (not the `Param(idx=0)` arguments of the `tt.call`):

```
{Intermediate(idx=25): [Op(name='tt.call',
                           fn_call_name='_sum_combine__fp32_fp32__',
                           args=[Param(idx=0),
                                 Param(idx=0)])],
 Intermediate(idx=27): [Op(name='tt.scan.return',
                           fn_call_name=None,
                           args=[Intermediate(idx=25)])],
 Intermediate(idx=-4): [Op(name='tt.return',
                           fn_call_name=None,
                           args=[Intermediate(idx=27)])]}
```

This is fixed in the PR.

Test Plan:

```
$ python test/inductor/test_triton_kernels.py -k test_cumsum
.
----------------------------------------------------------------------
Ran 1 test in 1.771s

OK
```

ghstack-source-id: c8c271f
Pull Request resolved: #121867
@aakhundov aakhundov added ciflow/trunk Trigger trunk jobs on your pull request topic: not user facing topic category ciflow/inductor labels Mar 13, 2024
@aakhundov
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/aakhundov/15/head branch April 14, 2024 02:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants