Skip to content

Commit

Permalink
needs work
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Apr 22, 2024
1 parent 9c2ac44 commit 5ea3c59
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 7 deletions.
2 changes: 1 addition & 1 deletion torch/_higher_order_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .cond import cond
from .while_loop import while_loop
from .templated_attention import templated_attention
# from .templated_attention import templated_attention
165 changes: 159 additions & 6 deletions torch/_higher_order_ops/templated_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,6 @@ def sdpa_dense(
return out, lse


# TODO We need to implement an autograd function for this, there is some complexity to do this generically
templated_attention.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(templated_attention, deferred_error=True)
)


def trace_templated_attention(
proxy_mode: ProxyTorchDispatchMode,
query: torch.Tensor,
Expand Down Expand Up @@ -218,3 +212,162 @@ def templated_attention_fake_tensor_mode(
batch_size, num_heads, seq_len_q, dtype=torch.float32
)
return torch.empty_like(query, memory_format=torch.contiguous_format), logsumexp

# ---------------------------- Autograd Implementation ----------------------------
# # TODO We need to implement an autograd function for this, there is some complexity to do this generically
# templated_attention.py_impl(DispatchKey.Autograd)(
# autograd_not_implemented(templated_attention, deferred_error=True)
# )

from torch._dispatch.python import suspend_functionalization
from torch._functorch.aot_autograd import AOTConfig, create_joint, from_fun, default_partition

from torch._subclasses.functional_tensor import (
disable_functional_mode,
FunctionalTensor,
)
from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing
# from torch._higher_order_ops.utils import _unstack_pytree, _stack_pytree
from torch.multiprocessing.reductions import StorageWeakRef


dummy_aot_config = AOTConfig(
fw_compiler=None, # type: ignore[arg-type]
bw_compiler=None, # type: ignore[arg-type]
partition_fn=None, # type: ignore[arg-type]
decompositions={},
num_params_buffers=0,
aot_id=0,
keep_inference_input_mutations=False,
)


def create_fw_bw_graph(score_mod: Callable, index_values, other_buffers):

# Note: We create "clean" environments for make_fx by suspending all dispatch keys
# between Autograd and Python key. Currently, we only suspend functionalization but more can be
# added when required. Will encounter two problems if we don't suspend functionalization:
#
# 1. make_fx fails to capture operations on input: the inputs are wrapped as _to_functional_tensor_wrapper,
# but they will be unwrapped before entering ProxyTorchDispatchMode as part of the dispatching.
# However, it's the outside wrapper that tracer creates proxies for. This casuses tracer fail to
# fetch the proxy for the inputs and fail to capture any operations on them.
#
# 2. make_fx fails to capture output: the outputs after ProxyTorchDispatchMode are further
# wrapped as FunctionalTensorWrapper in Functionalize key after return. However, the tracer
# only associates the inner tensor with proxy in ProxyTorchDispatchMode. Therefore,
# when creating the output node, it fails to associate the wrapped tensor with its proxy.
# Instead, it will create _tensor_constant as output.

with suspend_functionalization(), disable_functional_mode():
with disable_proxy_modes_tracing():
assert len(other_buffers) == 0, "Other buffers are not yet supported. We will properly generate the graph for them later."

def _from_fun(t):
if isinstance(t, torch.Tensor):
if t.dtype != torch.bool:
return torch.empty_strided(
t.size(),
t.stride(),
dtype=t.dtype,
requires_grad=t.requires_grad,
)
else:
# clone of a functional tensor produces a functional tensor
# but we want to avoid it so we clone a non-functional version
maybe_unfunc_t = t
if isinstance(t, FunctionalTensor):
torch._sync(t)
maybe_unfunc_t = from_fun(t)
elif torch._is_functional_tensor(t):
# need to handle both types of functionalization here:
# these are the tensors that came from the user,
# which could be either FunctionalTensorWrapper or FunctionalTensor
torch._sync(t)
maybe_unfunc_t = torch._from_functional_tensor(t)
return maybe_unfunc_t.clone()
return t

unwrapped_score_mod_indexes = pytree.tree_map(_from_fun, index_values)
unwrapped_other_buffers = pytree.tree_map(_from_fun, other_buffers)
example_flat_out = pytree.tree_map(
_from_fun, score_mod(*unwrapped_score_mod_indexes, *unwrapped_other_buffers)
)
if not isinstance(example_flat_out, torch.Tensor):
raise RuntimeError(
"Expected output of score_mod to be a tensor."
f"Got type {type(example_flat_out)}."
)
example_grad = _from_fun(example_flat_out)

fw_graph = make_fx(score_mod)(*unwrapped_score_mod_indexes, *unwrapped_other_buffers)

def joint_f(index_values, other_buffers, example_grad):
def fw_with_masks(*args):
fw_out = score_mod(*args)
out_requires_grad = fw_out.requires_grad
return ((fw_out,), (out_requires_grad,))

joint = create_joint(fw_with_masks, aot_config=dummy_aot_config)
args = index_values + list(other_buffers)
_, grads = joint(args, [example_grad])

# In order to keep map functional for backward graph,
# we clone outputs that are aliasing inputs
# input_storage = {
# StorageWeakRef(arg._typed_storage())
# for arg in example_args
# if isinstance(arg, torch.Tensor)
# }

return grads
# return pytree.tree_map(maybe_clone, grads)

joint_graph = make_fx(joint_f)(unwrapped_score_mod_indexes, unwrapped_other_buffers, example_grad)
fwd_graph, bwd_graph = default_partition(joint_graph, (unwrapped_score_mod_indexes, unwrapped_other_buffers, example_grad), num_fwd_outputs=1)
return fw_graph, joint_graph

# from torch._higher_order_ops.map import create_fw_bw_graph
class TemplatedAttentionAutogradOp(torch.autograd.Function):
@staticmethod
def forward(ctx, fw_graph, joint_graph, *flat_args):
ctx.save_for_backward(*flat_args)
ctx._joint_graph = joint_graph
# ctx._num_mapped_args = num_mapped_args
with torch._C._AutoDispatchBelowAutograd():
query, key, value, score_mod, *other_buffers = flat_args

# Need to have out and logsumexp returned
return templated_attention(query, key, value, score_mod=fw_graph, *other_buffers)

Check failure on line 341 in torch/_higher_order_ops/templated_attention.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [misc]

"__call__" of "TemplatedAttentionHOP" gets multiple values for keyword argument "score_mod"

@staticmethod
def backward(ctx, *flat_grads):
fw_args = ctx.saved_tensors
fw_mapped_args = fw_args[: ctx._num_mapped_args]
pos_args = fw_args[ctx._num_mapped_args :]
raise NotImplementedError("Need to implement this")
return flat_grads

# grads = templated_attention_backward()
# grads = map_impl(
# ctx._joint_graph,
# fw_mapped_args + flat_grads,
# pos_args,
# )
# return *grads

@templated_attention.py_impl(DispatchKey.Autograd)
def templated_attention_autograd(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
score_mod: Callable,
*other_buffers: Tuple[torch.Tensor, ...],
) -> torch.Tensor:
input_requires_grad = query.requires_grad or key.requires_grad
example_vals = [torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad)] + [
torch.zeros((), dtype=torch.int) for _ in range(4)
]
fw_graph, bw_graph = create_fw_bw_graph(score_mod, example_vals, other_buffers)
# flat_out = TemplatedAttentionAutogradOp.apply(fw_graph, bw_graph, num_mapped_args, *xs, *pos_args)
return flat_out

Check failure on line 373 in torch/_higher_order_ops/templated_attention.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [name-defined]

Name "flat_out" is not defined

0 comments on commit 5ea3c59

Please sign in to comment.