Skip to content

Conversation

oulgen
Copy link
Contributor

@oulgen oulgen commented Oct 17, 2024

Summary:
This diff/PR attempts to consolidate Triton caching into the Inductor caching so that there can be just one cache that unifies them both, reducing network requests and increasing success rate.

Implementation details can be found via reading the code or the post: https://fb.workplace.com/groups/1553867532149891/posts/1605037517032892

I did not use the Autotune bundler code at all since I want to simplify that and merge it into this on the next diff/PR.

In terms of instrumentation

  1. Dynamo compile: triton_bundler_time_saved_s this is sum of all triton.compile calls. We dont have to use the specific number, can use this as a binary value.
  2. Events table: I used dynamo_timed to measure how much time we spend on bundler collect and write functions which is all the work we do in this diff
  3. TLParse: I emitted number of kernels and triton_bundler_time_saved_s into tlparse as well

Test Plan: Updated unit tests

Adhoc running

TORCHINDUCTOR_BUNDLE_TRITON_INTO_FX_GRAPH_CACHE=1 buck2 run @mode/opt //scripts/oulgen:runner

gives
https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpmTZt6b/0_0_0/fx_graph_cache_hit_4.json
image

Differential Revision: D64504909

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov @rec

Copy link

pytorch-bot bot commented Oct 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138239

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit f394adf with merge base f4ab8b4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64504909

@oulgen oulgen added ciflow/trunk Trigger trunk jobs on your pull request release notes: inductor labels Oct 17, 2024
@oulgen
Copy link
Contributor Author

oulgen commented Oct 17, 2024

@ezyang I'm not really sure how to integrate this into TLParse or our other data sources.. Any recommendations?

@ezyang
Copy link
Contributor

ezyang commented Oct 17, 2024

Hi, I don't see any logging / trace structured updates / scuba updates, can you think about what kind of instrumentation you want to add and propose something here?

EDIT: Oh oops didn't see your comment

@classmethod
def collect(cls) -> List[TritonKernelArtifacts]:
if not TritonBundler.is_enabled():
cls._entries = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused about thread safety here. You are talking about _entries writes OK from multiple threads to be atomic. Is collect a sync point then or something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm probably being overly cautious here. collect is executed when we are writing the inductor cache entry to disk/remote, so at this point we expect everything to have completed and this is true by design (no need for any sync primitive).

