Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always create ShapeEnv, always apply unspec logic #103302

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 6 additions & 1 deletion benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2366,7 +2366,12 @@ def run(runner, args, original_dir=None):
torch.use_deterministic_algorithms(True)
if args.only in {"hf_T5_generate"}:
# See https://github.com/pytorch/pytorch/issues/102814
torch._dynamo.config.assume_static_by_default = False
if torch._dynamo.config.dynamic_shapes:
torch._dynamo.config.assume_static_by_default = False
if not torch._dynamo.config.automatic_dynamic_shapes:
log.warning(
"hf_T5_generate compiles extremely slowly without dynamic shapes; consider lowering cache_size_limit"
)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.allow_tf32 = False
Expand Down
8 changes: 4 additions & 4 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,10 +988,10 @@ def result_capturing_wrapper(*graph_inputs):
remove_from_cache(f)

if (
shape_env := getattr(fake_mode, "shape_env", None)
) is not None and not skipfiles.check(inspect.getsourcefile(call_to_inspect)):
dim_constraints = shape_env.dim_constraints
assert dim_constraints is not None
(shape_env := getattr(fake_mode, "shape_env", None)) is not None
and (dim_constraints := shape_env.dim_constraints) is not None
and not skipfiles.check(inspect.getsourcefile(call_to_inspect))
):
dim_constraints.solve()
msg = dim_constraints.prettify_results(original_signature)
forced_specializations = dim_constraints.forced_specializations()
Expand Down
4 changes: 1 addition & 3 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,7 @@ def __init__(
allow_scalar_outputs=config.capture_scalar_outputs,
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
frame_id=frame_state["_id"],
)
if config.dynamic_shapes
else None,
),
# TODO (tmanlaibaatar) Remove this once we always lift params and buffers
allow_non_fake_inputs=True if self.export else False,
)
Expand Down
96 changes: 46 additions & 50 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def wrap_module(self, value: torch.nn.Module):
)

def wrap_literal(self, value):
unspec = not config.specialize_int and config.dynamic_shapes
unspec = not config.specialize_int
if unspec and type(value) is torch.Size:
return SizeVariable(
[
Expand Down Expand Up @@ -930,8 +930,7 @@ def wrap_unspecialized_primitive(self, value):
# but the general idea is that we generate kernels that can
# take unspecialized floats and use them in sizevar computation
if (
config.dynamic_shapes
and isinstance(value, int)
isinstance(value, int)
and not is_constant_source(self.get_source())
and not isinstance(self.get_source(), RandomValueSource)
):
Expand Down Expand Up @@ -1218,10 +1217,9 @@ def _clone_input(value):
elif istype(example_value, (list, immutable_list)):
return ListVariable(unpacked, mutable_local=MutableLocal(), **options)
else:
assert (
example_value.__class__.__module__ == "torch.return_types"
or hasattr(example_value, "_fields")
), ("namedtuple?")
assert example_value.__class__.__module__ == "torch.return_types" or hasattr(
example_value, "_fields"
), f"expected {example_value.__class__.__module__} == torch.return_types or named tuple but got {type(example_value)}"
return NamedTupleVariable(unpacked, example_value.__class__, **options)
elif example_value is None or proxy.node.target is torch.manual_seed:
return ConstantVariable(None, **options)
Expand Down Expand Up @@ -1338,51 +1336,49 @@ def update_dim2constraint(dim, constraint_range):
constraint.shared.dim, constraint.constraint_range
)

dynamic_dims = None
constraint_dims = None
if tx.fake_mode.shape_env is not None:
dynamic_dims = []
constraint_dims = []
for i in range(e.dim()):
# NB: mark dynamic has precedence over static
marked_dynamic = i in getattr(e, "_dynamo_dynamic_indices", set())
marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set())
marked_static = i in getattr(e, "_dynamo_static_indices", set())

# NB: both static and dynamic have precedence over
automatic_dynamic = config.automatic_dynamic_shapes and (
frame_state_entry.size is None or frame_state_entry.size[i] is None
)
dynamic_dims = []
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is reindentation only, as shape_env is always non-None.

constraint_dims = []
for i in range(e.dim()):
# NB: mark dynamic has precedence over static
marked_dynamic = i in getattr(e, "_dynamo_dynamic_indices", set())
marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set())
marked_static = i in getattr(e, "_dynamo_static_indices", set())

# NB: both static and dynamic have precedence over
automatic_dynamic = config.automatic_dynamic_shapes and (
frame_state_entry.size is None or frame_state_entry.size[i] is None
)

# Reflect the user directive in the frame_state
# For dynamic, apply None always
if frame_state_entry.size and marked_dynamic:
frame_state_entry.size[i] = None

# We will process constraints first, as they will imply that we
# have a dynamic dimension
# Precedence: export constraints > eager constraints
constraint = dim2constraint.get(i)
if constraint is None:
if marked_dynamic and not config.allow_ignore_mark_dynamic:
constraint = RelaxedUnspecConstraint(warn_only=False)
elif not marked_static and automatic_dynamic:
constraint = RelaxedUnspecConstraint(warn_only=True)
constraint_dims.append(constraint)

