Skip to content

Commit

Permalink
Set default to compile to eager and convert inputs to faketensors
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Apr 26, 2024
1 parent e5255f5 commit 43cd623
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 30 deletions.
7 changes: 5 additions & 2 deletions test/inductor/test_templated_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,10 @@ def test_aot_eager_gradcheck(self, score_mod):

@supported_platform
@common_utils.parametrize("score_mod_name", ["_head_offset", "_buffer_reduced"])
def test_captured_score_mod_aot_eager_gradcheck(self, score_mod_name: str):
@common_utils.parametrize("mode", ["eager", "aot_eager"])
def test_captured_score_mod_aot_eager_gradcheck(
self, score_mod_name: str, mode: str
):
make_tensor = functools.partial(
torch.randn,
(2, 2, 8, 4),
Expand All @@ -465,7 +468,7 @@ def test_captured_score_mod_aot_eager_gradcheck(self, score_mod_name: str):
)
query, key, value = make_tensor(), make_tensor(), make_tensor()

func = torch.compile(_templated_attention, backend="aot_eager", fullgraph=True)
func = torch.compile(_templated_attention, backend=mode, fullgraph=True)
score_mod = captured_buffers_map[score_mod_name](torch.float64)

self.assertTrue(
Expand Down
45 changes: 17 additions & 28 deletions torch/_higher_order_ops/templated_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,30 +251,14 @@ def templated_attention_fake_tensor_mode(
return torch.empty_like(query, memory_format=torch.contiguous_format), logsumexp


def is_fake_tensor(t: torch.Tensor):
"""Why not use is_fake in fake_tensor?
That is specifically designed to pick up on traceable wrapper subclasses so instead we
define this one off
"""
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor

return (
isinstance(t, torch.Tensor)
and isinstance(t, FunctionalTensor)
and torch._is_functional_tensor(t.elem)
and isinstance(torch._from_functional_tensor(t.elem), FakeTensor)
)


# ---------------------------- Autograd Implementation ----------------------------
def create_fw_bw_graph(score_mod, index_values, other_buffers):
# See Note:[HOP create fw_bw graph]

# All of these imports need to be here in order to avoid circular dependencies
from torch._dispatch.python import suspend_functionalization
from torch._functorch.aot_autograd import AOTConfig, create_joint
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode

from torch._subclasses.functional_tensor import disable_functional_mode
from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing
Expand All @@ -288,15 +272,6 @@ def create_fw_bw_graph(score_mod, index_values, other_buffers):
aot_id=0,
keep_inference_input_mutations=False,
)
assert all(is_fake_tensor(t) for t in index_values), (
"Expected all index_values to create_fw_bw_graph to be FakeTensors! ",
"Ensure that FlexAttention was called with backend >= aot_eager",
)

assert all(is_fake_tensor(t) for t in other_buffers), (
"Expected all other_buffers to create_fw_bw_graph to be FakeTensors! ",
"Ensure that FlexAttention was called with backend >= aot_eager",
)

with suspend_functionalization(), disable_functional_mode():
with disable_proxy_modes_tracing():
Expand All @@ -310,8 +285,22 @@ def _from_fun(t):
requires_grad=t.requires_grad,
)

unwrapped_score_mod_indexes = pytree.tree_map(_from_fun, index_values)
unwrapped_other_buffers = pytree.tree_map(_from_fun, other_buffers)
# If someone runs this hop under the default compiler backend ("eager")
# Then this path will be run with the actual user inputs. We convert them
# to fake tensors in order to not perform any actual compute.
maybe_tracing = torch._guards.TracingContext.try_get()
fake_mode = (
maybe_tracing.fake_mode
if maybe_tracing
else FakeTensorMode(allow_non_fake_inputs=True)
)

with fake_mode:
unwrapped_score_mod_indexes = pytree.tree_map(_from_fun, index_values)
unwrapped_other_buffers = pytree.tree_map(_from_fun, other_buffers)

assert all(isinstance(t, FakeTensor) for t in unwrapped_score_mod_indexes)
assert all(isinstance(t, FakeTensor) for t in unwrapped_other_buffers)

example_flat_out = pytree.tree_map(
_from_fun,
Expand Down

0 comments on commit 43cd623

Please sign in to comment.