class TritonBundler:
# It is safe for multiple threads to insert entries as insert operation
# of dict is atomic
_entries: Optional[Dict[TritonBundleEntry, None]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit - is this normal pytorch style? or should we just add a lock to make this thread safety explicit?

Copy link
Contributor Author

@oulgen oulgen Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The operation is atomic, why would we want the overhead of a lock? Documentation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documentation. (also do you think the overhead will be significant here? If you do we can drop it, but assuming what 100k triton compiles, and 100ns for a mutex acquisition, you get 10ms in overhead).


for artifact in artifacts.artifacts:
filepath = os.path.join(directory, artifact.filename)
with open(filepath, "wb") as file:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't really safe. The normal way to do this is to write to a temporary file (either in the same directory or elsewhere), and then swap them, since that's a atomic operation on most file systems. It also helps handle crashes with partial writes etc....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think we need this to be atomic, which is what i was asking below, but you're right about crashes, so i'll update it to use temp file and swap

for artifact in artifacts.artifacts:
filepath = os.path.join(directory, artifact.filename)
with open(filepath, "wb") as file:
file.write(artifact.payload)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a hash or something we can use here to help detect corruption? or even a version number if we ever want to update the on disk format?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see a kernel_hash above, not sure how easy that is to check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernel_hash is just the folder name triton decided, i dont think it is complete hash. We should sha256 it if you want when we save it if you want but we havent done that anywhere else before

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah - I'm probably just paranoid, but if we ever have corruption issues, it'll help us figure it out. Your call.

cls._entries = None
return []

if (entries := cls._entries) is not None:
Copy link
Contributor

@c00w c00w Oct 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we assert that cls._entries it not None here? The begin compile method seems to indicate that we're trying to ensure one compile at a time, but we don't check that any entries are written here.

Or are we instead going to sometimes have to empty _entries even for valid compiles so we can't assert correctness here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had this it this way because what if someone enables just knobs in the middle of execution, so entries is None but collect operation executes because it was turned on midway

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm - I guess I'm curious about if it makes sense to have begin_compile around at all? just set _entires in put, read + clear them in collect? But very much your call.

Seperately if JK is turned on between begin_compile + put, we're going to not write to the filesystem during put, not sure if that's an issue (if inductor is going to read the files later).

@@ -1,6 +1,7 @@
# Owner(s): ["module: inductor"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style question - Is it pytorch style to not write explicit tests for the trion_bundler and maybe triton_heuristics files?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the tests I have seen so far have been end to end tests with counters to verify data. I could write more if you want but I'm just adhering to what currently exists

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we test the corner and errors cases in triton_bundler? I.e. begin_compile called twice, collect called twice, disk reads + writes throw errors?

Maybe even writes from two threads to make sure stuff is thread safe?

@c00w
Copy link
Contributor

c00w commented Oct 17, 2024

re: data sources, at a minimum can we log whether this confis is enabled into dynamo_compile and pt2_remote_cache?

custom_dir = False
basedir = os.path.join(cache_dir(), "triton")

for artifacts in bundle:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

worth logging (counter or something) how many bytes we're writing into the cache?

write_atomic(artifact_path, code, make_dirs=True)

if bundle := graph._triton_bundle:
TritonBundler.write_bundle_to_file_system(bundle)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq - Worth recording how long this takes in case the disk is slow?

compiled_graph._time_taken_ns = time_ns() - start_time
cache_key = key_info[0]
compiled_graph._fx_graph_cache_key = cache_key
compiled_graph._triton_bundle = TritonBundler.collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq - Worth recording how long this takes?

from pathlib import Path
from typing import Dict, List, Optional

import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jansel When do I put a file in torch._inductor.runtime as opposed to somewhere else?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our generated code should only import from torch._inductor.runtime.

@classmethod
def put(cls, kernel_hash: str, device: int) -> None:
if (entries := cls._entries) is not None:
# CachingAutotuner.__init__ unconditionally sets TRITON_CACHE_DIR
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not really your problem but I want to point out that setenv on TRITON_CACHE_DIR is not thread safe lol

log = logging.getLogger(__name__)


@dataclasses.dataclass(frozen=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minimal documentation for these dataclasses would be helpful on review

class TritonBundler:
# It is safe for multiple threads to insert entries as insert operation
# of dict is atomic
_entries: Optional[Dict[TritonBundleEntry, None]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this... just an OrderedSet? lol

@ezyang
Copy link
Contributor

ezyang commented Oct 19, 2024

Despite commenting a bit, I might not be a good reviewer for this, there's a lot of indirect threading in and out of the Triton backend that I am not familiar with the call sites for. I'll work harder to review it if Bert/Niko don't pick up the review.

facebook-github-bot pushed a commit that referenced this pull request Oct 23, 2024
Summary:

This diff/PR attempts to consolidate Triton caching into the Inductor caching so that there can be just one cache that unifies them both, reducing network requests and increasing success rate.

Implementation details can be found via reading the code or the post: https://fb.workplace.com/groups/1553867532149891/posts/1605037517032892

I did not use the Autotune bundler code at all since I want to simplify that and merge it into this on the next diff/PR.

Test Plan:
Updated unit tests

```
TORCHINDUCTOR_BUNDLE_TRITON_INTO_FX_GRAPH_CACHE=1 buck2 run mode/opt //scripts/oulgen:runner
```
https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpmTZt6b/0_0_0/fx_graph_cache_hit_4.json
 {F1944553347}

Differential Revision: D64504909
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64504909

@oulgen oulgen requested a review from bertmaher October 23, 2024 18:49
@ezyang
Copy link
Contributor

ezyang commented Oct 25, 2024

At a high level this seems fine but there are a bunch of details that need work.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64504909

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A little confused here:

os.replace(src, dst, *, src_dir_fd=None, dst_dir_fd=None)
Rename the file or directory src to dst. If dst is a non-empty directory, OSError will be raised.

You seem to pass a directory on the RHS, this doesn't seem like it could possibly work?

Copy link
Contributor Author

@oulgen oulgen Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we chatted offline, this does work after the skip logic i added

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New os.replace logic is not quite right

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64504909

facebook-github-bot pushed a commit that referenced this pull request Oct 29, 2024
Summary:

This diff/PR attempts to consolidate Triton caching into the Inductor caching so that there can be just one cache that unifies them both, reducing network requests and increasing success rate.

Implementation details can be found via reading the code or the post: https://fb.workplace.com/groups/1553867532149891/posts/1605037517032892

I did not use the Autotune bundler code at all since I want to simplify that and merge it into this on the next diff/PR.

In terms of instrumentation
1) Dynamo compile: `triton_bundler_time_saved_s` this is sum of all triton.compile calls. We dont have to use the specific number, can use this as a binary value.
2) Events table: I used dynamo_timed to measure how much time we spend on bundler collect and write functions which is all the work we do in this diff
3) TLParse: I emitted number of kernels and triton_bundler_time_saved_s into tlparse as well

