Skip to content

Conversation

kflu
Copy link
Contributor

@kflu kflu commented Aug 11, 2023

Summary:
Seems like a bug in D47998435, where when cache hits it returns None

Repro:

class TestModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x + 1

mod = TestModule()
inp = torch.rand(1)
out = mod(inp)
mod2 = torch.fx.symbolic_trace(mod, concrete_args=[inp])

so, _ = torch._export.aot_compile(mod2, tuple([inp]))
# 2nd time, it will return None
so, _ = torch._export.aot_compile(mod2, tuple([inp]))
assert so is not None  # FAIL

Test Plan: Run the repro

Differential Revision: D48258375

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 11, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/107020

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit 68faf64 with merge base a422969 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48258375

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48258375

@kflu kflu force-pushed the export-D48258375 branch from f4f86f9 to 1ffd1f0 Compare August 14, 2023 20:32
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48258375

kflu added a commit to kflu/pytorch that referenced this pull request Aug 14, 2023
Summary:
Pull Request resolved: pytorch#107020

Seems like a bug in D47998435, where when cache hits it returns None

Repro:

```
class TestModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x + 1

mod = TestModule()
inp = torch.rand(1)
out = mod(inp)
mod2 = torch.fx.symbolic_trace(mod, concrete_args=[inp])

so, _ = torch._export.aot_compile(mod2, tuple([inp]))
# 2nd time, it will return None
so, _ = torch._export.aot_compile(mod2, tuple([inp]))
assert so is not None  # FAIL
```

Test Plan: Run the repro

Reviewed By: frank-wei

Differential Revision: D48258375

fbshipit-source-id: 58982772fcff07595ed20e4856889e79d0ce065f
@kflu kflu force-pushed the export-D48258375 branch from 1ffd1f0 to cae303f Compare August 14, 2023 20:39
@kflu
Copy link
Contributor Author

kflu commented Aug 16, 2023

@pytorchbot help

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 16, 2023

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: argument command: invalid choice: 'help' (choose from 'merge', 'revert', 'rebase', 'label', 'drci')

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci} ...

Try @pytorchbot --help for more info.

@kflu
Copy link
Contributor Author

kflu commented Aug 16, 2023

@pytorchbot --help

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 16, 2023

PyTorchBot Help

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci} ...

In order to invoke the bot on your PR, include a line that starts with
@pytorchbot anywhere in a comment. That line will form the command; no
multi-line commands are allowed. 

Example:
    Some extra context, blah blah, wow this PR looks awesome

    @pytorchbot merge

optional arguments:
  -h, --help            Show this help message and exit.

command:
  {merge,revert,rebase,label,drci}
    merge               Merge a PR
    revert              Revert a PR
    rebase              Rebase a PR
    label               Add label to a PR
    drci                Update Dr. CI

Merge

usage: @pytorchbot merge [-f MESSAGE | -i] [-ic] [-r [{viable/strict,main}]]

Merge an accepted PR, subject to the rules in .github/merge_rules.json.
By default, this will wait for all required checks (lint, pull) to succeed before merging.

optional arguments:
  -f MESSAGE, --force MESSAGE
                        Merge without checking anything. This requires a reason for auditting purpose, for example:
                        @pytorchbot merge -f 'Minor update to fix lint. Expecting all PR tests to pass'
                        
                        Please use `-f` as last resort, prefer `--ignore-current` to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.
  -i, --ignore-current  Merge while ignoring the currently failing jobs.  Behaves like -f if there are no pending jobs.
  -ic                   Old flag for --ignore-current. Deprecated in favor of -i.
  -r [{viable/strict,main}], --rebase [{viable/strict,main}]
                        Rebase the PR to re run checks before merging.  Accepts viable/strict or main as branch options and will default to viable/strict if not specified.

Revert

usage: @pytorchbot revert -m MESSAGE -c
                          {nosignal,ignoredsignal,landrace,weird,ghfirst}

Revert a merged PR. This requires that you are a Meta employee.

Example:
  @pytorchbot revert -m="This is breaking tests on trunk. hud.pytorch.org/" -c=nosignal

optional arguments:
  -m MESSAGE, --message MESSAGE
                        The reason you are reverting, will be put in the commit message. Must be longer than 3 words.
  -c {nosignal,ignoredsignal,landrace,weird,ghfirst}, --classification {nosignal,ignoredsignal,landrace,weird,ghfirst}
                        A machine-friendly classification of the revert reason.

Rebase

usage: @pytorchbot rebase [-s | -b BRANCH]

Rebase a PR. Rebasing defaults to the stable viable/strict branch of pytorch.
Repeat contributor may use this command to rebase their PR.

optional arguments:
  -s, --stable          [DEPRECATED] Rebase onto viable/strict
  -b BRANCH, --branch BRANCH
                        Branch you would like to rebase to

Label

usage: @pytorchbot label labels [labels ...]

Adds label to a PR

positional arguments:
  labels  Labels to add to given Pull Request

Dr CI

usage: @pytorchbot drci 

Update Dr. CI. Updates the Dr. CI comment on the PR in case it's gotten out of sync with actual CI results.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48258375

@kflu kflu force-pushed the export-D48258375 branch from cae303f to 165ddcf Compare August 16, 2023 22:06
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48258375

@kflu kflu force-pushed the export-D48258375 branch from 165ddcf to 38b09ab Compare August 16, 2023 22:11
Summary:
Pull Request resolved: pytorch#107020

Seems like a bug in D47998435, where when cache hits it returns None

Repro:

```
class TestModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x + 1

mod = TestModule()
inp = torch.rand(1)
out = mod(inp)
mod2 = torch.fx.symbolic_trace(mod, concrete_args=[inp])

so, _ = torch._export.aot_compile(mod2, tuple([inp]))
# 2nd time, it will return None
so, _ = torch._export.aot_compile(mod2, tuple([inp]))
assert so is not None  # FAIL
```

Test Plan: Run the repro

Reviewed By: frank-wei

Differential Revision: D48258375

fbshipit-source-id: 0d1feeabadacaa97b7e545c7b03f0c45fa8d8ce6
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D48258375

@kflu kflu force-pushed the export-D48258375 branch from 38b09ab to 68faf64 Compare August 16, 2023 22:17
@kflu
Copy link
Contributor Author

kflu commented Aug 17, 2023

@pytorchbot merge -i

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 17, 2023
@kflu
Copy link
Contributor Author

kflu commented Aug 17, 2023

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@kflu
Copy link
Contributor Author

kflu commented Aug 17, 2023

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Aug 17, 2023
@kflu
Copy link
Contributor Author

kflu commented Aug 17, 2023

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 2 checks: inductor / cuda12.1-py3.10-gcc9-sm86 / test (inductor_torchbench, 1, 1, linux.g5.4xlarge.nvidia.gpu), inductor / cuda12.1-py3.10-gcc9-sm86 / test (inductor_torchbench_dynamic, 1, 1, linux.g5.4xlarge.nvidia.gpu)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Copy link
Member

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add some unittest in the following PR/diff?

@huydhn
Copy link
Contributor

huydhn commented Aug 22, 2023

@pytorchbot drci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants