Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
jansel committed Jun 19, 2024
2 parents 98525d1 + a7a33c3 commit a8d60f7
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 8 deletions.
18 changes: 18 additions & 0 deletions benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3123,6 +3123,12 @@ def parse_args(args=None):
parser.add_argument(
"--freezing", action="store_true", help="turn on freezing", default=False
)
parser.add_argument(
"--inductor-config",
"-c",
action="append",
help="key=value in torch._inductor.config",
)
parser.add_argument(
"--ci", action="store_true", help="Flag to tell that its a CI run"
)
Expand Down Expand Up @@ -4025,6 +4031,18 @@ def run(runner, args, original_dir=None):
inductor_config.triton.divisible_by_16 = not args.disable_divisible_by_16
if args.inference:
inductor_config.freezing = args.freezing
if args.inductor_config:
for config in args.inductor_config:
key, value = config.split("=")
typ = type(inductor_config.__getattr__(key))
if issubclass(typ, bool):
assert value in ("0", "1", "True", "False")
value = value in ("1", "True")
elif issubclass(typ, (str, int, float)):
value = typ(value)
else:
raise NotImplementedError(typ)
inductor_config.__setattr__(key, value)

runner.setup_amp()

Expand Down
16 changes: 10 additions & 6 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9878,7 +9878,7 @@ def fn(x: torch.Tensor) -> torch.Tensor:
# Inductor specializes on the (unguarded) alignment of the initial input.
# Make sure that for different configurations, nothing breaks.
for offset in (0, 1, 2, 3, 4):
base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE)
base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=self.device)
inp = torch.as_strided(base, (64, 64), (64, 1), offset)
torch._dynamo.reset()
fn_c = torch.compile(fn)
Expand All @@ -9888,8 +9888,10 @@ def fn(x: torch.Tensor) -> torch.Tensor:
self.assertEqual(ref, res)

for offset2 in (0, 1, 2, 3, 4):
base2 = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE)
inp2 = torch.as_strided(base, (64, 64), (64, 1), offset2)
base2 = torch.randn(
64 * 64 + 64, dtype=torch.float32, device=self.device
)
inp2 = torch.as_strided(base2, (64, 64), (64, 1), offset2)
ref2 = fn(inp2)
res2 = fn_c(inp2)
self.assertEqual(ref2, res2)
Expand All @@ -9910,7 +9912,7 @@ def fail(guard):
def fn(x: torch.Tensor) -> torch.Tensor:
return x.sin() + x.cos()

base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE)
base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=self.device)

inp1 = torch.as_strided(base, (32, 32), (32, 1), 4)
inp2 = torch.as_strided(base, (64, 64), (64, 1), 4)
Expand Down Expand Up @@ -9955,9 +9957,11 @@ def fn(x):
((64, 64), (64, 1), 5),
):
torch.manual_seed(42)
base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE)
base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=self.device)
torch.manual_seed(42)
base_ref = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE)
base_ref = torch.randn(
64 * 64 + 64, dtype=torch.float32, device=self.device
)

inp = torch.as_strided(base, size, stride, offset)
inp_ref = torch.as_strided(base_ref, size, stride, offset)
Expand Down
18 changes: 16 additions & 2 deletions torch/_inductor/codegen/halide.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def maximum(a, b):

@staticmethod
def where(a, b, c):
return f"hl.select({a}, {b}, {c})"
return f"hl.select({a}, {b}, hl.cast({b.name}.type(), {c}))"

@staticmethod
def cos(x):
Expand Down Expand Up @@ -1059,9 +1059,23 @@ def update_index(m):

code.do_unindent(2)
code.splice(
"""
f"""
if __name__ == "__main__":
hl.main()
else:
hl.load_plugin({HalideCodeCache.find_libautoschedule(meta.scheduler)!r})
target = hl.Target({meta.target!r})
autoscheduler = hl.AutoschedulerParams({meta.scheduler!r}, {meta.scheduler_flags!r})
with hl.GeneratorContext(target, autoscheduler):
gen = Kernel()
pipeline = gen._build_pipeline()
# gen.compile_to_callable() does not run the autoscheduler
pipeline.apply_autoscheduler(target, autoscheduler)
kernel = pipeline.compile_to_callable([
gen._get_input_parameter(a.name)._to_argument()
for a in gen._get_arginfos()
if a.dir == hl.ArgInfoDirection.Input
], target)
"""
)
return code.getvalue()
Expand Down

0 comments on commit a8d60f7

Please sign in to comment.