In [1]:
import functools
from copy import deepcopy
from types import FunctionType
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import torch
import torch.nn as nn
print(torch.__version__)
from torch._C import ScriptObject  # type: ignore[attr-defined]
from torch.fx import Graph, GraphModule, Tracer
from torch.fx._symbolic_trace import (_autowrap_check,
                                          _patch_wrapped_functions, _Patcher)
from torch.fx.proxy import Proxy

1.13.1+cu116


In [2]:
_orig_module_call: Callable = nn.Module.__call__
_orig_module_getattr: Callable = nn.Module.__getattr__

In [3]:
class UntracedMethodRegistry:
    """A `Descriptor` class which records untraced methods. Thus, when the
    class is traced with CustomTracer, the decorated method will be as a leaf
    node, not be nested traced.

    Example:
        >>> # `imported_cls` is the owner of the untraced method;
        >>> # `method_str` is the name of the untraced method.
        >>> method_registry = UntracedMethodRegistry(method)
        >>> method_registry.__set_name__(imported_cls, method_str)

    Args:
        method (FunctionType): Function to be registered.
    """
    method_dict: Dict = dict()
    tracer = None

    def __init__(self, method: FunctionType):
        self.method = method
        self.owner = None

    def __set_name__(self, owner, name):
        self.owner = owner
        self.name = name
        wrapped = self.method_wrapper()
        self.method_dict[name] = dict(mod=self.owner, wrapped=wrapped)

    def method_wrapper(self):

        @functools.wraps(self.method)
        def wrapped_method(mod, *args, **kwargs):

            def method(*args, **kwargs):
                return self.method(mod, *args, **kwargs)

            return self.tracer.call_method(mod, self.name, method, args,
                                           kwargs)

        return wrapped_method

In [4]:
def _prepare_module_dict(model: torch.nn.Module, fx_graph):
    """If there is a class method that can not be traced by the symbolic
    tracer, a ``call_method`` ``Node`` will be inserted into the ``Graph`` in
    ``CustomTracer``.

    Example:
        >>> class Model:
        ...     def __init__(self):
        ...         self.head = ClsHead()
        ...
        >>> class ClsHead(nn.Module):
        ...     def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
        ...         return feats[-1]
        ...
        ...     def loss(self, feats: Tuple[torch.Tensor],
        ...              data_samples: List[ClsDataSample], **kwargs) -> dict:
        ...         cls_score = self(feats)
        ...         # The part can not be traced by torch.fx
        ...         losses = self._get_loss(cls_score, data_samples, **kwargs)
        ...         return losses
        ...
        ...     def _get_loss(self, cls_score: torch.Tensor,
        ...                   data_samples: List[ClsDataSample], **kwargs):
        ...         if 'score' in data_samples[0].gt_label:
        ...             xxx
        ...         else:
        ...             xxx
        ...         losses = xxx
        ...         return losses

    As the ``_get_loss`` can not be traced by torch.fx, ``Toy._get_loss`` need
    to be added to ``skipped_methods`` in ``CustomTracer``. Hence the code
    above will product the following Graph::

    .. code-block:: text
        ... ...
        %head : [#users=1] = get_attr[target=head]
        %_get_loss : [#users=1] = call_method[target=_get_loss](args = (%head, %head_fc, %data_samples), kwargs = {})  # noqa: E501
        return _get_loss

    Hence, the head module in the ``GraphModule`` and that in the original
    model are the same one (refer to https://github.com/pytorch/pytorch/blob/master/torch/fx/graph_module.py#L346).  # noqa: E501
    So changes made to the graph module (in ``prepare()``) will also modify
    the original model.

    Args:
        model (torch.nn.Module): Module or function to be
            traced and converted into a Graph representation.
        fx_graph (torch.fx.Graph): The fx Graph traced by fx tracer. It
            contains the nodes this GraphModule should use for code generation.
    """

    def _get_attrs(target, attrs):
        attrs = attrs.split('.')
        for att in attrs:
            target = getattr(target, att)
        return target

    module_dict = dict()
    special_nodes = []

    for node in fx_graph.nodes:
        if node.op == 'get_attr':
            attr = _get_attrs(model, node.target)
            if isinstance(attr, nn.Module):
                module_dict[node.target] = nn.Module()
                special_nodes.append(node)
        elif node.op == 'call_method':
            for special_node in special_nodes:
                if special_node in node.args or \
                        special_node in node.kwargs.values():
                    origin_module = getattr(model, special_node.target)
                    setattr(module_dict[special_node.target], node.target,
                            getattr(origin_module, node.target))

    return module_dict

