diff --git a/torchvision/models/feature_extraction.py b/torchvision/models/feature_extraction.py index a6c26913093..8a25e1f7187 100644 --- a/torchvision/models/feature_extraction.py +++ b/torchvision/models/feature_extraction.py @@ -184,6 +184,23 @@ def _get_leaf_modules_for_ops() -> List[type]: return result +def _set_default_tracer_kwargs(original_tr_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]: + default_autowrap_modules = (math, torchvision.ops) + default_leaf_modules = _get_leaf_modules_for_ops() + result_tracer_kwargs = {} if original_tr_kwargs is None else original_tr_kwargs + result_tracer_kwargs["autowrap_modules"] = ( + tuple(set(result_tracer_kwargs["autowrap_modules"] + default_autowrap_modules)) + if "autowrap_modules" in result_tracer_kwargs + else default_autowrap_modules + ) + result_tracer_kwargs["leaf_modules"] = ( + list(set(result_tracer_kwargs["leaf_modules"] + default_leaf_modules)) + if "leaf_modules" in result_tracer_kwargs + else default_leaf_modules + ) + return result_tracer_kwargs + + def get_graph_node_names( model: nn.Module, tracer_kwargs: Optional[Dict[str, Any]] = None, @@ -212,7 +229,11 @@ def get_graph_node_names( tracer_kwargs (dict, optional): a dictionary of keywork arguments for ``NodePathTracer`` (they are eventually passed onto `torch.fx.Tracer `_). - By default it will be set to wrap and make leaf nodes all torchvision ops. + By default it will be set to wrap and make leaf nodes all torchvision ops: + {"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),} + WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user + provided dictionary. + suppress_diff_warning (bool, optional): whether to suppress a warning when there are discrepancies between the train and eval version of the graph. Defaults to False. @@ -226,14 +247,7 @@ def get_graph_node_names( >>> model = torchvision.models.resnet18() >>> train_nodes, eval_nodes = get_graph_node_names(model) """ - if tracer_kwargs is None: - tracer_kwargs = { - "autowrap_modules": ( - math, - torchvision.ops, - ), - "leaf_modules": _get_leaf_modules_for_ops(), - } + tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs) is_training = model.training train_tracer = NodePathTracer(**tracer_kwargs) train_tracer.trace(model.train()) @@ -378,7 +392,10 @@ def create_feature_extractor( tracer_kwargs (dict, optional): a dictionary of keywork arguments for ``NodePathTracer`` (which passes them onto it's parent class `torch.fx.Tracer `_). - By default it will be set to wrap and make leaf nodes all torchvision ops. + By default it will be set to wrap and make leaf nodes all torchvision ops: + {"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),} + WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user + provided dictionary. suppress_diff_warning (bool, optional): whether to suppress a warning when there are discrepancies between the train and eval version of the graph. Defaults to False. @@ -423,14 +440,7 @@ def create_feature_extractor( >>> 'autowrap_functions': [leaf_function]}) """ - if tracer_kwargs is None: - tracer_kwargs = { - "autowrap_modules": ( - math, - torchvision.ops, - ), - "leaf_modules": _get_leaf_modules_for_ops(), - } + tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs) is_training = model.training if all(arg is None for arg in [return_nodes, train_return_nodes, eval_return_nodes]):