diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 3d4ff7199f71f..3c08e7e89be11 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -3123,6 +3123,12 @@ def parse_args(args=None): parser.add_argument( "--freezing", action="store_true", help="turn on freezing", default=False ) + parser.add_argument( + "--inductor-config", + "-c", + action="append", + help="key=value in torch._inductor.config", + ) parser.add_argument( "--ci", action="store_true", help="Flag to tell that its a CI run" ) @@ -4025,6 +4031,18 @@ def run(runner, args, original_dir=None): inductor_config.triton.divisible_by_16 = not args.disable_divisible_by_16 if args.inference: inductor_config.freezing = args.freezing + if args.inductor_config: + for config in args.inductor_config: + key, value = config.split("=") + typ = type(inductor_config.__getattr__(key)) + if issubclass(typ, bool): + assert value in ("0", "1", "True", "False") + value = value in ("1", "True") + elif issubclass(typ, (str, int, float)): + value = typ(value) + else: + raise NotImplementedError(typ) + inductor_config.__setattr__(key, value) runner.setup_amp()