Skip to content

Commit

Permalink
small cleanup of executor (#79973)
Browse files Browse the repository at this point in the history
per title
Pull Request resolved: #79973
Approved by: https://github.com/mruberry
  • Loading branch information
ngimel authored and pytorchmergebot committed Jun 22, 2022
1 parent ec4be38 commit 9244547
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion test/test_prims.py
Expand Up @@ -136,7 +136,7 @@ def _wrapper(a):
traced = make_traced(_wrapper)
make_arg = partial(make_tensor, device=device, dtype=dtype)

for executor in ('aten',): # FIXME test fails on nvfuser executor
for executor in ('aten', 'nvfuser'):
fn = partial(traced, executor=executor)
shape = (5, 5)
a = make_arg(shape)
Expand Down
7 changes: 3 additions & 4 deletions torch/_prims/executor.py
Expand Up @@ -13,7 +13,7 @@
from torch._C._nvfuser import Fusion, FusionDefinition # type: ignore[import]


def execute(gm: GraphModule, *args, executor: str = "aten", **kwargs):
def execute(gm: GraphModule, *args, executor: str = "aten"):
"""
Prototype ATen executor.
Expand Down Expand Up @@ -45,7 +45,7 @@ def call_function(self, target, args, kwargs):
args = tuple(map(_to_nvfuser_constant, args))
target = target.impl_nvfuser
args = (fd,) + args
return target(*args)
return target(*args, **kwargs)

def to_nv(arg):
if isinstance(arg, torch.Tensor):
Expand All @@ -60,8 +60,7 @@ def to_nv(arg):
# Transforms graph to call nvfuser lowerings
# Note, this doesn't handle nested structures in the args, TODO: add tree_flatten
nv_args = tree_map(to_nv, args)
nv_kwargs = tree_map(to_nv, kwargs)
out = FusionInterpreter(gm).run(*nv_args, **nv_kwargs)
out = FusionInterpreter(gm).run(*nv_args)
flat_out, unflatten_spec = tree_flatten(out)
for o in flat_out:
fd.add_output(o)
Expand Down

0 comments on commit 9244547

Please sign in to comment.