Skip to content

Commit

Permalink
Enable optional tensor args with requires_grad=False
Browse files Browse the repository at this point in the history
  • Loading branch information
vasunvidia committed Jun 7, 2024
1 parent 090e724 commit d0a1057
Showing 1 changed file with 82 additions and 7 deletions.
89 changes: 82 additions & 7 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.utils._pytree import tree_flatten as _tree_flatten
from torch.utils._pytree import tree_unflatten as _tree_unflatten
from torch._C import _graph_pool_handle
from typing import Dict

from .fp8 import (
fp8_autocast,
Expand Down Expand Up @@ -53,6 +54,7 @@ def _make_graphed_callables(
num_warmup_iters=3,
allow_unused_input=False,
fp8_weight_caching=False,
optional_args=None,
_order=None,
):
"""
Expand All @@ -71,8 +73,12 @@ def _make_graphed_callables(
just_one_callable = True
callables = (callables,)
sample_args = (sample_args,)
if optional_args is not None:
optional_args = (optional_args,)

flatten_sample_args = []
flatten_optional_keys = []
flatten_optional_args = []
if _order is not None:
# order is a list containing 1..model_chunk values in the order of microbatch schedule
num_model_chunks = max(_order)
Expand All @@ -93,6 +99,10 @@ def _make_graphed_callables(
), (f"Expected {num_model_chunks * num_microbatches}"
+ f"args tuple, but got {len(sample_args)}."
)
assert (
optional_args is None
or len(optional_args) == len(sample_args)
), (f"Number of optional_args does not match with sample_args.")

if fp8_weight_caching:
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False)
Expand Down Expand Up @@ -120,6 +130,29 @@ def _make_graphed_callables(
"In the beta API, sample_args "
+ "for each callable must contain only Tensors. Other types are not allowed."
)
flatten_optional_args = []
if optional_args is not None:
per_callable_len_optional_user_args = []
for i, optional_arg in enumerate(optional_args):
keys = list(optional_arg.keys())
optional_tensors = tuple(optional_arg.values())
assert all(isinstance(arg, torch.Tensor) for arg in optional_tensors), (
"In the beta API, optional_args "
+ "for each callable must contain only Tensors. Other types are not allowed."
)
assert all(not arg.requires_grad for arg in optional_tensors), (
"In the beta API, optional_args gradients not supported. "
+ " optional_args should have requires_grad set to False."
)
if i==0:
flatten_optional_keys = keys
assert flatten_optional_keys == keys, "Key mismatch detected for optional_args."
flatten_arg, _ = _tree_flatten(optional_tensors)
flatten_optional_args.append(tuple(flatten_arg))
per_callable_len_optional_user_args.append(len(flatten_arg))
else:
per_callable_len_optional_user_args = [0 for args in flatten_sample_args]


# If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
# passes to forward (ie, its sample_args) AND the module's parameter attributes.
Expand All @@ -133,6 +166,12 @@ def _make_graphed_callables(
flatten_sample_args[i] + per_callable_module_params[i]
for i in range(len(callables))
]
if len(flatten_optional_args) != 0:
per_callable_static_optional_input_surfaces = [
flatten_optional_args[i] for i in range(len(callables))
]
else:
per_callable_static_optional_input_surfaces = [() for i in range(len(callables))]
else:
per_callable_module_params = []
for c in callables:
Expand All @@ -145,6 +184,12 @@ def _make_graphed_callables(
flatten_sample_args[i] + per_callable_module_params[i]
for i in range(len(flatten_sample_args))
]
if len(flatten_optional_args) != 0:
per_callable_static_optional_input_surfaces = [
flatten_optional_args[i] for i in range(len(flatten_sample_args))
]
else:
per_callable_static_optional_input_surfaces = [() for i in range(len(flatten_sample_args))]

fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
Expand All @@ -165,9 +210,10 @@ def _make_graphed_callables(
with torch.cuda.stream(torch.cuda.Stream()):
for c_i, func in enumerate(callables):
args = sample_args[c_i]
opt_args = optional_args[c_i] if optional_args is not None else {}
static_input_surface = per_callable_static_input_surfaces[c_i]
for _ in range(num_warmup_iters):
outputs, _ = _tree_flatten(func(*args))
outputs, _ = _tree_flatten(func(*args, **opt_args))
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
Expand Down Expand Up @@ -200,9 +246,10 @@ def _make_graphed_callables(
per_callable_fwd_idx = (m_chunk * num_microbatches * num_layers) \
+ (fwd_idx[m_chunk] * num_layers + l_no)
args = sample_args[per_callable_fwd_idx]
opt_args = optional_args[per_callable_fwd_idx] if optional_args is not None else {}
fwd_graph = fwd_graphs[per_callable_fwd_idx]
with torch.cuda.graph(fwd_graph, pool=mempool):
outputs = func(*args)
outputs = func(*args, **opt_args)
flatten_outputs, spec = _tree_flatten(outputs)
per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs)
per_callable_output_unflatten_spec[per_callable_fwd_idx] = spec
Expand All @@ -215,6 +262,7 @@ def _make_graphed_callables(
per_callable_bwd_idx = (m_chunk * num_microbatches * num_layers) \
+ (bwd_idx[m_chunk] * num_layers + l_no)
static_input_surface = per_callable_static_input_surfaces[per_callable_bwd_idx]

static_outputs = per_callable_static_outputs[per_callable_bwd_idx]
bwd_graph = bwd_graphs[per_callable_bwd_idx]
# For now, assumes all static_outputs require grad
Expand Down Expand Up @@ -250,9 +298,13 @@ def _make_graphed_callables(
per_callable_static_outputs = []
per_callable_output_unflatten_spec = []
graph_id = 0
for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs):
if optional_args is None:
opt_args = [{} for _ in range(len(sample_args))]
else:
opt_args = optional_args
for func, args, opt_arg, fwd_graph in zip(callables, sample_args, opt_args, fwd_graphs):
with torch.cuda.graph(fwd_graph, pool=mempool):
outputs = func(*args)
outputs = func(*args, **opt_arg)
graph_callables[graph_id] = func
graph_id += 1

Expand Down Expand Up @@ -306,8 +358,10 @@ def make_graphed_autograd_function(
bwd_graph,
module_params,
len_user_args,
len_optional_user_args,
output_unflatten_spec,
static_input_surface,
static_optional_input_surface,
static_outputs,
static_grad_outputs,
static_grad_inputs,
Expand All @@ -324,6 +378,10 @@ def forward(ctx, skip_fp8_weight_update, *inputs):
for i in range(len_user_args):
if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
static_input_surface[i].copy_(inputs[i])
for i in range(len_optional_user_args):
input_idx = len(inputs) - len_optional_user_args + i
assert static_optional_input_surface[i].shape == inputs[input_idx].shape and static_optional_input_surface[i].dtype == inputs[input_idx].dtype
static_optional_input_surface[i].copy_(inputs[input_idx])
fwd_graph.replay()
assert isinstance(static_outputs, tuple)
return tuple(o.detach() for o in static_outputs)
Expand All @@ -347,7 +405,7 @@ def backward(ctx, *grads):
assert isinstance(static_grad_inputs, tuple)
return (None,) + tuple(
b.detach() if b is not None else b for b in static_grad_inputs
)
) + tuple(None for _ in range(len_optional_user_args))

def functionalized(*user_args, **user_kwargs):
# Runs the autograd function with inputs == all
Expand All @@ -364,7 +422,21 @@ def functionalized(*user_args, **user_kwargs):
skip_fp8_weight_update = not user_kwargs["is_first_microbatch"]

flatten_user_args, _ = _tree_flatten(user_args)
out = Graphed.apply(skip_fp8_weight_update, *(tuple(flatten_user_args) + module_params))
if optional_args is not None:
assert (
("optional_args" in user_kwargs
and isinstance(user_kwargs["optional_args"], Dict))
), "`optional_args` Dict kwarg missing during replay."
assert all(key in user_kwargs["optional_args"] for key in flatten_optional_keys), (
f"Missing optional_arg during replay. Expected optional_args are {flatten_optional_keys}."
)
assert all(not user_kwargs["optional_args"][key].requires_grad for key in flatten_optional_keys), (
f"optional_arg has requires_grad=True during replay. This is not supported."
)
flatten_optional_args, _ = _tree_flatten(tuple([user_kwargs["optional_args"][key] for key in flatten_optional_keys]))
else:
flatten_optional_args = []
out = Graphed.apply(skip_fp8_weight_update, *(tuple(flatten_user_args) + module_params + tuple(flatten_optional_args)))
return _tree_unflatten(out, output_unflatten_spec)

return functionalized
Expand All @@ -377,8 +449,10 @@ def functionalized(*user_args, **user_kwargs):
bwd_graphs[i],
per_callable_module_params[i],
per_callable_len_user_args[i],
per_callable_len_optional_user_args[i],
per_callable_output_unflatten_spec[i],
per_callable_static_input_surfaces[i],
per_callable_static_optional_input_surfaces[i],
per_callable_static_outputs[i],
per_callable_static_grad_outputs[i],
per_callable_static_grad_inputs[i],
Expand Down Expand Up @@ -452,6 +526,7 @@ def make_graphed_callables(
fp8_calibrating=False,
fp8_recipe=None,
fp8_weight_caching=False,
optional_args=None,

This comment has been minimized.

Copy link
@timmoon10

timmoon10 Jun 13, 2024

I'd prefer sample_kwargs for consistency with sample_args. Also, the name is misleading since these args won't be optional when the graph is replayed.

_order=None,
):
"""
Expand Down Expand Up @@ -527,7 +602,7 @@ def forward_func(*args, **kwargs):
graphed_callables = _make_graphed_callables(
forward_funcs, sample_args, num_warmup_iters=num_warmup_iters,
allow_unused_input=allow_unused_input,
fp8_weight_caching=fp8_weight_caching, _order=_order)
fp8_weight_caching=fp8_weight_caching, optional_args=optional_args, _order=_order)

# Ensures warmup does not affect numerics for ops such as dropout.
if graph_safe_rng_available():
Expand Down

0 comments on commit d0a1057

Please sign in to comment.