Skip to content

Commit

Permalink
[export][refactor][6/n] Remove equality_constraints (pytorch#116979)
Browse files Browse the repository at this point in the history
Through the new dynamic_shapes API and using torch.export.Dim, dimensions that are equal will now be represented by the same symbol, so we no longer need to store `equality_constraints`.

Differential Revision: D52351705

Pull Request resolved: pytorch#116979
Approved by: https://github.com/avikchaudhuri
  • Loading branch information
angelayi authored and xadupre committed Jan 10, 2024
1 parent 9e3bb02 commit 6f38c1a
Showing 1 changed file with 1 addition and 15 deletions.
16 changes: 1 addition & 15 deletions torch/export/exported_program.py
Expand Up @@ -103,18 +103,12 @@ def __init__(
graph_signature: ExportGraphSignature,
state_dict: Dict[str, Union[torch.Tensor, torch.nn.Parameter]],
range_constraints: "Dict[sympy.Symbol, Any]",
equality_constraints: Optional[List[Tuple[Any, Any]]] = None,
module_call_graph: Optional[
List[ModuleCallEntry]
] = None, # TODO: make this not optional
module_call_graph: List[ModuleCallEntry],
example_inputs: Optional[Tuple[Tuple[Any, ...], Dict[str, Any]]] = None,
verifier: Optional[Type[Any]] = None, # TODO Change typing hint to Verifier.
tensor_constants: Optional[Dict[str, torch.Tensor]] = None,
):
from torch._export.exported_program import _create_graph_module_for_export
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
InputDim,
)

# Remove codegen related things from the graph. It should just be a flat graph.
graph._codegen = torch.fx.graph.CodeGen()
Expand All @@ -125,9 +119,6 @@ def __init__(
self._graph_signature: ExportGraphSignature = graph_signature
self._state_dict: Dict[str, Any] = state_dict
self._range_constraints: "Dict[sympy.Symbol, ValueRanges]" = range_constraints
self._equality_constraints: List[Tuple[InputDim, InputDim]] = (
equality_constraints or []
)
assert module_call_graph is not None
self._module_call_graph: List[ModuleCallEntry] = module_call_graph
self._example_inputs = example_inputs
Expand Down Expand Up @@ -202,11 +193,6 @@ def named_buffers(self) -> Iterator[Tuple[str, torch.Tensor]]:
def range_constraints(self):
return self._range_constraints

@property
@compatibility(is_backward_compatible=False)
def equality_constraints(self):
return self._equality_constraints

@property
@compatibility(is_backward_compatible=False)
def module_call_graph(self):
Expand Down

0 comments on commit 6f38c1a

Please sign in to comment.