-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Consolidate Triton cache into Inductor cache #138239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138239
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit f394adf with merge base f4ab8b4 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D64504909 |
@ezyang I'm not really sure how to integrate this into TLParse or our other data sources.. Any recommendations? |
Hi, I don't see any logging / trace structured updates / scuba updates, can you think about what kind of instrumentation you want to add and propose something here? EDIT: Oh oops didn't see your comment |
@classmethod | ||
def collect(cls) -> List[TritonKernelArtifacts]: | ||
if not TritonBundler.is_enabled(): | ||
cls._entries = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a bit confused about thread safety here. You are talking about _entries
writes OK from multiple threads to be atomic. Is collect a sync point then or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm probably being overly cautious here. collect is executed when we are writing the inductor cache entry to disk/remote, so at this point we expect everything to have completed and this is true by design (no need for any sync primitive).
class TritonBundler: | ||
# It is safe for multiple threads to insert entries as insert operation | ||
# of dict is atomic | ||
_entries: Optional[Dict[TritonBundleEntry, None]] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit - is this normal pytorch style? or should we just add a lock to make this thread safety explicit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The operation is atomic, why would we want the overhead of a lock? Documentation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Documentation. (also do you think the overhead will be significant here? If you do we can drop it, but assuming what 100k triton compiles, and 100ns for a mutex acquisition, you get 10ms in overhead).
|
||
for artifact in artifacts.artifacts: | ||
filepath = os.path.join(directory, artifact.filename) | ||
with open(filepath, "wb") as file: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't really safe. The normal way to do this is to write to a temporary file (either in the same directory or elsewhere), and then swap them, since that's a atomic operation on most file systems. It also helps handle crashes with partial writes etc....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont think we need this to be atomic, which is what i was asking below, but you're right about crashes, so i'll update it to use temp file and swap
for artifact in artifacts.artifacts: | ||
filepath = os.path.join(directory, artifact.filename) | ||
with open(filepath, "wb") as file: | ||
file.write(artifact.payload) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a hash or something we can use here to help detect corruption? or even a version number if we ever want to update the on disk format?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see a kernel_hash above, not sure how easy that is to check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kernel_hash is just the folder name triton decided, i dont think it is complete hash. We should sha256 it if you want when we save it if you want but we havent done that anywhere else before
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah - I'm probably just paranoid, but if we ever have corruption issues, it'll help us figure it out. Your call.
cls._entries = None | ||
return [] | ||
|
||
if (entries := cls._entries) is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we assert that cls._entries it not None here? The begin compile method seems to indicate that we're trying to ensure one compile at a time, but we don't check that any entries are written here.
Or are we instead going to sometimes have to empty _entries even for valid compiles so we can't assert correctness here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had this it this way because what if someone enables just knobs in the middle of execution, so entries is None but collect operation executes because it was turned on midway
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm - I guess I'm curious about if it makes sense to have begin_compile around at all? just set _entires in put, read + clear them in collect? But very much your call.
Seperately if JK is turned on between begin_compile + put, we're going to not write to the filesystem during put, not sure if that's an issue (if inductor is going to read the files later).
@@ -1,6 +1,7 @@ | |||
# Owner(s): ["module: inductor"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Style question - Is it pytorch style to not write explicit tests for the trion_bundler and maybe triton_heuristics files?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All the tests I have seen so far have been end to end tests with counters to verify data. I could write more if you want but I'm just adhering to what currently exists
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we test the corner and errors cases in triton_bundler? I.e. begin_compile called twice, collect called twice, disk reads + writes throw errors?
Maybe even writes from two threads to make sure stuff is thread safe?
re: data sources, at a minimum can we log whether this confis is enabled into dynamo_compile and pt2_remote_cache? |
custom_dir = False | ||
basedir = os.path.join(cache_dir(), "triton") | ||
|
||
for artifacts in bundle: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
worth logging (counter or something) how many bytes we're writing into the cache?
torch/_inductor/codecache.py
Outdated
write_atomic(artifact_path, code, make_dirs=True) | ||
|
||
if bundle := graph._triton_bundle: | ||
TritonBundler.write_bundle_to_file_system(bundle) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq - Worth recording how long this takes in case the disk is slow?
torch/_inductor/codecache.py
Outdated
compiled_graph._time_taken_ns = time_ns() - start_time | ||
cache_key = key_info[0] | ||
compiled_graph._fx_graph_cache_key = cache_key | ||
compiled_graph._triton_bundle = TritonBundler.collect() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq - Worth recording how long this takes?
from pathlib import Path | ||
from typing import Dict, List, Optional | ||
|
||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jansel When do I put a file in torch._inductor.runtime as opposed to somewhere else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our generated code should only import from torch._inductor.runtime.
@classmethod | ||
def put(cls, kernel_hash: str, device: int) -> None: | ||
if (entries := cls._entries) is not None: | ||
# CachingAutotuner.__init__ unconditionally sets TRITON_CACHE_DIR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not really your problem but I want to point out that setenv on TRITON_CACHE_DIR is not thread safe lol
log = logging.getLogger(__name__) | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minimal documentation for these dataclasses would be helpful on review
class TritonBundler: | ||
# It is safe for multiple threads to insert entries as insert operation | ||
# of dict is atomic | ||
_entries: Optional[Dict[TritonBundleEntry, None]] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this... just an OrderedSet? lol
Despite commenting a bit, I might not be a good reviewer for this, there's a lot of indirect threading in and out of the Triton backend that I am not familiar with the call sites for. I'll work harder to review it if Bert/Niko don't pick up the review. |
Summary: This diff/PR attempts to consolidate Triton caching into the Inductor caching so that there can be just one cache that unifies them both, reducing network requests and increasing success rate. Implementation details can be found via reading the code or the post: https://fb.workplace.com/groups/1553867532149891/posts/1605037517032892 I did not use the Autotune bundler code at all since I want to simplify that and merge it into this on the next diff/PR. Test Plan: Updated unit tests ``` TORCHINDUCTOR_BUNDLE_TRITON_INTO_FX_GRAPH_CACHE=1 buck2 run mode/opt //scripts/oulgen:runner ``` https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpmTZt6b/0_0_0/fx_graph_cache_hit_4.json {F1944553347} Differential Revision: D64504909
425420e
to
7dfffb3
Compare
This pull request was exported from Phabricator. Differential Revision: D64504909 |
At a high level this seems fine but there are a bunch of details that need work. |
7dfffb3
to
845c7cc
Compare
This pull request was exported from Phabricator. Differential Revision: D64504909 |
torch/_inductor/triton_bundler.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A little confused here:
os.replace(src, dst, *, src_dir_fd=None, dst_dir_fd=None)
Rename the file or directory src to dst. If dst is a non-empty directory, OSError will be raised.
You seem to pass a directory on the RHS, this doesn't seem like it could possibly work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we chatted offline, this does work after the skip logic i added
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New os.replace logic is not quite right
845c7cc
to
679f89a
Compare
This pull request was exported from Phabricator. Differential Revision: D64504909 |
Summary: This diff/PR attempts to consolidate Triton caching into the Inductor caching so that there can be just one cache that unifies them both, reducing network requests and increasing success rate. Implementation details can be found via reading the code or the post: https://fb.workplace.com/groups/1553867532149891/posts/1605037517032892 I did not use the Autotune bundler code at all since I want to simplify that and merge it into this on the next diff/PR. In terms of instrumentation 1) Dynamo compile: `triton_bundler_time_saved_s` this is sum of all triton.compile calls. We dont have to use the specific number, can use this as a binary value. 2) Events table: I used dynamo_timed to measure how much time we spend on bundler collect and write functions which is all the work we do in this diff 3) TLParse: I emitted number of kernels and triton_bundler_time_saved_s into tlparse as well Test Plan: Updated unit tests ``` TORCHINDUCTOR_BUNDLE_TRITON_INTO_FX_GRAPH_CACHE=1 buck2 run mode/opt //scripts/oulgen:runner ``` https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpmTZt6b/0_0_0/fx_graph_cache_hit_4.json {F1944553347} Reviewed By: ezyang Differential Revision: D64504909
679f89a
to
76b161e
Compare
This pull request was exported from Phabricator. Differential Revision: D64504909 |
|
||
if os.path.exists(directory) and len(os.listdir(directory)) != 0: | ||
# If directory already exists, we bail out and leave | ||
# local disk to take care of caching |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need more description here, specifically, where are documenting the fact that we assume exclusive access to the directory?
Summary: This diff/PR attempts to consolidate Triton caching into the Inductor caching so that there can be just one cache that unifies them both, reducing network requests and increasing success rate. Implementation details can be found via reading the code or the post: https://fb.workplace.com/groups/1553867532149891/posts/1605037517032892 I did not use the Autotune bundler code at all since I want to simplify that and merge it into this on the next diff/PR. In terms of instrumentation 1) Dynamo compile: `triton_bundler_time_saved_s` this is sum of all triton.compile calls. We dont have to use the specific number, can use this as a binary value. 2) Events table: I used dynamo_timed to measure how much time we spend on bundler collect and write functions which is all the work we do in this diff 3) TLParse: I emitted number of kernels and triton_bundler_time_saved_s into tlparse as well Test Plan: Updated unit tests ``` TORCHINDUCTOR_BUNDLE_TRITON_INTO_FX_GRAPH_CACHE=1 buck2 run mode/opt //scripts/oulgen:runner ``` https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpmTZt6b/0_0_0/fx_graph_cache_hit_4.json {F1944553347} Reviewed By: ezyang Differential Revision: D64504909
76b161e
to
f394adf
Compare
This pull request was exported from Phabricator. Differential Revision: D64504909 |
|
||
# __grp__kernel_name.json contains metadata with source code paths | ||
# we use this as sentinal value for search and replace | ||
_REPLACE_BYTES: bytes = b"[REPLACE]" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There will be some terrifying SEV when this mucks over some unrelated byte string that just happens to be [REPLACE]
. One way to protect against this is to assert that there are no unrelated occurrences of this string before we munge file paths.
path = os.path.join(entry.directory, entry.kernel_hash) | ||
if not os.path.exists(path): | ||
continue | ||
for filename in os.listdir(path): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So post our discussion, we have exclusive access to the kernel directory at this point in time (no one potentially racing us writing files into this directory while we are collecting).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i added this to the doc block above
# Random ID to avoid any collisions | ||
rnd_id = str(uuid.uuid4()) | ||
tmp_dir = os.path.join(basedir, f"tmp.{rnd_id}") | ||
os.makedirs(tmp_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: uuid to generate tempfile is not actually the recommended way. On POSIX, you can use mkstemp or similar to create a temporary file in a directory that is guaranteed not to conflict with any other file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But actually, because you have exclusive access, you don't even need a random ID, you can just write to a well known temp file name and because you assume no one else is mucking with this dir it's fine. You only need to guard against premature termination
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
although you do have to delete stale tmp dir if there are any lying around :P If it's a well known name, it's easier to make sure that you cleanup the next time around.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None of the comments are blocking
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary: This diff/PR attempts to consolidate Triton caching into the Inductor caching so that there can be just one cache that unifies them both, reducing network requests and increasing success rate. Implementation details can be found via reading the code or the post: https://fb.workplace.com/groups/1553867532149891/posts/1605037517032892 I did not use the Autotune bundler code at all since I want to simplify that and merge it into this on the next diff/PR. In terms of instrumentation 1) Dynamo compile: `triton_bundler_time_saved_s` this is sum of all triton.compile calls. We dont have to use the specific number, can use this as a binary value. 2) Events table: I used dynamo_timed to measure how much time we spend on bundler collect and write functions which is all the work we do in this diff 3) TLParse: I emitted number of kernels and triton_bundler_time_saved_s into tlparse as well Test Plan: Updated unit tests Adhoc running ``` TORCHINDUCTOR_BUNDLE_TRITON_INTO_FX_GRAPH_CACHE=1 buck2 run @mode/opt //scripts/oulgen:runner ``` gives https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpmTZt6b/0_0_0/fx_graph_cache_hit_4.json <img width="771" alt="image" src="https://github.com/user-attachments/assets/478782a2-ee47-40cb-b723-fcac2bf9dd93"> Differential Revision: D64504909 Pull Request resolved: pytorch#138239 Approved by: https://github.com/ezyang
Summary:
This diff/PR attempts to consolidate Triton caching into the Inductor caching so that there can be just one cache that unifies them both, reducing network requests and increasing success rate.
Implementation details can be found via reading the code or the post: https://fb.workplace.com/groups/1553867532149891/posts/1605037517032892
I did not use the Autotune bundler code at all since I want to simplify that and merge it into this on the next diff/PR.
In terms of instrumentation
triton_bundler_time_saved_s
this is sum of all triton.compile calls. We dont have to use the specific number, can use this as a binary value.Test Plan: Updated unit tests
Adhoc running
gives

https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpmTZt6b/0_0_0/fx_graph_cache_hit_4.json
Differential Revision: D64504909
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @rec