From e49776371210a7184680bd68e57cbf1a7d73f8a7 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Tue, 13 Sep 2022 01:21:27 +0000 Subject: [PATCH 1/2] Clean/monitor triton cache during benchmarking --- benchmarks/common.py | 42 +++++++++++++++++++++++++----------------- torchinductor/utils.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 17 deletions(-) diff --git a/benchmarks/common.py b/benchmarks/common.py index 367406f07c..47abf758a4 100644 --- a/benchmarks/common.py +++ b/benchmarks/common.py @@ -17,6 +17,7 @@ import numpy as np import pandas as pd import torch +from torchinductor.utils import TritonCacheMinder from microbenchmarks.operator_inp_utils import OperatorInputsMode from scipy.stats import gmean from scipy.stats import ttest_ind @@ -449,24 +450,26 @@ def maybe_profile(*args, **kwargs): else: yield - with maybe_profile(enabled=args.export_profiler_trace) as p: - frozen_model_iter_fn = torchdynamo.run(model_iter_fn) - for rep in range(args.repeat): - inputs = ( - randomize_input(copy.deepcopy(example_inputs)) - if should_randomize_input - else example_inputs - ) + with TritonCacheMinder() as triton_cache_minder: + with maybe_profile(enabled=args.export_profiler_trace) as p: + frozen_model_iter_fn = torchdynamo.run(model_iter_fn) + for rep in range(args.repeat): + inputs = ( + randomize_input(copy.deepcopy(example_inputs)) + if should_randomize_input + else example_inputs + ) - # interleave the runs to handle frequency scaling and load changes - timings[rep, 0], expected_output = timed( - model, model_iter_fn, inputs, return_result=True - ) - timings[rep, 1], actual_output = timed( - model, frozen_model_iter_fn, inputs, return_result=True - ) - if should_check_result: - is_correct = is_correct and same(expected_output, actual_output) + # interleave the runs to handle frequency scaling and load changes + timings[rep, 0], expected_output = timed( + model, model_iter_fn, inputs, return_result=True + ) + timings[rep, 1], actual_output = timed( + model, frozen_model_iter_fn, inputs, return_result=True + ) + if should_check_result: + is_correct = is_correct and same(expected_output, actual_output) + triton_cache = triton_cache_minder.get_cache_entries() if args.export_profiler_trace: name = args.profiler_trace_name + "_" + model.name + ".json" name = os.path.join(torchdynamo.config.base_dir, name) @@ -501,6 +504,11 @@ def maybe_profile(*args, **kwargs): ["dev", "name", "batch_size"] + headers, [current_device, current_name, current_batch_size] + data, ) + output_csv( + output_filename[:-4] + "_triton_cache.csv", + ["dev", "name", "batch_size", "triton_cache"], + [current_device, current_name, current_batch_size, triton_cache], + ) return format_speedup(speedup, pvalue, is_correct=is_correct) diff --git a/torchinductor/utils.py b/torchinductor/utils.py index a5d5e58bec..f2e4d4b6e7 100644 --- a/torchinductor/utils.py +++ b/torchinductor/utils.py @@ -1,11 +1,15 @@ import collections +import contextlib import functools import operator +import os +import tempfile import time from importlib import import_module from typing import Any from typing import Dict from typing import List +from unittest import mock import numpy as np import sympy @@ -234,3 +238,27 @@ def has_incompatible_cudagraph_ops(gm): instance_descriptor = collections.namedtuple( "instance_descriptor", ["divisible_by_16", "equal_to_1"] ) + + +class TritonCacheMinder: + """ + Contextmanager used in benchmarking, which provides a clean triton cache dir + and also audits the files that were created during a run. + """ + def __init__(self): + self.cache_entries = {} + + @contextlib.contextmanager + def fresh_cache(self): + with tempfile.TemporaryDirectory() as tmpdirname: + with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": tmpdirname}): + yield + cache_dir = os.path.join(tmpdirname, "cache") + files = os.listdir(cache_dir) + self.cache_entries = { + f: os.path.getsize(os.path.join(cache_dir, f)) + for f in files if ".lock" not in f + } + + def get_cache_entries(self): + return self.cache_entries From e9e43bb9af4eb751bba6da93854f368837679a54 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 12 Oct 2022 17:13:31 +0000 Subject: [PATCH 2/2] Refactor triton cache minder - pass in cache_entries instead of making minder a class - add wrapper around run_one_model - optionally dump cache entries to csv --- benchmarks/common.py | 76 ++++++++++++++++++++++++++++-------------- torchinductor/utils.py | 39 +++++++++++----------- 2 files changed, 70 insertions(+), 45 deletions(-) diff --git a/benchmarks/common.py b/benchmarks/common.py index 47abf758a4..0fc264ab86 100644 --- a/benchmarks/common.py +++ b/benchmarks/common.py @@ -17,7 +17,6 @@ import numpy as np import pandas as pd import torch -from torchinductor.utils import TritonCacheMinder from microbenchmarks.operator_inp_utils import OperatorInputsMode from scipy.stats import gmean from scipy.stats import ttest_ind @@ -34,6 +33,7 @@ from torchdynamo.testing import format_speedup from torchdynamo.testing import same from torchdynamo.utils import clone_inputs +from torchinductor.utils import fresh_triton_cache try: from functorch._src.aot_autograd import set_model_name @@ -450,26 +450,24 @@ def maybe_profile(*args, **kwargs): else: yield - with TritonCacheMinder() as triton_cache_minder: - with maybe_profile(enabled=args.export_profiler_trace) as p: - frozen_model_iter_fn = torchdynamo.run(model_iter_fn) - for rep in range(args.repeat): - inputs = ( - randomize_input(copy.deepcopy(example_inputs)) - if should_randomize_input - else example_inputs - ) + with maybe_profile(enabled=args.export_profiler_trace) as p: + frozen_model_iter_fn = torchdynamo.run(model_iter_fn) + for rep in range(args.repeat): + inputs = ( + randomize_input(copy.deepcopy(example_inputs)) + if should_randomize_input + else example_inputs + ) - # interleave the runs to handle frequency scaling and load changes - timings[rep, 0], expected_output = timed( - model, model_iter_fn, inputs, return_result=True - ) - timings[rep, 1], actual_output = timed( - model, frozen_model_iter_fn, inputs, return_result=True - ) - if should_check_result: - is_correct = is_correct and same(expected_output, actual_output) - triton_cache = triton_cache_minder.get_cache_entries() + # interleave the runs to handle frequency scaling and load changes + timings[rep, 0], expected_output = timed( + model, model_iter_fn, inputs, return_result=True + ) + timings[rep, 1], actual_output = timed( + model, frozen_model_iter_fn, inputs, return_result=True + ) + if should_check_result: + is_correct = is_correct and same(expected_output, actual_output) if args.export_profiler_trace: name = args.profiler_trace_name + "_" + model.name + ".json" name = os.path.join(torchdynamo.config.base_dir, name) @@ -504,11 +502,6 @@ def maybe_profile(*args, **kwargs): ["dev", "name", "batch_size"] + headers, [current_device, current_name, current_batch_size] + data, ) - output_csv( - output_filename[:-4] + "_triton_cache.csv", - ["dev", "name", "batch_size", "triton_cache"], - [current_device, current_name, current_batch_size, triton_cache], - ) return format_speedup(speedup, pvalue, is_correct=is_correct) @@ -1292,6 +1285,33 @@ def compare_branches( "--diff_main called on main branch, what are you diffing?" ) + def maybe_fresh_cache(fn): + def inner(self, *args, **kwargs): + cache_minder = NullContext() + if self.args.cold_start_latency: + cache_entries = {} + cache_minder = fresh_triton_cache(cache_entries) + + try: + with cache_minder: + return fn(self, *args, **kwargs) + finally: + dump_cache = False + if dump_cache and self.args.cold_start_latency: + output_csv( + output_filename[:-4] + "_triton_cache.csv", + ["dev", "name", "batch_size", "triton_cache"], + [ + current_device, + current_name, + current_batch_size, + cache_entries, + ], + ) + + return inner + + @maybe_fresh_cache def run_one_model( self, name, @@ -1462,6 +1482,12 @@ def parse_args(): help="Delta this branch against main. In the future, we may add support for picking the branch.", ) + parser.add_argument( + "--cold_start_latency", + action="store_true", + help="Use a fresh triton cachedir when running each model, to force cold-start compile.", + ) + group_fuser = parser.add_mutually_exclusive_group() # --nvfuser is now the default, keep the option to not break scripts group_fuser.add_argument("--nvfuser", action="store_true", help=argparse.SUPPRESS) diff --git a/torchinductor/utils.py b/torchinductor/utils.py index f2e4d4b6e7..88a71eebe3 100644 --- a/torchinductor/utils.py +++ b/torchinductor/utils.py @@ -240,25 +240,24 @@ def has_incompatible_cudagraph_ops(gm): ) -class TritonCacheMinder: +@contextlib.contextmanager +def fresh_triton_cache(cache_entries=None): """ - Contextmanager used in benchmarking, which provides a clean triton cache dir - and also audits the files that were created during a run. + Contextmanager that provides a clean tmp cachedir for triton. + + Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes + generated with this cache instance. """ - def __init__(self): - self.cache_entries = {} - - @contextlib.contextmanager - def fresh_cache(self): - with tempfile.TemporaryDirectory() as tmpdirname: - with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": tmpdirname}): - yield - cache_dir = os.path.join(tmpdirname, "cache") - files = os.listdir(cache_dir) - self.cache_entries = { - f: os.path.getsize(os.path.join(cache_dir, f)) - for f in files if ".lock" not in f - } - - def get_cache_entries(self): - return self.cache_entries + with tempfile.TemporaryDirectory() as tmpdirname: + with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": tmpdirname}): + yield + if isinstance(cache_entries, dict): + assert len(cache_entries) == 0, "expected empty cache_entries dict" + files = os.listdir(tmpdirname) + cache_entries.update( + { + f: os.path.getsize(os.path.join(tmpdirname, f)) + for f in files + if ".lock" not in f + } + )