Skip to content

Commit

Permalink
Always apply unspec logic
Browse files Browse the repository at this point in the history
Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 924899944d9c2c61c4e28edb5f8df14f824f215e
Pull Request resolved: #103302
  • Loading branch information
ezyang committed Jun 11, 2023
1 parent 03101a2 commit 090166d
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 15 deletions.
7 changes: 6 additions & 1 deletion benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2366,7 +2366,12 @@ def run(runner, args, original_dir=None):
torch.use_deterministic_algorithms(True)
if args.only in {"hf_T5_generate"}:
# See https://github.com/pytorch/pytorch/issues/102814
torch._dynamo.config.assume_static_by_default = False
if torch._dynamo.config.dynamic_shapes:
torch._dynamo.config.assume_static_by_default = False
if not torch._dynamo.config.automatic_dynamic_shapes:
log.warning(
"hf_T5_generate compiles extremely slowly without dynamic shapes; consider lowering cache_size_limit"
)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.allow_tf32 = False
Expand Down
8 changes: 4 additions & 4 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,10 +988,10 @@ def result_capturing_wrapper(*graph_inputs):
remove_from_cache(f)

if (
shape_env := getattr(fake_mode, "shape_env", None)
) is not None and not skipfiles.check(inspect.getsourcefile(call_to_inspect)):
dim_constraints = shape_env.dim_constraints
assert dim_constraints is not None
(shape_env := getattr(fake_mode, "shape_env", None)) is not None
and (dim_constraints := shape_env.dim_constraints) is not None
and not skipfiles.check(inspect.getsourcefile(call_to_inspect))
):
dim_constraints.solve()
msg = dim_constraints.prettify_results(original_signature)
forced_specializations = dim_constraints.forced_specializations()
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def __init__(
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
frame_id=frame_state["_id"],
)
if config.dynamic_shapes
if config.dynamic_shapes or config.automatic_dynamic_shapes
else None,
# TODO (tmanlaibaatar) Remove this once we always lift params and buffers
allow_non_fake_inputs=True if self.export else False,
Expand Down
12 changes: 5 additions & 7 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def wrap_module(self, value: torch.nn.Module):
)

def wrap_literal(self, value):
unspec = not config.specialize_int and config.dynamic_shapes
unspec = not config.specialize_int
if unspec and type(value) is torch.Size:
return SizeVariable(
[
Expand Down Expand Up @@ -930,8 +930,7 @@ def wrap_unspecialized_primitive(self, value):
# but the general idea is that we generate kernels that can
# take unspecialized floats and use them in sizevar computation
if (
config.dynamic_shapes
and isinstance(value, int)
isinstance(value, int)
and not is_constant_source(self.get_source())
and not isinstance(self.get_source(), RandomValueSource)
):
Expand Down Expand Up @@ -1218,10 +1217,9 @@ def _clone_input(value):
elif istype(example_value, (list, immutable_list)):
return ListVariable(unpacked, mutable_local=MutableLocal(), **options)
else:
assert (
example_value.__class__.__module__ == "torch.return_types"
or hasattr(example_value, "_fields")
), ("namedtuple?")
assert example_value.__class__.__module__ == "torch.return_types" or hasattr(
example_value, "_fields"
), f"expected {example_value.__class__.__module__} == torch.return_types or named tuple but got {type(example_value)}"
return NamedTupleVariable(unpacked, example_value.__class__, **options)
elif example_value is None or proxy.node.target is torch.manual_seed:
return ConstantVariable(None, **options)
Expand Down
8 changes: 7 additions & 1 deletion torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -2133,7 +2133,13 @@ def current_reduction_nodes(nodes):
kernel.set_last_usage(current_reduction_nodes(node_schedule[i:]))
else:
# TODO - mostly works but needs a couple fixes
if not dynamo_config.dynamic_shapes:
# Problem looks like free variables NYI: s0
# We need to detect if the proposed ranges would have
# symbols and bail out on this optimization if so
if (
not dynamo_config.dynamic_shapes
and dynamo_config.assume_static_by_default
):
# TODO - use split ranges ?
indexing_dtype_strength_reduction(node._body)
index_vars = kernel.split_and_set_ranges(node.get_ranges())
Expand Down
2 changes: 1 addition & 1 deletion torch/fx/passes/shape_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def propagate(self, *args):
Any: The value returned from executing the Module
"""
if self.fake_mode is not None:
fake_args = [self.fake_mode.from_tensor(t) for t in args]
fake_args = [self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args]
else:
fake_args = args
return super().run(*fake_args)

0 comments on commit 090166d

Please sign in to comment.