In [5]:
def duplicate_reused_nodes(graph: Graph, modules: Dict[str, Any] = {}):
    """Deepcopy the shared modules (e.g. shared detection head in RetinaNet) to
    make sure modules can be fused correctly.

    Modified from https://github.com/ModelTC/MQBench/blob/main/mqbench/prepare_by_platform.py  # noqa: E501
    """
    _dup_prefix = '_dup'
    target_dict = dict()
    dup_modules = dict()
    for node in graph.nodes:
        if node.op == 'call_module':
            if node.target not in target_dict:
                target_dict[node.target] = [node]
            else:
                target_dict[node.target].append(node)
    for key in target_dict:
        if len(target_dict[key]) > 1:
            for idx, node in enumerate(target_dict[key]):
                if idx == 0:
                    continue
                module = deepcopy(modules[node.target])
                node.target += _dup_prefix + str(idx)
                dup_modules[node.target] = module
    graph.lint()
    return graph, dup_modules

In [6]:
def build_graphmodule(model: torch.nn.Module,
                      fx_graph,
                      name: str = 'GraphModule'):
    """To build GraphModule with the generated graph by CustomTracer. The
    implement of skipping methods in CustomTracer will cause the confliction of
    that a node is both a leaf node and non-leaf node, which will lead that the
    modification to the ``graph`` also change the original ``forward``.

    Args:
        model (torch.nn.Module): Module or function to be
            traced and converted into a Graph representation.
        fx_graph (torch.fx.Graph): The fx Graph traced by fx tracer. It
            contains the nodes this GraphModule should use for code generation.
        name (str): The name of generated GraphModule.

    Returns:
        GraphModule: GraphModule is an nn.Module generated from an fx.Graph.
        Graphmodule has a ``graph`` attribute, as well as ``code`` and
        ``forward`` attributes generated from that ``graph``.

    .. warning::
        When ``graph`` is reassigned, ``code`` and ``forward`` will be
        automatically regenerated. However, if you edit the contents of the
        ``graph`` without reassigning the ``graph`` attribute itself, you must
        call ``recompile()`` to update the generated code.
    """
    modules = dict(model.named_modules())
    module_dict = _prepare_module_dict(model, fx_graph)
    fx_graph, duplicated_modules = duplicate_reused_nodes(fx_graph, modules)
    modules.update(module_dict)
    modules.update(duplicated_modules)
    return GraphModule(modules, fx_graph, name)

In [7]:
import torch
from torch.fx._symbolic_trace import Tracer
from torch.fx.node import Target, Node, Argument
from torch.nn.intrinsic import _FusedModule
from typing import List, Callable, Tuple, Any, Dict, Optional

__all__ = [
    "QuantizationTracer",
]

class Scope(object):
    """ Scope object that records the module path and the module type
    of a module. Scope is used to track the information of the module
    that contains a Node in a Graph of GraphModule. For example::
        class Sub(torch.nn.Module):
            def forward(self, x):
                # This will be a call_method Node in GraphModule,
                # scope for this would be (module_path="sub", module_type=Sub)
                return x.transpose(1, 2)
        class M(torch.nn.Module):
            def __init__(self):
                self.sub = Sub()
            def forward(self, x):
                # This will be a call_method Node as well,
                # scope for this would be (module_path="", None)
                x = x.transpose(1, 2)
                x = self.sub(x)
                return x
    """

    def __init__(self, module_path: str, module_type: Any):
        super().__init__()
        self.module_path = module_path
        self.module_type = module_type


class ScopeContextManager(object):
    """ A context manager to track the Scope of Node during symbolic tracing.
    When entering a forward function of a Module, we'll update the scope information of
    the current module, and when we exit, we'll restore the previous scope information.
    """

    def __init__(
        self, scope: Scope, current_module: torch.nn.Module, current_module_path: str
    ):
        super().__init__()
        self.prev_module_type = scope.module_type
        self.prev_module_path = scope.module_path
        self.scope = scope
        self.scope.module_path = current_module_path
        self.scope.module_type = type(current_module)

    def __enter__(self):
        return

    def __exit__(self, *args):
        self.scope.module_path = self.prev_module_path
        self.scope.module_type = self.prev_module_type
        return

