Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
jansel committed Jun 19, 2024
1 parent a0e1e20 commit cafbd58
Show file tree
Hide file tree
Showing 14 changed files with 2,141 additions and 87 deletions.
18 changes: 18 additions & 0 deletions benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3123,6 +3123,12 @@ def parse_args(args=None):
parser.add_argument(
"--freezing", action="store_true", help="turn on freezing", default=False
)
parser.add_argument(
"--inductor-config",
"-c",
action="append",
help="key=value in torch._inductor.config",
)
parser.add_argument(
"--ci", action="store_true", help="Flag to tell that its a CI run"
)
Expand Down Expand Up @@ -4025,6 +4031,18 @@ def run(runner, args, original_dir=None):
inductor_config.triton.divisible_by_16 = not args.disable_divisible_by_16
if args.inference:
inductor_config.freezing = args.freezing
if args.inductor_config:
for config in args.inductor_config:
key, value = config.split("=")
typ = type(inductor_config.__getattr__(key))
if issubclass(typ, bool):
assert value in ("0", "1", "True", "False")
value = value in ("1", "True")
elif issubclass(typ, (str, int, float)):
value = typ(value)
else:
raise NotImplementedError(typ)
inductor_config.__setattr__(key, value)

runner.setup_amp()

Expand Down
46 changes: 42 additions & 4 deletions test/inductor/test_halide.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
from torch._inductor import config
from torch._inductor.codecache import HalideCodeCache
from torch._inductor.runtime.hints import HalideInputSpec, HalideMeta
from torch._inductor.test_case import run_tests, TestCase
Expand All @@ -21,21 +22,51 @@
HAS_HALIDE = False


try:
from . import test_torchinductor
except ImportError:
import test_torchinductor


make_halide = config.patch(
{
"cpu_backend": "halide",
"cuda_backend": "halide",
"fallback_random": True, # TODO(jansel): support random
"halide.scan_kernels": True,
}
)


@unittest.skipUnless(HAS_HALIDE, "requires halide")
class HalideTests(TestCase):
def test_codecache(self):
fn = HalideCodeCache.generate_halide(
HalideMeta(
argtypes=[
HalideInputSpec(ctype="float*", name="in_ptr0", numel="1024L"),
HalideInputSpec(ctype="float*", name="in_ptr1", numel="1024L"),
HalideInputSpec(
ctype="float*",
name="in_ptr0",
shape=["1024L"],
stride=["1L"],
offset="0",
),
HalideInputSpec(
ctype="float*",
name="in_ptr1",
shape=["1024L"],
stride=["1L"],
offset="0",
),
HalideInputSpec(
ctype="float*",
name="out_ptr0",
numel="1024L",
shape=["1024L"],
stride=["1L"],
offset="0",
),
],
target="host",
target="host-no_runtime",
scheduler="Mullapudi2016",
scheduler_flags={
"parallelism": parallel_num_threads(),
Expand Down Expand Up @@ -82,6 +113,13 @@ def generate(g):
self.assertEqual(c, a + b)


SweepInputsCpuHalideTest = make_halide(test_torchinductor.SweepInputsCpuTest)
CpuHalideTests = make_halide(test_torchinductor.CpuTests)

if test_torchinductor.HAS_GPU:
SweepInputsGPUHalideTest = make_halide(test_torchinductor.SweepInputsGPUTest)
GPUHalideTests = make_halide(test_torchinductor.GPUTests)

if __name__ == "__main__":
if HAS_CPU and not IS_MACOS and HAS_HALIDE:
run_tests(needs="filelock")
Loading

0 comments on commit cafbd58

Please sign in to comment.