# Now, figure out if the dim is dynamic/duck/static
if constraint is not None or marked_dynamic or marked_weak_dynamic:
# NB: We could assert static_shapes is False here, but it
# seems better to allow the user to override policy in this
# case
dynamic = DimDynamic.DYNAMIC
elif static_shapes or config.assume_static_by_default or marked_static:
dynamic = DimDynamic.STATIC
else:
dynamic = DimDynamic.DUCK
dynamic_dims.append(dynamic)
# Reflect the user directive in the frame_state
# For dynamic, apply None always
if frame_state_entry.size and marked_dynamic:
frame_state_entry.size[i] = None

# We will process constraints first, as they will imply that we
# have a dynamic dimension
# Precedence: export constraints > eager constraints
constraint = dim2constraint.get(i)
if constraint is None:
if marked_dynamic and not config.allow_ignore_mark_dynamic:
constraint = RelaxedUnspecConstraint(warn_only=False)
elif not marked_static and automatic_dynamic:
constraint = RelaxedUnspecConstraint(warn_only=True)
constraint_dims.append(constraint)

# Now, figure out if the dim is dynamic/duck/static
if constraint is not None or marked_dynamic or marked_weak_dynamic:
# NB: We could assert static_shapes is False here, but it
# seems better to allow the user to override policy in this
# case
dynamic = DimDynamic.DYNAMIC
elif static_shapes or config.assume_static_by_default or marked_static:
dynamic = DimDynamic.STATIC
else:
dynamic = DimDynamic.DUCK

dynamic_dims.append(dynamic)

tx.output.frame_state[name] = frame_state_entry
tx.output.frame_state[name] = frame_state_entry

return dynamic_dims, constraint_dims

Expand Down
8 changes: 7 additions & 1 deletion torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -2133,7 +2133,13 @@ def current_reduction_nodes(nodes):
kernel.set_last_usage(current_reduction_nodes(node_schedule[i:]))
else:
# TODO - mostly works but needs a couple fixes
if not dynamo_config.dynamic_shapes:
# Problem looks like free variables NYI: s0
# We need to detect if the proposed ranges would have
# symbols and bail out on this optimization if so
if (
not dynamo_config.dynamic_shapes
and dynamo_config.assume_static_by_default
):
# TODO - use split ranges ?
indexing_dtype_strength_reduction(node._body)
index_vars = kernel.split_and_set_ranges(node.get_ranges())
Expand Down
14 changes: 10 additions & 4 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2043,6 +2043,14 @@ def create_symbolic_sizes_strides_storage_offset(
dynamic_dims.append(r)
dynamic_dims = [DimDynamic.DUCK] * dim

# TODO: make this configurable from outside policy; we made a policy
# decision here where if all sizes are static, we are going to
# specialize all of the inner strides/offset too. We don't have to
# do this, and arguably we should ALWAYS allow for dynamic offset,
# this is cheap.
# TODO: This should be DYNAMIC, using DUCK for BC
dynamic_strides_offset = DimDynamic.STATIC if all(r == DimDynamic.STATIC for r in dynamic_dims) else DimDynamic.DUCK

assert len(dynamic_dims) == dim
assert len(constraint_dims) == dim

Expand Down Expand Up @@ -2078,8 +2086,7 @@ def create_symbolic_sizes_strides_storage_offset(
stride[i] = self.create_symbol(
val,
TensorPropertySource(source, TensorProperty.STRIDE, i),
# TODO: This should be DYNAMIC, using DUCK for BC
dynamic_dim=DimDynamic.DUCK,
dynamic_dim=dynamic_strides_offset,
constraint_dim=None,
)
assert all(x is not None for x in stride)
Expand All @@ -2094,8 +2101,7 @@ def create_symbolic_sizes_strides_storage_offset(
sym_storage_offset = self.create_symintnode(self.create_symbol(
ex.storage_offset(),
TensorPropertySource(source, TensorProperty.STORAGE_OFFSET),
# TODO: This should be DYNAMIC, using DUCK for BC
dynamic_dim=DimDynamic.DUCK,
dynamic_dim=dynamic_strides_offset,
constraint_dim=None,
), hint=ex.storage_offset())
return sym_sizes, sym_stride, sym_storage_offset
Expand Down
4 changes: 3 additions & 1 deletion torch/fx/passes/fake_tensor_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def extract_val(obj):
if isinstance(obj, FakeTensor):
return snapshot_fake(obj)
elif isinstance(obj, torch.Tensor):
return snapshot_fake(self._mode.from_tensor(obj))
# TODO: How is it possible that we get a non fake tensor? We
# should be running under the mode...
return snapshot_fake(self._mode.from_tensor(obj, static_shapes=True))
elif isinstance(obj, py_sym_types):
return obj
else:
Expand Down
2 changes: 1 addition & 1 deletion torch/fx/passes/shape_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def propagate(self, *args):
Any: The value returned from executing the Module
"""
if self.fake_mode is not None:
fake_args = [self.fake_mode.from_tensor(t) for t in args]
fake_args = [self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args]
else:
fake_args = args
return super().run(*fake_args)