Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions benchmarks/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from torchdynamo.testing import format_speedup
from torchdynamo.testing import same
from torchdynamo.utils import clone_inputs
from torchinductor.utils import fresh_triton_cache

try:
from functorch._src.aot_autograd import set_model_name
Expand Down Expand Up @@ -1284,6 +1285,33 @@ def compare_branches(
"--diff_main called on main branch, what are you diffing?"
)

def maybe_fresh_cache(fn):
def inner(self, *args, **kwargs):
cache_minder = NullContext()
if self.args.cold_start_latency:
cache_entries = {}
cache_minder = fresh_triton_cache(cache_entries)

try:
with cache_minder:
return fn(self, *args, **kwargs)
finally:
dump_cache = False
if dump_cache and self.args.cold_start_latency:
output_csv(
output_filename[:-4] + "_triton_cache.csv",
["dev", "name", "batch_size", "triton_cache"],
[
current_device,
current_name,
current_batch_size,
cache_entries,
],
)

return inner

@maybe_fresh_cache
def run_one_model(
self,
name,
Expand Down Expand Up @@ -1454,6 +1482,12 @@ def parse_args():
help="Delta this branch against main. In the future, we may add support for picking the branch.",
)

parser.add_argument(
"--cold_start_latency",
action="store_true",
help="Use a fresh triton cachedir when running each model, to force cold-start compile.",
)

group_fuser = parser.add_mutually_exclusive_group()
# --nvfuser is now the default, keep the option to not break scripts
group_fuser.add_argument("--nvfuser", action="store_true", help=argparse.SUPPRESS)
Expand Down
27 changes: 27 additions & 0 deletions torchinductor/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import collections
import contextlib
import functools
import operator
import os
import tempfile
import time
from importlib import import_module
from typing import Any
from typing import Dict
from typing import List
from unittest import mock

import numpy as np
import sympy
Expand Down Expand Up @@ -234,3 +238,26 @@ def has_incompatible_cudagraph_ops(gm):
instance_descriptor = collections.namedtuple(
"instance_descriptor", ["divisible_by_16", "equal_to_1"]
)


@contextlib.contextmanager
def fresh_triton_cache(cache_entries=None):
"""
Contextmanager that provides a clean tmp cachedir for triton.

Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes
generated with this cache instance.
"""
with tempfile.TemporaryDirectory() as tmpdirname:
with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": tmpdirname}):
yield
if isinstance(cache_entries, dict):
assert len(cache_entries) == 0, "expected empty cache_entries dict"
files = os.listdir(tmpdirname)
cache_entries.update(
{
f: os.path.getsize(os.path.join(tmpdirname, f))
for f in files
if ".lock" not in f
}
)