diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index 2823ca726e0..5cd8628c603 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -33,7 +33,10 @@ from executorch.exir.passes import MemoryPlanningPass from executorch.exir.passes.quant_fusion_pass import QuantFusionPass -from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass +from executorch.exir.passes.sym_shape_eval_pass import ( + ConstraintBasedSymShapeEvalPass, + HintBasedSymShapeEvalPass, +) from executorch.extension.llm.export.builder import DType, LLMEdgeManager from executorch.extension.llm.tokenizer.tokenizer import Tokenizer @@ -227,6 +230,8 @@ def export_all(llava_model: LlavaModel): memory_planning_pass=MemoryPlanningPass("greedy", alloc_graph_input=False), sym_shape_eval_pass={ "image_encoder": ConstraintBasedSymShapeEvalPass(), + "text_model": ConstraintBasedSymShapeEvalPass(), + "token_embedding": HintBasedSymShapeEvalPass(), }, ) ) diff --git a/exir/capture/_config.py b/exir/capture/_config.py index c0f7b71baf9..7b91464bdce 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -12,7 +12,7 @@ from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode from executorch.exir.pass_manager import PassType from executorch.exir.passes import MemoryPlanningPass, ToOutVarPass -from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass +from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.exir.tracer import ExirDynamoConfig from torch.fx._compatibility import compatibility @@ -86,7 +86,7 @@ class ExecutorchBackendConfig: # A single sym shape eval pass can be defined for all the programs in the # EdgeProgramManager or can be defined per program. sym_shape_eval_pass: Union[PassType, Dict[str, PassType]] = ( - HintBasedSymShapeEvalPass() + ConstraintBasedSymShapeEvalPass() ) # If set to true, view_copy operations will be converted to lightweight diff --git a/exir/passes/TARGETS b/exir/passes/TARGETS index 4e59af26eae..eeb1e5265b0 100644 --- a/exir/passes/TARGETS +++ b/exir/passes/TARGETS @@ -202,6 +202,7 @@ python_library( ], deps = [ "//caffe2:torch", + "//executorch/exir:_warnings", "//executorch/exir:pass_base", "//executorch/exir:sym_util", "//executorch/exir:tensor", diff --git a/exir/passes/sym_shape_eval_pass.py b/exir/passes/sym_shape_eval_pass.py index f4d11ed8143..ec61d4b3a6f 100644 --- a/exir/passes/sym_shape_eval_pass.py +++ b/exir/passes/sym_shape_eval_pass.py @@ -10,6 +10,8 @@ import torch import torch.utils._pytree as pytree + +from executorch.exir._warnings import deprecated from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import PassBase, PassResult from executorch.exir.sym_util import eval_expr, eval_shape, eval_upper_bound @@ -164,8 +166,21 @@ def index_Tensor(args, kwargs) -> List[Optional[int]]: # noqa: C901 return out_sizes +@deprecated( + "`HintBasedSymShapeEvalPass` is deprecated " + "and will be removed in a future version of ExecuTorch. " + "Please use `ConstraintBasedSymShapeEvalPass` instead.", + category=FutureWarning, +) class HintBasedSymShapeEvalPass(PassBase): """ + + .. warning:: + + 'HintBasedSymShapeEvalPass` is deprecated + and will be removed in a future version of ExecuTorch. + Please use `ConstraintBasedSymShapeEvalPass` instead. + If we enable dynamic shape tracing, a tensor's shape may become a symbolic formula. We should convert those symbolic formula to concrete value for static/upperbound tensors so we can properly do memory planning for them.