diff --git a/benchmarks/common.py b/benchmarks/common.py index 367406f07c..0fc264ab86 100644 --- a/benchmarks/common.py +++ b/benchmarks/common.py @@ -33,6 +33,7 @@ from torchdynamo.testing import format_speedup from torchdynamo.testing import same from torchdynamo.utils import clone_inputs +from torchinductor.utils import fresh_triton_cache try: from functorch._src.aot_autograd import set_model_name @@ -1284,6 +1285,33 @@ def compare_branches( "--diff_main called on main branch, what are you diffing?" ) + def maybe_fresh_cache(fn): + def inner(self, *args, **kwargs): + cache_minder = NullContext() + if self.args.cold_start_latency: + cache_entries = {} + cache_minder = fresh_triton_cache(cache_entries) + + try: + with cache_minder: + return fn(self, *args, **kwargs) + finally: + dump_cache = False + if dump_cache and self.args.cold_start_latency: + output_csv( + output_filename[:-4] + "_triton_cache.csv", + ["dev", "name", "batch_size", "triton_cache"], + [ + current_device, + current_name, + current_batch_size, + cache_entries, + ], + ) + + return inner + + @maybe_fresh_cache def run_one_model( self, name, @@ -1454,6 +1482,12 @@ def parse_args(): help="Delta this branch against main. In the future, we may add support for picking the branch.", ) + parser.add_argument( + "--cold_start_latency", + action="store_true", + help="Use a fresh triton cachedir when running each model, to force cold-start compile.", + ) + group_fuser = parser.add_mutually_exclusive_group() # --nvfuser is now the default, keep the option to not break scripts group_fuser.add_argument("--nvfuser", action="store_true", help=argparse.SUPPRESS) diff --git a/torchinductor/utils.py b/torchinductor/utils.py index a5d5e58bec..88a71eebe3 100644 --- a/torchinductor/utils.py +++ b/torchinductor/utils.py @@ -1,11 +1,15 @@ import collections +import contextlib import functools import operator +import os +import tempfile import time from importlib import import_module from typing import Any from typing import Dict from typing import List +from unittest import mock import numpy as np import sympy @@ -234,3 +238,26 @@ def has_incompatible_cudagraph_ops(gm): instance_descriptor = collections.namedtuple( "instance_descriptor", ["divisible_by_16", "equal_to_1"] ) + + +@contextlib.contextmanager +def fresh_triton_cache(cache_entries=None): + """ + Contextmanager that provides a clean tmp cachedir for triton. + + Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes + generated with this cache instance. + """ + with tempfile.TemporaryDirectory() as tmpdirname: + with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": tmpdirname}): + yield + if isinstance(cache_entries, dict): + assert len(cache_entries) == 0, "expected empty cache_entries dict" + files = os.listdir(tmpdirname) + cache_entries.update( + { + f: os.path.getsize(os.path.join(tmpdirname, f)) + for f in files + if ".lock" not in f + } + )