Test Plan:
Updated unit tests

```
TORCHINDUCTOR_BUNDLE_TRITON_INTO_FX_GRAPH_CACHE=1 buck2 run mode/opt //scripts/oulgen:runner
```
https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpmTZt6b/0_0_0/fx_graph_cache_hit_4.json
 {F1944553347}

Reviewed By: ezyang

Differential Revision: D64504909
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64504909


if os.path.exists(directory) and len(os.listdir(directory)) != 0:
# If directory already exists, we bail out and leave
# local disk to take care of caching
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need more description here, specifically, where are documenting the fact that we assume exclusive access to the directory?

Summary:

This diff/PR attempts to consolidate Triton caching into the Inductor caching so that there can be just one cache that unifies them both, reducing network requests and increasing success rate.

Implementation details can be found via reading the code or the post: https://fb.workplace.com/groups/1553867532149891/posts/1605037517032892

I did not use the Autotune bundler code at all since I want to simplify that and merge it into this on the next diff/PR.

In terms of instrumentation
1) Dynamo compile: `triton_bundler_time_saved_s` this is sum of all triton.compile calls. We dont have to use the specific number, can use this as a binary value.
2) Events table: I used dynamo_timed to measure how much time we spend on bundler collect and write functions which is all the work we do in this diff
3) TLParse: I emitted number of kernels and triton_bundler_time_saved_s into tlparse as well

Test Plan:
Updated unit tests

```
TORCHINDUCTOR_BUNDLE_TRITON_INTO_FX_GRAPH_CACHE=1 buck2 run mode/opt //scripts/oulgen:runner
```
https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpmTZt6b/0_0_0/fx_graph_cache_hit_4.json
 {F1944553347}

Reviewed By: ezyang

Differential Revision: D64504909
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64504909


# __grp__kernel_name.json contains metadata with source code paths
# we use this as sentinal value for search and replace
_REPLACE_BYTES: bytes = b"[REPLACE]"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There will be some terrifying SEV when this mucks over some unrelated byte string that just happens to be [REPLACE]. One way to protect against this is to assert that there are no unrelated occurrences of this string before we munge file paths.

path = os.path.join(entry.directory, entry.kernel_hash)
if not os.path.exists(path):
continue
for filename in os.listdir(path):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So post our discussion, we have exclusive access to the kernel directory at this point in time (no one potentially racing us writing files into this directory while we are collecting).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i added this to the doc block above

# Random ID to avoid any collisions
rnd_id = str(uuid.uuid4())
tmp_dir = os.path.join(basedir, f"tmp.{rnd_id}")
os.makedirs(tmp_dir)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: uuid to generate tempfile is not actually the recommended way. On POSIX, you can use mkstemp or similar to create a temporary file in a directory that is guaranteed not to conflict with any other file.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But actually, because you have exclusive access, you don't even need a random ID, you can just write to a well known temp file name and because you assume no one else is mucking with this dir it's fine. You only need to guard against premature termination

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

although you do have to delete stale tmp dir if there are any lying around :P If it's a well known name, it's easier to make sure that you cleanup the next time around.

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None of the comments are blocking

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

rahulsingh-intel pushed a commit to rahulsingh-intel/pytorch that referenced this pull request Nov 5, 2024
Summary:
This diff/PR attempts to consolidate Triton caching into the Inductor caching so that there can be just one cache that unifies them both, reducing network requests and increasing success rate.

Implementation details can be found via reading the code or the post: https://fb.workplace.com/groups/1553867532149891/posts/1605037517032892

I did not use the Autotune bundler code at all since I want to simplify that and merge it into this on the next diff/PR.

In terms of instrumentation
1) Dynamo compile: `triton_bundler_time_saved_s` this is sum of all triton.compile calls. We dont have to use the specific number, can use this as a binary value.
2) Events table: I used dynamo_timed to measure how much time we spend on bundler collect and write functions which is all the work we do in this diff
3) TLParse: I emitted number of kernels and triton_bundler_time_saved_s into tlparse as well

Test Plan: Updated unit tests

Adhoc running
```
TORCHINDUCTOR_BUNDLE_TRITON_INTO_FX_GRAPH_CACHE=1 buck2 run @mode/opt //scripts/oulgen:runner
```
gives
https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpmTZt6b/0_0_0/fx_graph_cache_hit_4.json
<img width="771" alt="image" src="https://github.com/user-attachments/assets/478782a2-ee47-40cb-b723-fcac2bf9dd93">

Differential Revision: D64504909

Pull Request resolved: pytorch#138239
Approved by: https://github.com/ezyang
@github-actions github-actions bot deleted the export-D64504909 branch November 30, 2024 02:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants