Skip to content

Conversation

yushangdi
Copy link
Contributor

@yushangdi yushangdi commented Dec 16, 2024

Summary:

  • When a user specify TORCHINDUCTOR_MAX_AUTOTUNE=1 env variable, we add config.max_autotune=True to the generated minifier_launcher
  • We should do this to other inductor configs as well in a followup Diff

Currently in dynamo and aoti minifier, if a config is overwritten by an env variable, the config will not show up in the config list in the minifier_launcher.py file. As a result, when running the minifier_launcher, they need to re-apply the same env variable.
This is:

  1. not convenient for the users
  2. if they copy-paste the minifier_launcher.py to us without including the env variable, we could be confused and not able to reproduce the error.

Underlying implementation change:

  • Add env_default parameter to codegen_config(). If set, configs overriden by the env are not considered default.

Test Plan:

 buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:utils -- -r test_codegen_config

Differential Revision: D67299312

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov

Copy link

pytorch-bot bot commented Dec 16, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/143330

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (10 Unrelated Failures)

As of commit 65f117b with merge base f4e9aeb (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D67299312

skip_default does two things. One if a key has not been modified
it skips it. The other is it modified the logging behaviour
to match what codegen already did for modified skipped keys
env_default: If set, configs overriden by the env are not considered default.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should always be True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll change it to always be True

@desertfire
Copy link
Contributor

Adding more inductor reviewers

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you reuse the existing config serialization from minifier? i would not expect this many changes


# enable slow autotuning passes to select algorithms
max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1"
max_autotune = Config(False, env_name_default="TORCHINDUCTOR_MAX_AUTOTUNE")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this ? the existing minifier serializes config settings without this change

Copy link
Contributor Author

@yushangdi yushangdi Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison

The problem is that the minifier cannot produce the same config as the run time config.

For example, if you run the script below like TORCHINDUCTOR_MAX_AUTOTUNE=1 python test.py, the configs that the minifier dumps doesn't have torch._inductor.config.max_autotune = True, even though you have env variable TORCHINDUCTOR_MAX_AUTOTUNE=1.

This is because the config module takes the os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1" as default and skipped this config, which is wrong (or at least not convenient for the minifier users). The true default value is False, and the user actually changed the value of this config with an env variable.

import torch
from torch._inductor import config as inductor_config
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(10, 16)
        self.relu = torch.nn.ReLU()
        self.sigmoid = torch.nn.Sigmoid()
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.sigmoid(x)
        return x
inductor_config.aot_inductor.dump_aoti_minifier = True
torch._inductor.config.triton.inject_relu_bug_TESTING_ONLY = "compile_error"
with torch.no_grad():
    model = Model().to("cuda")
    example_inputs = (torch.randn(8, 10).to("cuda"),)
    ep = torch.export.export(model, example_inputs)
    package_path = torch._inductor.aoti_compile_and_package(ep)
    compiled_model = torch._inductor.aoti_load_package(package_path)
    result = compiled_model(*example_inputs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the changes here is to add a flag when dumping the configs, so we don't skip the configs that are changed by the setting the env variables.

Copy link
Contributor

@eellison eellison Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're going to fix max-autotune, it would be great to do a complete fix, not just max autotune. we could do this less invasively by

  • Recording modified config values - either w/custom logic in setattr, or by saving config on instantiation and diffing with current config
  • Serialize modfiied config values
  • serialize any env vars with TORCHINDUCTOR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're going to fix max-autotune, it would be great to do a complete fix, not just max autotune. we could do this less invasively by

@eellison Yeah, I'm planning to fix all configs in a follow-up Diff if the approach in this Diff looks ok.

  • Recording modified config values - either w/custom logic in setattr, or by saving config on instantiation and diffing with current config
  • Serialize modfiied config values
  • serialize any env vars with TORCHINDUCTOR

I believe currently we are doing these already in torch/utils/_config_module.py?

Copy link
Contributor

@eellison eellison Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be able to do the above without the max_autotune = Config(False, env_name_default="TORCHINDUCTOR_MAX_AUTOTUNE") changes.. But I guess this is the standard way, so fine with me..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be able to do the above without the max_autotune = Config(False, env_name_default="TORCHINDUCTOR_MAX_AUTOTUNE") changes.. But I guess this is the standard way, so fine with me..

@eellison yeah I agree, we could do it another way, but since there is the Config class already, I thought we might as well just use it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does work, but I would still prefer we do the general thing that covers all configs instead of just this specific one.

Copy link
Contributor Author

@yushangdi yushangdi Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does work, but I would still prefer we do the general thing that covers all configs instead of just this specific one.

@eellison How about I change all configs to use the Config class (like D67303965)? It's a lot of lines to change, but if we do the general thing we still need to touch all configs. (I can do that in a follow-up PR)

We won't be able to just leave the config like max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1" untouched.

If we want to record the default, we at least need to change that to two steps:

  1. specify the default: max_autotune = False.
  2. record the change: setattr(max_autotune, os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE")=="1").

if DEBUG_DIR_VAR_NAME in os.environ:
return os.path.join(os.environ[DEBUG_DIR_VAR_NAME], "torch_compile_debug")
elif is_fbcode():
fbcode_dir = os.path.join(os.path.expanduser("~"), "fbsource", "fbcode")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this path general? Should this be cwd or home? Putting it in the fbsource tree seems a bit weird.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's intended to be in the fbcode. The generated TARGETS file and python files need to be in fbsource for them to be compiled by buck.

Copy link
Contributor Author

@yushangdi yushangdi Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I could also make it cwd. I'll just delete the elif is_fbcode() branch. The file written to /tmp cannot be compiled by buck.

Comment on lines 1384 to 1394
# def _set_config_from_env(config, env) -> None:
# """
# Set config to True if env is set to 1.
# We use the if statement to avoid calling setattr if env is not set.
# This helps config.codegen_config() to only generate configs that are not
# the default value.
# """
# if os.environ.get(env) == "1":
# setattr(sys.modules[__name__], config, True)

# _set_config_from_env("max_autotune", "TORCHINDUCTOR_MAX_AUTOTUNE")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dead code?

ignored_keys: Optional[List[str]] = None,
ignored_prefixes: Optional[List[str]] = None,
skip_default: bool = False,
env_default: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does anything break if we just turn this on always? If so let's just remove the option and make it default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's probably ok. I'll make it default then.


subparsers = parser.add_subparsers(
dest="command", metavar="{run,minify,analyze}", required=True
dest="command", metavar="{run,minify}", required=True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this unused?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, AoTI minifier doesn't have an analyze mode yet

yushangdi added a commit to yushangdi/pytorch that referenced this pull request Dec 17, 2024
Summary:

- When a user specify `TORCHINDUCTOR_MAX_AUTOTUNE=1` env variable, we add `config.max_autotune=True` to the generated minifier_launcher
- We should do this to other inductor configs as well in a followup Diff

Currently in dynamo and aoti minifier, if a config is overwritten by an env variable, the config will not show up in the config list in the minifier_launcher.py file. As a result, when running the minifier_launcher, they need to re-apply the same env variable.
 This is:
1) not convenient for the users
2) if they copy-paste the minifier_launcher.py to us without including the env variable, we could be confused and not able to reproduce the error.

Underlying implementation change:

- Add `env_default` parameter to `codegen_config()`. If set, configs overriden by the env are not considered default.

Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:utils -- -r test_codegen_config
```

Differential Revision: D67299312
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D67299312

@yushangdi
Copy link
Contributor Author

yushangdi commented Dec 18, 2024

@jansel @desertfire @eellison Pin for review on the updated version of the PR. Thank you!

@yushangdi yushangdi requested a review from eellison December 18, 2024 17:27
@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 19, 2024
pytorch-bot bot pushed a commit that referenced this pull request Jan 7, 2025
Summary:

- When a user specify env variables that include "TORCH", "INDUCTOR", "DYNAMO" or "TRITON", we add  those env variable definitions to the generated minifier_launcher


Example env variables generated:

```
import os
os.environ['TORCHINDUCTOR_MAX_AUTOTUNE'] = '1'
os.environ['PYTORCH_DDP_USE_SIDE_STREAM'] = '0'
os.environ['TRITON_CACHE_MANAGER'] = 'triton.runtime.cache:RemoteCacheManager'
os.environ['TRITON_REMOTE_CACHE_BACKEND'] = 'triton.fb.fb_memcache:FbMemcacheRemoteKernelCache'
os.environ['TRITON_LIBDEVICE_PATH'] = '/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/585c17c94f036562/tutorials/pytorch/nntest/__torchtest__/torchtest#link-tree/triton/fb/libdevice.10.bc'
os.environ['TRITON_PTXAS_PATH'] = '/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/585c17c94f036562/tutorials/pytorch/nntest/__torchtest__/torchtest#link-tree/triton/fb/ptxas'
os.environ['TRITON_LIBCUDA_PATH'] = '/usr/local/fbcode/platform010/lib'
os.environ['TORCHINDUCTOR_CACHE_DIR'] = '/tmp/torchinductor_shangdiy'
```

Currently in dynamo and aoti minifier, if a config is overwritten by an env variable, the config will not show up in the config list in the minifier_launcher.py file. As a result, when running the minifier_launcher, they need to re-apply the same env variable.
 This is:
1) not convenient for the users
2) if they copy-paste the minifier_launcher.py to us without including the env variable, we could be confused and not able to reproduce the error.

We also make the following change in `codegen_config()`:

- Add `env_default` parameter to `codegen_config()`. If set, configs overriden by the env are not considered default.

Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:utils -- -r test_codegen_config
```

Reviewed By: eellison

Differential Revision: D67299312
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D67299312

yushangdi added a commit to yushangdi/pytorch that referenced this pull request Jan 7, 2025
Summary:

- When a user specify env variables that include "TORCH", "INDUCTOR", "DYNAMO" or "TRITON", we add  those env variable definitions to the generated minifier_launcher


Example env variables generated:

```
import os
os.environ['TORCHINDUCTOR_MAX_AUTOTUNE'] = '1'
os.environ['PYTORCH_DDP_USE_SIDE_STREAM'] = '0'
os.environ['TRITON_CACHE_MANAGER'] = 'triton.runtime.cache:RemoteCacheManager'
os.environ['TRITON_REMOTE_CACHE_BACKEND'] = 'triton.fb.fb_memcache:FbMemcacheRemoteKernelCache'
os.environ['TRITON_LIBDEVICE_PATH'] = '/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/585c17c94f036562/tutorials/pytorch/nntest/__torchtest__/torchtest#link-tree/triton/fb/libdevice.10.bc'
os.environ['TRITON_PTXAS_PATH'] = '/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/585c17c94f036562/tutorials/pytorch/nntest/__torchtest__/torchtest#link-tree/triton/fb/ptxas'
os.environ['TRITON_LIBCUDA_PATH'] = '/usr/local/fbcode/platform010/lib'
os.environ['TORCHINDUCTOR_CACHE_DIR'] = '/tmp/torchinductor_shangdiy'
```

Currently in dynamo and aoti minifier, if a config is overwritten by an env variable, the config will not show up in the config list in the minifier_launcher.py file. As a result, when running the minifier_launcher, they need to re-apply the same env variable.
 This is:
1) not convenient for the users
2) if they copy-paste the minifier_launcher.py to us without including the env variable, we could be confused and not able to reproduce the error.

