Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
jansel committed Jun 19, 2024
1 parent 17cfb89 commit aa6de48
Show file tree
Hide file tree
Showing 6 changed files with 324 additions and 40 deletions.
20 changes: 17 additions & 3 deletions test/inductor/test_halide.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,29 @@ def test_codecache(self):
fn = HalideCodeCache.generate_halide(
HalideMeta(
argtypes=[
HalideInputSpec(ctype="float*", name="in_ptr0", shape=["1024L"]),
HalideInputSpec(ctype="float*", name="in_ptr1", shape=["1024L"]),
HalideInputSpec(
ctype="float*",
name="in_ptr0",
shape=["1024L"],
stride=["1L"],
offset="0",
),
HalideInputSpec(
ctype="float*",
name="in_ptr1",
shape=["1024L"],
stride=["1L"],
offset="0",
),
HalideInputSpec(
ctype="float*",
name="out_ptr0",
shape=["1024L"],
stride=["1L"],
offset="0",
),
],
target="host",
target="host-no_runtime",
scheduler="Mullapudi2016",
scheduler_flags={
"parallelism": parallel_num_threads(),
Expand Down
250 changes: 223 additions & 27 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
_reload_python_module,
_reload_python_module_in_subproc,
)
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir
from torch._inductor.utils import ALIGN_BYTES, clear_on_fresh_inductor_cache, is_linux

from torch._logging import trace_structured
Expand All @@ -72,7 +72,7 @@

from torch._inductor.graph import GraphLowering
from torch._inductor.ir import ChoiceCaller
from torch._inductor.runtime.hints import HalideMeta
from torch._inductor.runtime.hints import HalideInputSpec, HalideMeta


_HERE = os.path.abspath(__file__)
Expand Down Expand Up @@ -2595,6 +2595,12 @@ class CppPythonBindingsCodeCache(CppCodeCache):
[[unlikely]] throw std::runtime_error("expected int arg");
return result;
}
template <> inline uintptr_t parse_arg<uintptr_t>(PyObject* args, size_t n) {
auto result = PyLong_AsVoidPtr(PyTuple_GET_ITEM(args, n));
if(result == reinterpret_cast<void*>(-1) && PyErr_Occurred())
[[unlikely]] throw std::runtime_error("expected int arg");
return reinterpret_cast<uintptr_t>(result);
}
%s
Expand Down Expand Up @@ -2871,40 +2877,142 @@ def validate_new_cpp_commands():
class HalideCodeCache(CppPythonBindingsCodeCache):
cache: Dict[str, Callable[[], Union[ModuleType, CDLL]]] = {}
cache_clear = staticmethod(cache.clear)
glue_template = textwrap.dedent(
_standalone_runtime_path: Optional[str] = None
prefix = textwrap.dedent(
"""
#include "{halidebuffer_h}"
#include "{halideruntime_h}"
#include "{headerfile}"
#include <stdexcept>
#include <cmath>
namespace c10 {{
inline long div_floor_integer(long a, long b) {{
if ((a<0) != (b<0)) {{
const auto quot = a / b;
const auto rem = a % b;
return rem ? quot - 1 : quot;
}}
return a / b;
}}
}}
"""
)
glue_template_cpp = prefix + textwrap.dedent(
"""
void kernel({argdefs}) {{
{buffers}
int err = halide_kernel({buffer_names});
if(err != 0) {{
throw std::runtime_error("halide_kernel failed");
}}
if(err != 0) throw std::runtime_error("halide_kernel failed");
}}
"""
)
glue_template_cuda = prefix + textwrap.dedent(
"""
#include <cuda.h>
static const halide_device_interface_t* cuda_interface = halide_cuda_device_interface();
void kernel({argdefs}, uintptr_t stream) {{
{buffers}
int err = halide_kernel(reinterpret_cast<void*>(stream), {buffer_names});
if(err != 0) throw std::runtime_error("halide_kernel failed");
}}
"""
)
standalone_runtime_cuda_init = textwrap.dedent(
"""
#include "{}"
#include <cuda.h>
static int acquire_context(void* user_context,
void** cuda_context_out,
bool create) {{
return cuCtxGetCurrent(reinterpret_cast<CUcontext*>(cuda_context_out));
}}
static int release_context(void* user_context) {{
return 0;
}}
static int get_stream(void* user_context,
void* cuda_context,
void** stream_out) {{
*stream_out = user_context;
return 0;
}}
static int register_halide_hooks() {{
halide_set_cuda_acquire_context(&acquire_context);
halide_set_cuda_release_context(&release_context);
halide_set_cuda_get_stream(&get_stream);
return 0;
}}
int inductor_register_halide_hooks_result = register_halide_hooks();
"""
)

@classmethod
def _codegen_buffer(cls, name: str, arg: HalideInputSpec, cuda: bool):
assert arg.shape is not None
assert arg.stride is not None and len(arg.shape) == len(arg.stride)
assert arg.offset is not None
data_ptr = f"{arg.alias_of or arg.name} + {arg.offset}"
if cuda:
device = f"reinterpret_cast<uint64_t>({data_ptr})"
device_interface = "cuda_interface"
host = "nullptr"
flags = "halide_buffer_flag_device_dirty"
else:
device = "0"
device_interface = "nullptr"
host = f"reinterpret_cast<uint8_t*>({data_ptr})"
flags = "halide_buffer_flag_host_dirty"

dims = []
for size, stride in zip(arg.shape, arg.stride):
dims.append(f"halide_dimension_t(0, {size}, {stride})")

return [
f"halide_buffer_t {name};",
f"halide_dimension_t {name}_dims[] = {{{', '.join(dims)}}};",
f"{name}.device = {device};",
f"{name}.device_interface = {device_interface};",
f"{name}.host = {host};",
f"{name}.flags = {flags};",
f"{name}.type = {arg.halide_type()};",
f"{name}.dimensions = {len(dims)};",
f"{name}.dim = {name}_dims;",
f"{name}.padding = nullptr;",
]

@classmethod
def _codegen_glue(cls, argtypes, headerfile):
def _codegen_glue(cls, meta, headerfile):
is_cuda = meta.is_cuda()
assert is_cuda is ("user_context" in meta.target)
assert "no_runtime" in meta.target
buffers = []
buffer_names = []
for i, arg in enumerate(argtypes):
for i, arg in enumerate(meta.argtypes):
if arg.is_buffer():
buffer_names.append(f"hl_buf_{i}")
buffers.append(
f" Halide::Runtime::Buffer {buffer_names[-1]}({arg.halide_type()}, {arg.name}, {', '.join(arg.shape)});"
)
buffer_names.append(f"&hl_buf_{i}")
buffers.extend(cls._codegen_buffer(f"hl_buf_{i}", arg, is_cuda))
else:
assert "*" not in arg.ctype
buffer_names.append(arg.name)
glue_code = cls.glue_template.format(
halidebuffer_h=cls.find_header("HalideBuffer.h"),
buffers = "\n".join([f" {line}" for line in buffers]).lstrip()

glue_template = cls.glue_template_cuda if is_cuda else cls.glue_template_cpp
glue_code = glue_template.format(
halideruntime_h=cls.find_header(
"HalideRuntimeCuda.h" if is_cuda else "HalideRuntime.h"
),
headerfile=headerfile,
argdefs=", ".join(f"{a.bindings_type()} {a.name}" for a in argtypes),
buffers="\n".join(buffers).lstrip(),
argdefs=", ".join(
f"{a.bindings_type()} {a.name}"
for a in meta.argtypes
if a.alias_of is None
),
buffers=buffers,
buffer_names=", ".join(buffer_names),
)
return glue_code
Expand All @@ -2915,7 +3023,9 @@ def config_hash(cls):
return sha256_hash(
"\n".join(
[
cls.glue_template,
cls.glue_template_cpp,
cls.glue_template_cuda,
cls.standalone_runtime_cuda_init,
f"{cls.cpu_cache_size()}",
cpp_compile_command("I", "O"),
]
Expand All @@ -2939,10 +3049,11 @@ def cpu_cache_size():

@staticmethod
def _search_for_file(suffix, errmsg):
spec = importlib.machinery.PathFinder.find_spec("halide")
if spec is None or not spec.submodule_search_locations:
raise RuntimeError("halide python bindings not installed")
try:
search, *_ = importlib.machinery.PathFinder.find_spec( # type: ignore[union-attr,misc]
"halide"
).submodule_search_locations
search = spec.submodule_search_locations[0]
for file in os.listdir(search):
if file.endswith(".so"):
try:
Expand Down Expand Up @@ -3034,11 +3145,17 @@ def generate_halide_async(cls, meta: HalideMeta, source_code: str, submit_fn=Non
)
)

binding_types = [
arg.bindings_type() for arg in meta.argtypes if arg.alias_of is None
]
if meta.is_cuda():
binding_types.append("uintptr_t") # stream
bindings_future = cls.load_pybinding_async(
[arg.bindings_type() for arg in meta.argtypes],
cls._codegen_glue(meta.argtypes, headerfile),
extra_flags=(libfile,),
binding_types,
cls._codegen_glue(meta, headerfile),
extra_flags=(libfile, cls.build_standalone_runtime()),
submit_fn=jobs.append if need_compile else None,
cuda=meta.is_cuda(),
)

if need_compile:
Expand All @@ -3060,13 +3177,92 @@ def load():
def generate_halide(cls, *args, **kwargs):
return cls.generate_halide_async(*args, **kwargs)()

@classmethod
def build_standalone_runtime(cls):
if cls._standalone_runtime_path and os.path.exists(
cls._standalone_runtime_path
):
return cls._standalone_runtime_path
is_cuda = torch.cuda.is_available()
libname = "libStandaloneHalideRuntime.so"
target = "host-cuda" if is_cuda else "host"
if cls._standalone_runtime_path:
assert not os.path.exists(cls._standalone_runtime_path)
# We hit this case in unittests when we run with fresh_inductor_cache()
# Generating a fresh runtime over and over causes errors because we initialize
# cuda hundreds of times in the same process and run out of file descriptors.
# Workaround by jail breaking the current fresh_inductor_cache().
base = default_cache_dir()
else:
base = cache_dir()
dirpath = Path(base) / f"halide-runtime-{target}-{cls.config_hash()}"
os.makedirs(dirpath, exist_ok=True)
donefile = str(dirpath / "done")
lockfile = str(dirpath / "lock")
hookfile = str(dirpath / "hooks.cpp")
afile = str(dirpath / "standalone_halide_runtime.a")
sofile = str(dirpath / libname)
if not os.path.exists(donefile):
import filelock
import halide as hl # type: ignore[import-untyped]

Check failure on line 3207 in torch/_inductor/codecache.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [import-not-found]

Cannot find implementation or library stub for module named "halide"

with filelock.FileLock(lockfile, LOCK_TIMEOUT):
if not os.path.exists(donefile):
with open(hookfile, "w") as f:
if is_cuda:
f.write(
cls.standalone_runtime_cuda_init.format(
cls.find_header("HalideRuntimeCuda.h")
)
)
hl.compile_standalone_runtime(afile, hl.Target(target))
subprocess.check_call(
shlex.split(
cpp_compile_command([hookfile, afile], sofile, cuda=is_cuda)
)
)
touch(donefile)
assert os.path.exists(sofile)
cls._standalone_runtime_path = sofile
return sofile


def _worker_task_halide(lockfile, jobs):
from filelock import FileLock

with FileLock(lockfile, LOCK_TIMEOUT):
for job in jobs:
job()
try:
with FileLock(lockfile, LOCK_TIMEOUT):
for job in jobs:
job()
except subprocess.SubprocessError as e:
if os.environ.get("HALIDE_REPRO") == "1":
python, script, *cmd = getattr(e, "cmd", ("", "", ""))
if os.path.basename(python).startswith("python"):
code = open(script).read()
main = " hl.main()"
assert code.count(main) == 1

class Out:
def __repr__(self):
return "out"

cmd[cmd.index("-o") + 1] = Out() # type: ignore[call-overload]
repl = textwrap.indent(
textwrap.dedent(
f"""\
import sys, tempfile
with tempfile.TemporaryDirectory() as out:
sys.argv = {["repro.py", *cmd]!r}
hl.main()
"""
),
" ",
)
code = code.replace(main, repl)
with open("repro.py", "w") as fd:
fd.write(code.lstrip())
raise RuntimeError(f"wrote repro.py: {e}") from e
raise


def touch(filename):
Expand Down
7 changes: 6 additions & 1 deletion torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,19 @@ class TensorArg:
name: str
buffer: str
dtype: torch.dtype
offset: sympy.Expr = sympy.Integer(0)
offset: sympy.Expr = sympy.Integer(0) # c++ only
alias_of: Optional[str] = None # halide only


@dataclasses.dataclass
class SizeArg:
name: str
expr: sympy.Expr

@property
def alias_of(self):
return None


@dataclasses.dataclass
class DeviceCodegen:
Expand Down
Loading

0 comments on commit aa6de48

Please sign in to comment.