Skip to content
46 changes: 28 additions & 18 deletions torchvision/models/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,23 @@ def _get_leaf_modules_for_ops() -> List[type]:
return result


def _set_default_tracer_kwargs(original_tr_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]:
default_autowrap_modules = (math, torchvision.ops)
default_leaf_modules = _get_leaf_modules_for_ops()
result_tracer_kwargs = {} if original_tr_kwargs is None else original_tr_kwargs
result_tracer_kwargs["autowrap_modules"] = (
tuple(set(result_tracer_kwargs["autowrap_modules"] + default_autowrap_modules))
if "autowrap_modules" in result_tracer_kwargs
else default_autowrap_modules
)
result_tracer_kwargs["leaf_modules"] = (
list(set(result_tracer_kwargs["leaf_modules"] + default_leaf_modules))
if "leaf_modules" in result_tracer_kwargs
else default_leaf_modules
)
return result_tracer_kwargs


def get_graph_node_names(
model: nn.Module,
tracer_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -212,7 +229,11 @@ def get_graph_node_names(
tracer_kwargs (dict, optional): a dictionary of keywork arguments for
``NodePathTracer`` (they are eventually passed onto
`torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
By default it will be set to wrap and make leaf nodes all torchvision ops.
By default it will be set to wrap and make leaf nodes all torchvision ops:
{"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),}
WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user
provided dictionary.

suppress_diff_warning (bool, optional): whether to suppress a warning
when there are discrepancies between the train and eval version of
the graph. Defaults to False.
Expand All @@ -226,14 +247,7 @@ def get_graph_node_names(
>>> model = torchvision.models.resnet18()
>>> train_nodes, eval_nodes = get_graph_node_names(model)
"""
if tracer_kwargs is None:
tracer_kwargs = {
"autowrap_modules": (
math,
torchvision.ops,
),
"leaf_modules": _get_leaf_modules_for_ops(),
}
tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs)
is_training = model.training
train_tracer = NodePathTracer(**tracer_kwargs)
train_tracer.trace(model.train())
Expand Down Expand Up @@ -378,7 +392,10 @@ def create_feature_extractor(
tracer_kwargs (dict, optional): a dictionary of keywork arguments for
``NodePathTracer`` (which passes them onto it's parent class
`torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
By default it will be set to wrap and make leaf nodes all torchvision ops.
By default it will be set to wrap and make leaf nodes all torchvision ops:
{"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),}
WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user
provided dictionary.
suppress_diff_warning (bool, optional): whether to suppress a warning
when there are discrepancies between the train and eval version of
the graph. Defaults to False.
Expand Down Expand Up @@ -423,14 +440,7 @@ def create_feature_extractor(
>>> 'autowrap_functions': [leaf_function]})

"""
if tracer_kwargs is None:
tracer_kwargs = {
"autowrap_modules": (
math,
torchvision.ops,
),
"leaf_modules": _get_leaf_modules_for_ops(),
}
tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs)
is_training = model.training

if all(arg is None for arg in [return_nodes, train_return_nodes, eval_return_nodes]):
Expand Down