We also make the following change in `codegen_config()`:

- Add `env_default` parameter to `codegen_config()`. If set, configs overriden by the env are not considered default.

Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:utils -- -r test_codegen_config
```

Reviewed By: eellison

Differential Revision: D67299312
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D67299312

yushangdi added a commit to yushangdi/pytorch that referenced this pull request Jan 7, 2025
Summary:

- When a user specify env variables that include "TORCH", "INDUCTOR", "DYNAMO" or "TRITON", we add  those env variable definitions to the generated minifier_launcher


Example env variables generated:

```
import os
os.environ['TORCHINDUCTOR_MAX_AUTOTUNE'] = '1'
os.environ['PYTORCH_DDP_USE_SIDE_STREAM'] = '0'
os.environ['TRITON_CACHE_MANAGER'] = 'triton.runtime.cache:RemoteCacheManager'
os.environ['TRITON_REMOTE_CACHE_BACKEND'] = 'triton.fb.fb_memcache:FbMemcacheRemoteKernelCache'
os.environ['TRITON_LIBDEVICE_PATH'] = '/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/585c17c94f036562/tutorials/pytorch/nntest/__torchtest__/torchtest#link-tree/triton/fb/libdevice.10.bc'
os.environ['TRITON_PTXAS_PATH'] = '/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/585c17c94f036562/tutorials/pytorch/nntest/__torchtest__/torchtest#link-tree/triton/fb/ptxas'
os.environ['TRITON_LIBCUDA_PATH'] = '/usr/local/fbcode/platform010/lib'
os.environ['TORCHINDUCTOR_CACHE_DIR'] = '/tmp/torchinductor_shangdiy'
```

Currently in dynamo and aoti minifier, if a config is overwritten by an env variable, the config will not show up in the config list in the minifier_launcher.py file. As a result, when running the minifier_launcher, they need to re-apply the same env variable.
 This is:
1) not convenient for the users
2) if they copy-paste the minifier_launcher.py to us without including the env variable, we could be confused and not able to reproduce the error.

We also make the following change in `codegen_config()`:

- Add `env_default` parameter to `codegen_config()`. If set, configs overriden by the env are not considered default.

Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:utils -- -r test_codegen_config
```

