diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 3d4ff7199f71f..3c08e7e89be11 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -3123,6 +3123,12 @@ def parse_args(args=None): parser.add_argument( "--freezing", action="store_true", help="turn on freezing", default=False ) + parser.add_argument( + "--inductor-config", + "-c", + action="append", + help="key=value in torch._inductor.config", + ) parser.add_argument( "--ci", action="store_true", help="Flag to tell that its a CI run" ) @@ -4025,6 +4031,18 @@ def run(runner, args, original_dir=None): inductor_config.triton.divisible_by_16 = not args.disable_divisible_by_16 if args.inference: inductor_config.freezing = args.freezing + if args.inductor_config: + for config in args.inductor_config: + key, value = config.split("=") + typ = type(inductor_config.__getattr__(key)) + if issubclass(typ, bool): + assert value in ("0", "1", "True", "False") + value = value in ("1", "True") + elif issubclass(typ, (str, int, float)): + value = typ(value) + else: + raise NotImplementedError(typ) + inductor_config.__setattr__(key, value) runner.setup_amp() diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index c88f0e9d3b5be..cf925e8eaefe0 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -9878,7 +9878,7 @@ def fn(x: torch.Tensor) -> torch.Tensor: # Inductor specializes on the (unguarded) alignment of the initial input. # Make sure that for different configurations, nothing breaks. for offset in (0, 1, 2, 3, 4): - base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE) + base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=self.device) inp = torch.as_strided(base, (64, 64), (64, 1), offset) torch._dynamo.reset() fn_c = torch.compile(fn) @@ -9888,8 +9888,10 @@ def fn(x: torch.Tensor) -> torch.Tensor: self.assertEqual(ref, res) for offset2 in (0, 1, 2, 3, 4): - base2 = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE) - inp2 = torch.as_strided(base, (64, 64), (64, 1), offset2) + base2 = torch.randn( + 64 * 64 + 64, dtype=torch.float32, device=self.device + ) + inp2 = torch.as_strided(base2, (64, 64), (64, 1), offset2) ref2 = fn(inp2) res2 = fn_c(inp2) self.assertEqual(ref2, res2) @@ -9910,7 +9912,7 @@ def fail(guard): def fn(x: torch.Tensor) -> torch.Tensor: return x.sin() + x.cos() - base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE) + base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=self.device) inp1 = torch.as_strided(base, (32, 32), (32, 1), 4) inp2 = torch.as_strided(base, (64, 64), (64, 1), 4) @@ -9955,9 +9957,11 @@ def fn(x): ((64, 64), (64, 1), 5), ): torch.manual_seed(42) - base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE) + base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=self.device) torch.manual_seed(42) - base_ref = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE) + base_ref = torch.randn( + 64 * 64 + 64, dtype=torch.float32, device=self.device + ) inp = torch.as_strided(base, size, stride, offset) inp_ref = torch.as_strided(base_ref, size, stride, offset) diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index 7f7a201acd570..b1335aeb71e11 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -262,7 +262,7 @@ def maximum(a, b): @staticmethod def where(a, b, c): - return f"hl.select({a}, {b}, {c})" + return f"hl.select({a}, {b}, hl.cast({b.name}.type(), {c}))" @staticmethod def cos(x): @@ -1059,9 +1059,23 @@ def update_index(m): code.do_unindent(2) code.splice( - """ + f""" if __name__ == "__main__": hl.main() + else: + hl.load_plugin({HalideCodeCache.find_libautoschedule(meta.scheduler)!r}) + target = hl.Target({meta.target!r}) + autoscheduler = hl.AutoschedulerParams({meta.scheduler!r}, {meta.scheduler_flags!r}) + with hl.GeneratorContext(target, autoscheduler): + gen = Kernel() + pipeline = gen._build_pipeline() + # gen.compile_to_callable() does not run the autoscheduler + pipeline.apply_autoscheduler(target, autoscheduler) + kernel = pipeline.compile_to_callable([ + gen._get_input_parameter(a.name)._to_argument() + for a in gen._get_arginfos() + if a.dir == hl.ArgInfoDirection.Input + ], target) """ ) return code.getvalue()