class QuantizationTracer(Tracer):
    def __init__(
        self, skipped_module_names: List[str], skipped_module_classes: List[Callable]
    ):
        super().__init__()
        self.skipped_module_names = skipped_module_names
        self.skipped_module_classes = skipped_module_classes
        # NB: initialized the module_type of top level module to None
        # we are assuming people won't configure the model with the type of top level
        # module here, since people can use "" for global config
        # We can change this if there is a use case that configures
        # qconfig using top level module type
        self.scope = Scope("", None)
        self.node_name_to_scope: Dict[str, Tuple[str, type]] = {}
        self.record_stack_traces = True

    def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
        return (
            (
                (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn"))
                and not isinstance(m, torch.nn.Sequential)
            )
            or module_qualified_name in self.skipped_module_names
            or type(m) in self.skipped_module_classes
            or isinstance(m, _FusedModule)
        )

    def call_module(
        self,
        m: torch.nn.Module,
        forward: Callable[..., Any],
        args: Tuple[Any, ...],
        kwargs: Dict[str, Any],
    ) -> Any:
        module_qualified_name = self.path_of_module(m)
        # Creating scope with information of current module
        # scope will be restored automatically upon exit
        with ScopeContextManager(self.scope, m, module_qualified_name):
            return super().call_module(m, forward, args, kwargs)

    def create_node(
        self,
        kind: str,
        target: Target,
        args: Tuple[Argument, ...],
        kwargs: Dict[str, Argument],
        name: Optional[str] = None,
        type_expr: Optional[Any] = None,
    ) -> Node:
        node = super().create_node(kind, target, args, kwargs, name, type_expr)
        self.node_name_to_scope[node.name] = (
            self.scope.module_path,
            self.scope.module_type,
        )
        return node

In [8]:
def import_modules_from_strings(imports, allow_failed_imports=False):
    """Import modules from the given list of strings.

    Args:
        imports (list | str | None): The given module names to be imported.
        allow_failed_imports (bool): If True, the failed imports will return
            None. Otherwise, an ImportError is raise. Defaults to False.

    Returns:
        list[module] | module | None: The imported modules.

    Examples:
        >>> osp, sys = import_modules_from_strings(
        ...     ['os.path', 'sys'])
        >>> import os.path as osp_
        >>> import sys as sys_
        >>> assert osp == osp_
        >>> assert sys == sys_
    """
    if not imports:
        return
    single_import = False
    if isinstance(imports, str):
        single_import = True
        imports = [imports]
    if not isinstance(imports, list):
        raise TypeError(
            f'custom_imports must be a list but got type {type(imports)}')
    imported = []
    for imp in imports:
        if not isinstance(imp, str):
            raise TypeError(
                f'{imp} is of type {type(imp)} and cannot be imported.')
        try:
            imported_tmp = import_module(imp)
        except ImportError:
            if allow_failed_imports:
                warnings.warn(f'{imp} failed to import and is ignored.',
                              UserWarning)
                imported_tmp = None
            else:
                raise ImportError(f'Failed to import {imp}')
        imported.append(imported_tmp)
    if single_import:
        imported = imported[0]
    return imported

In [9]:
class CustomTracer(QuantizationTracer):
    """Custom tracer based on QuantizationTracer of pytorch. It can not only
    skip some modules and classes while tracing, but also skip some methods
    untraced by torch.fx.Tracer.

    Args:
        skipped_methods (List[str], optional): Methods to be skipped while
            tracing. Defaults to None.
        skipped_module_names (List[str], optional): Modules to be skipped
            while tracing. Defaults to None.
        skipped_module_classes (List[Callable], optional): Class to be skipped
            while tracing. Defaults to None.
    """

    def __init__(self,
                 skipped_methods: List[str] = [],
                 skipped_module_names: List[str] = [],
                 skipped_module_classes: List[Callable] = [],
                 *args,
                 **kwargs):
        super(CustomTracer, self).__init__(skipped_module_names,
                                           skipped_module_classes)
        UntracedMethodRegistry.tracer = self  # type: ignore
        self.skipped_methods = skipped_methods
        if self.skipped_methods:
            self.register_skipped_methods()

    @staticmethod
    def _check_valid_source(source):
        """Check if the source's format is valid."""
        if not isinstance(source, str):
            raise TypeError(f'source should be a str '
                            f'instance, but got {type(source)}')

        assert len(source.split('.')) > 1, \
            'source must have at least one `.`'

    def register_skipped_methods(self):
        """Register skipped methods to UntracedMethodRegistry.method_dict."""
        if not isinstance(self.skipped_methods, list):
            self.skipped_methods = [self.skipped_methods]
        for s_method in self.skipped_methods:
            self._check_valid_source(s_method)
            mod_str = '.'.join(s_method.split('.')[:-2])
            cls_str = s_method.split('.')[-2]
            method_str = s_method.split('.')[-1]

            try:
                mod = import_modules_from_strings(mod_str)
            except ImportError:
                raise ImportError(f'{mod_str} is not imported correctly.')

            imported_cls: type = getattr(mod, cls_str)
            if not isinstance(imported_cls, type):
                raise TypeError(f'{cls_str} should be a type '
                                f'instance, but got {type(imported_cls)}')
            assert hasattr(imported_cls, method_str), \
                   f'{method_str} is not in {mod_str}.'

            method = getattr(imported_cls, method_str)

            method_registry = UntracedMethodRegistry(method)
            method_registry.__set_name__(imported_cls, method_str)

    def call_method(self, m: torch.nn.Module, name: str, method: Callable,
                    args: Tuple, kwargs: Dict):
        """Method that specifies the behavior of this ``Tracer`` when it
        encounters a call to an ``nn.Module`` instance.

        By default, the behavior is to check if the called module is a leaf
        module via ``is_leaf_module``. If it is, emit a ``call_module``
        node referring to ``m`` in the ``Graph``. Otherwise, call the
        ``Module`` normally, tracing through the operations in its ``forward``
        function.

        This method can be overridden to--for example--create nested traced
        GraphModules, or any other behavior you would want while tracing across
        ``Module`` boundaries.

        Args:
            m (torch.nn.Module): The module for which a call is being emitted
            name (str): The name of proxy to be created.
            method (Callable): The method of the ``Module`` to be invoked
            args (Tuple): args of the module callsite
            kwargs (Dict): kwargs of the module callsite

        Return:

            The return value from the Module call. In the case that a
            ``call_module`` node was emitted, this is a ``Proxy`` value.
            Otherwise, it is whatever value was returned from the ``Module``
            invocation.
        """
        # module_qualified_name = self.path_of_module(m)
        if not self.is_skipped_method(m):
            return method(*args, **kwargs)
        args_l = list(args)
        args_l.insert(0, m)
        args = tuple(args_l)
        return self.create_proxy('call_method', name, args, kwargs)

    def trace(self,
              root: Union[torch.nn.Module, Callable[..., Any]],
              concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
        """Trace ``root`` and return the corresponding FX ``Graph``
        representation. ``root`` can either be an ``nn.Module`` instance or a
        Python callable. Note that after this call, ``self.root`` may be
        different from the ``root`` passed in here. For example, when a free
        function is passed to ``trace()``, we will create an ``nn.Module``
        instance to use as the root and add embedded constants to.

        Args:
            root (Union[Module, Callable]): Either a ``Module`` or a function
                to be traced through. Backwards-compatibility for this
                parameter is guaranteed.
            concrete_args (Optional[Dict[str, any]]): Concrete arguments that
                should not be treated as Proxies. This parameter is
                experimental and its backwards-compatibility is *NOT*
                guaranteed.

        Returns:
            A ``Graph`` representing the semantics of the passed-in ``root``.
        """
        if isinstance(root, torch.nn.Module):
            self.root = root
            fn = type(root).forward
            self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = {
                mod: name
                for name, mod in root.named_modules()
            }
        else:
            self.root = nn.Module()
            fn = root

        tracer_cls: Optional[Type['Tracer']] = getattr(self, '__class__', None)
        self.graph = Graph(tracer_cls=tracer_cls)

        # When we encounter a Tensor value that's not a parameter, we look if
        # it is some other attribute on the model. Construct a dict mapping
        # Tensor values to the qualified name here for efficiency. This is
        # used downstream in create_arg
        self.tensor_attrs: Dict[Union[torch.Tensor, ScriptObject], str] = {}

        def collect_tensor_attrs(m: nn.Module, prefix_atoms: List[str]):
            for k, v in m.__dict__.items():
                if isinstance(v, (torch.Tensor, ScriptObject)):
                    self.tensor_attrs[v] = '.'.join(prefix_atoms + [k])
            for k, v in m.named_children():
                collect_tensor_attrs(v, prefix_atoms + [k])

        collect_tensor_attrs(self.root, [])

        assert isinstance(fn, FunctionType)

        fn_globals = fn.__globals__  # run before it gets patched
        fn, args = self.create_args_for_root(fn, isinstance(root, nn.Module),
                                             concrete_args)

        # Reduce number of get_attr calls
        parameter_proxy_cache: Dict[str, Proxy] = {}

        # Method dispatch on parameters is not recorded unless it's directly
        # used. Thus, we need to insert a proxy when __getattr__ requests a
        # parameter.
        @functools.wraps(_orig_module_getattr)
        def module_getattr_wrapper(mod, attr):
            attr_val = _orig_module_getattr(mod, attr)
            return self.getattr(attr, attr_val, parameter_proxy_cache)

        @functools.wraps(_orig_module_call)
        def module_call_wrapper(mod, *args, **kwargs):

            def forward(*args, **kwargs):
                return _orig_module_call(mod, *args, **kwargs)

            _autowrap_check(
                patcher,
                getattr(getattr(mod, 'forward', mod), '__globals__', {}),
                self._autowrap_function_ids)
            return self.call_module(mod, forward, args, kwargs)

        with _Patcher() as patcher:
            # allow duplicate patches to support the case of nested calls
            patcher.patch_method(
                nn.Module,
                '__getattr__',
                module_getattr_wrapper,
                deduplicate=False)
            patcher.patch_method(
                nn.Module, '__call__', module_call_wrapper, deduplicate=False)

            for name, value in UntracedMethodRegistry.method_dict.items():
                wrapped = value['wrapped']
                patcher.patch_method(
                    value['mod'], name, wrapped, deduplicate=False)

            _patch_wrapped_functions(patcher)
            _autowrap_check(patcher, fn_globals, self._autowrap_function_ids)
            for module in self._autowrap_search:
                _autowrap_check(patcher, module.__dict__,
                                self._autowrap_function_ids)
            self.create_node(
                'output',
                'output', (self.create_arg(fn(*args)), ), {},
                type_expr=fn.__annotations__.get('return', None))

        self.submodule_paths = None

        return self.graph

    def is_skipped_method(self, m: torch.nn.Module):
        """Judge if ``m`` is registered skipped method."""
        mods = tuple(value['mod']
                     for value in UntracedMethodRegistry.method_dict.values())
        custom = isinstance(m, mods)
        return custom

    def is_leaf_module(self, m: torch.nn.Module,
                       module_qualified_name: str) -> bool:
        """A method to specify whether a given ``nn.Module`` is a "leaf"
        module. Leaf modules are the atomic units that appear in the IR,
        referenced by ``call_module`` calls. By default, Modules in the PyTorch
        standard library namespace (torch.nn) are leaf modules. All other
        modules are traced through and their constituent ops are recorded,
        unless specified otherwise via this parameter.

        Args:
            m (Module): The module being queried about
            module_qualified_name (str): The path to root of this module.
                For example, if you have a module hierarchy where submodule
                ``foo`` contains submodule ``bar``, which contains submodule
                ``baz``, that module will appear with the qualified name
                ``foo.bar.baz`` here.
        """
        leaf = super().is_leaf_module(m, module_qualified_name)
        return leaf

In [10]:
def custom_symbolic_trace(
        root: Union[torch.nn.Module, Callable[..., Any]],
        concrete_args: Optional[Dict[str, Any]] = None) -> GraphModule:
    """Modified `symbolic_trace` function in pytorch. Given an ``nn.Module`` or
    function instance ``root``, this function will return a ``GraphModule``
    constructed by recording operations seen while tracing through ``root``.

    Args:
        root (torch.nn.Module): Module or function to be
            traced and converted into a Graph representation.
        concrete_args (Optional[Dict[str, any]]): Inputs to be partially
            specialized.

    Returns:
        GraphModule: a Module created from the recorded operations from
        ``root``.
    """
    tracer = CustomTracer()
    graph = tracer.trace(root, concrete_args)
    name = root.__class__.__name__ if isinstance(
        root, torch.nn.Module) else root.__name__
    return GraphModule(tracer.root, graph, name)


In [12]:
!pip install timm==0.8.15.dev0
import timm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting timm==0.8.15.dev0
  Downloading timm-0.8.15.dev0-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m21.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting huggingface-hub
  Downloading huggingface_hub-0.12.1-py3-none-any.whl (190 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.3/190.3 KB[0m [31m18.6 MB/s[0m eta [36m0:00:00[0m
Collecting safetensors
  Downloading safetensors-0.2.8-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m55.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: safetensors, huggingface-hub, timm
Successfully installed huggingface-hub-0.12.1 safetensors-0.2.8 timm-0.8.15.dev0


In [13]:
all_models = timm.list_models('*swinv2*')
# print(all_models)
from torch.fx import symbolic_trace
swin_model = timm.create_model('swin_base_patch4_window7_224', num_classes=1000)
swinv2_model = timm.create_model('swinv2_base_window8_256', num_classes=1000)
davit_model = timm.create_model('davit_tiny', num_classes=1000)
swin_graph_module = custom_symbolic_trace(swin_model)
swinv2_graph_module = custom_symbolic_trace(swinv2_model)
davit_graph_module = custom_symbolic_trace(davit_model)
# swin_graph_module.print_readable()
# swin_graph_module.graph.print_tabular()
# swinv2_graph_module.print_readable()
# swinv2_graph_module.graph.print_tabular()
# davit_graph_module.print_readable()
# davit_graph_module.graph.print_tabular()


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [23]:
!pip install mmcls==1.0.0rc5 mmcv==2.0.0rc1 mmengine
from mmcls.models.backbones.resnet import ResLayer
from mmengine.config import Config
from mmengine.registry import MODELS
cfg = Config.fromfile(
    'tests/data/test_models/test_task_modules/mmcls_cfg.py')
skipped_methods = [
    'mmcls.models.heads.ClsHead._get_loss',
    'mmcls.models.heads.ClsHead._get_predictions'
]
skipped_module_names = ['backbone.layer4.0']
skipped_module_classes = [ResLayer]

# init without skipped_methods
tracer = CustomTracer()
assert hasattr(tracer, 'skipped_methods')
assert len(tracer.skipped_methods) == 0
# init with skipped_methods(list)
UntracedMethodRegistry.method_dict = dict()
tracer = CustomTracer(skipped_methods=self.skipped_methods)
assert '_get_loss' in UntracedMethodRegistry.method_dict.keys()
assert '_get_predictions' in UntracedMethodRegistry.method_dict.keys()
# init with skipped_methods(str)
UntracedMethodRegistry.method_dict = dict()
tracer = CustomTracer(skipped_methods=self.skipped_methods[0])
assert '_get_loss' in UntracedMethodRegistry.method_dict.keys()

 # test trace with skipped_methods
model = MODELS.build(cfg.model)
UntracedMethodRegistry.method_dict = dict()
tracer = CustomTracer(skipped_methods=skipped_methods)
graph_tensor = tracer.trace(model, concrete_args={'mode': 'tensor'})
graph_loss = tracer.trace(model, concrete_args={'mode': 'loss'})
graph_predict = tracer.trace(model, concrete_args={'mode': 'predict'})
assert isinstance(graph_tensor, Graph)
assert isinstance(graph_loss, Graph)
skip_flag_loss = False
for node in graph_loss.nodes:
    if node.op == 'call_method' and node.target == '_get_loss':
        skip_flag_loss = True
assert isinstance(graph_predict, Graph)
skip_flag_predict = False
for node in graph_predict.nodes:
    if node.op == 'call_method' and node.target == '_get_predictions':
        skip_flag_predict = True
assert skip_flag_loss and skip_flag_predict

# test trace with skipped_module_names
model = MODELS.build(cfg.model)
UntracedMethodRegistry.method_dict = dict()
tracer = CustomTracer(skipped_module_names=skipped_module_names)
graph_tensor = tracer.trace(model, concrete_args={'mode': 'tensor'})
skip_flag = False
for node in graph_tensor.nodes:
    skipped_module_name = skipped_module_names[0]
    if node.op == 'call_module' and node.target == skipped_module_name:
        skip_flag = True
assert skip_flag

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting mmcv==2.0.0rc1
  Using cached mmcv-2.0.0rc1.tar.gz (406 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: mmcv
  Building wheel for mmcv (setup.py) ... [?25l[?25hdone
  Created wheel for mmcv: filename=mmcv-2.0.0rc1-cp38-cp38-linux_x86_64.whl size=28115159 sha256=d32719d080c9861260c6c98af489fe6981401f88ff79e1ee8f525999442bbe44
  Stored in directory: /root/.cache/pip/wheels/5d/d3/64/54e29987d1fb1abb6ea08307121a60d6a84463e43603956fc2
Successfully built mmcv
Installing collected packages: mmcv
  Attempting uninstall: mmcv
    Found existing installation: mmcv 1.7.1
    Uninstalling mmcv-1.7.1:
      Successfully uninstalled mmcv-1.7.1
Successfully installed mmcv-2.0.0rc1


AssertionError: ignored