Reviewed By: eellison

Differential Revision: D67299312
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D67299312

Summary:

- When a user specify env variables that include "TORCH", "INDUCTOR", "DYNAMO" or "TRITON", we add  those env variable definitions to the generated minifier_launcher


Example env variables generated:

```
import os
os.environ['TORCHINDUCTOR_MAX_AUTOTUNE'] = '1'
os.environ['PYTORCH_DDP_USE_SIDE_STREAM'] = '0'
os.environ['TRITON_CACHE_MANAGER'] = 'triton.runtime.cache:RemoteCacheManager'
os.environ['TORCHINDUCTOR_CACHE_DIR'] = '/tmp/torchinductor_shangdiy'
os.environ['TRITON_REMOTE_CACHE_BACKEND'] = 'triton.fb.fb_memcache:FbMemcacheRemoteKernelCache'
```

We skip some env vars like below, because they are generated per compile.

```

os.environ['TRITON_LIBDEVICE_PATH'] = '/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/585c17c94f036562/tutorials/pytorch/nntest/__torchtest__/torchtest#link-tree/triton/fb/libdevice.10.bc'
os.environ['TRITON_PTXAS_PATH'] = '/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/585c17c94f036562/tutorials/pytorch/nntest/__torchtest__/torchtest#link-tree/triton/fb/ptxas'
os.environ['TRITON_LIBCUDA_PATH'] = '/usr/local/fbcode/platform010/lib'
```

Currently in dynamo and aoti minifier, if a config is overwritten by an env variable, the config will not show up in the config list in the minifier_launcher.py file. As a result, when running the minifier_launcher, they need to re-apply the same env variable.
 This is:
1) not convenient for the users
2) if they copy-paste the minifier_launcher.py to us without including the env variable, we could be confused and not able to reproduce the error.

We also make the following change in `codegen_config()`:

- Add `env_default` parameter to `codegen_config()`. If set, configs overriden by the env are not considered default.

Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:utils -- -r test_codegen_config
```

Reviewed By: eellison

Differential Revision: D67299312
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D67299312

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀 🚀 🚀

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot
Copy link
Contributor

This pull request has been reverted by 66ce13b. To re-land this change, please open another pull request, assignthe same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk).

@yushangdi
Copy link
Contributor Author

This pull request has been reverted by 66ce13b. To re-land this change, please open another pull request, assignthe same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk).

Note: this Diff is only partially reverted in [66ce13b]. The majority of the Diff is not reverted.

@yushangdi yushangdi removed the Reverted label Jan 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants