From 947d6f61d60e95dbc8c8950aa1925012a6dd2e42 Mon Sep 17 00:00:00 2001 From: Shanghua Gao Date: Tue, 22 Nov 2022 19:15:55 +0800 Subject: [PATCH] [Feature] Support receptive field search of CNN models (#2056) * support rfsearch * add labs for rfsearch * format * format * add docstring and type hints * clean code Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * rm unused func * update code * update code * update code * update details * fix details * support asymmetric kernel * support asymmetric kernel * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Apply suggestions from code review * add unit tests for rfsearch * set device for Conv2dRFSearchOp * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * remove unused function search_estimate_only * move unit tests * Update tests/test_cnn/test_rfsearch/test_operator.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmcv/cnn/rfsearch/operator.py Co-authored-by: Yue Zhou <592267829@qq.com> * change logger * Update mmcv/cnn/rfsearch/operator.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: lzyhha <819814373@qq.com> Co-authored-by: Zhongyu Li <44114862+lzyhha@users.noreply.github.com> Co-authored-by: Yue Zhou <592267829@qq.com> [Fix] Fix skip_layer for RF-Next (#2489) * judge skip_layer by fullname * lint * skip_layer first * update unit test --- mmcv/cnn/__init__.py | 3 +- mmcv/cnn/rfsearch/__init__.py | 5 + mmcv/cnn/rfsearch/operator.py | 170 +++++++++ mmcv/cnn/rfsearch/search.py | 238 +++++++++++++ mmcv/cnn/rfsearch/utils.py | 68 ++++ tests/test_cnn/test_rfsearch/test_operator.py | 325 ++++++++++++++++++ tests/test_cnn/test_rfsearch/test_search.py | 177 ++++++++++ 7 files changed, 985 insertions(+), 1 deletion(-) create mode 100644 mmcv/cnn/rfsearch/__init__.py create mode 100644 mmcv/cnn/rfsearch/operator.py create mode 100644 mmcv/cnn/rfsearch/search.py create mode 100644 mmcv/cnn/rfsearch/utils.py create mode 100644 tests/test_cnn/test_rfsearch/test_operator.py create mode 100644 tests/test_cnn/test_rfsearch/test_search.py diff --git a/mmcv/cnn/__init__.py b/mmcv/cnn/__init__.py index ce2d2463d0..10e7e027e4 100644 --- a/mmcv/cnn/__init__.py +++ b/mmcv/cnn/__init__.py @@ -11,6 +11,7 @@ build_upsample_layer, conv_ws_2d, is_norm) # yapf: enable from .resnet import ResNet, make_res_layer +from .rfsearch import Conv2dRFSearchOp, RFSearchHook from .utils import fuse_conv_bn, get_model_complexity_info from .vgg import VGG, make_vgg_layer @@ -23,5 +24,5 @@ 'Scale', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d', 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d', 'fuse_conv_bn', - 'get_model_complexity_info' + 'get_model_complexity_info', 'Conv2dRFSearchOp', 'RFSearchHook' ] diff --git a/mmcv/cnn/rfsearch/__init__.py b/mmcv/cnn/rfsearch/__init__.py new file mode 100644 index 0000000000..04d45725dc --- /dev/null +++ b/mmcv/cnn/rfsearch/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .operator import BaseConvRFSearchOp, Conv2dRFSearchOp +from .search import RFSearchHook + +__all__ = ['BaseConvRFSearchOp', 'Conv2dRFSearchOp', 'RFSearchHook'] diff --git a/mmcv/cnn/rfsearch/operator.py b/mmcv/cnn/rfsearch/operator.py new file mode 100644 index 0000000000..3d3416f59e --- /dev/null +++ b/mmcv/cnn/rfsearch/operator.py @@ -0,0 +1,170 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import numpy as np +import torch +import torch.nn as nn +from mmengine.logging import MMLogger +from mmengine.model import BaseModule +from torch import Tensor + +from .utils import expand_rates, get_single_padding + +logger = MMLogger.get_current_instance() + + +class BaseConvRFSearchOp(BaseModule): + """Based class of ConvRFSearchOp. + + Args: + op_layer (nn.Module): pytorch module, e,g, Conv2d + global_config (dict): config dict. + """ + + def __init__(self, op_layer: nn.Module, global_config: dict): + super().__init__() + self.op_layer = op_layer + self.global_config = global_config + + def normlize(self, weights: nn.Parameter) -> nn.Parameter: + """Normalize weights. + + Args: + weights (nn.Parameter): Weights to be normalized. + + Returns: + nn.Parameters: Normalized weights. + """ + abs_weights = torch.abs(weights) + normalized_weights = abs_weights / torch.sum(abs_weights) + return normalized_weights + + +class Conv2dRFSearchOp(BaseConvRFSearchOp): + """Enable Conv2d with receptive field searching ability. + + Args: + op_layer (nn.Module): pytorch module, e,g, Conv2d + global_config (dict): config dict. Defaults to None. + By default this must include: + + - "init_alphas": The value for initializing weights of each branch. + - "num_branches": The controller of the size of + search space (the number of branches). + - "exp_rate": The controller of the sparsity of search space. + - "mmin": The minimum dilation rate. + - "mmax": The maximum dilation rate. + + Extra keys may exist, but are used by RFSearchHook, e.g., "step", + "max_step", "search_interval", and "skip_layer". + verbose (bool): Determines whether to print rf-next + related logging messages. + Defaults to True. + """ + + def __init__(self, + op_layer: nn.Module, + global_config: dict, + verbose: bool = True): + super().__init__(op_layer, global_config) + assert global_config is not None, 'global_config is None' + self.num_branches = global_config['num_branches'] + assert self.num_branches in [2, 3] + self.verbose = verbose + init_dilation = op_layer.dilation + self.dilation_rates = expand_rates(init_dilation, global_config) + if self.op_layer.kernel_size[ + 0] == 1 or self.op_layer.kernel_size[0] % 2 == 0: + self.dilation_rates = [(op_layer.dilation[0], r[1]) + for r in self.dilation_rates] + if self.op_layer.kernel_size[ + 1] == 1 or self.op_layer.kernel_size[1] % 2 == 0: + self.dilation_rates = [(r[0], op_layer.dilation[1]) + for r in self.dilation_rates] + + self.branch_weights = nn.Parameter(torch.Tensor(self.num_branches)) + if self.verbose: + logger.info(f'Expand as {self.dilation_rates}') + nn.init.constant_(self.branch_weights, global_config['init_alphas']) + + def forward(self, input: Tensor) -> Tensor: + norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)]) + if len(self.dilation_rates) == 1: + outputs = [ + nn.functional.conv2d( + input, + weight=self.op_layer.weight, + bias=self.op_layer.bias, + stride=self.op_layer.stride, + padding=self.get_padding(self.dilation_rates[0]), + dilation=self.dilation_rates[0], + groups=self.op_layer.groups, + ) + ] + else: + outputs = [ + nn.functional.conv2d( + input, + weight=self.op_layer.weight, + bias=self.op_layer.bias, + stride=self.op_layer.stride, + padding=self.get_padding(r), + dilation=r, + groups=self.op_layer.groups, + ) * norm_w[i] for i, r in enumerate(self.dilation_rates) + ] + output = outputs[0] + for i in range(1, len(self.dilation_rates)): + output += outputs[i] + return output + + def estimate_rates(self): + """Estimate new dilation rate based on trained branch_weights.""" + norm_w = self.normlize(self.branch_weights[:len(self.dilation_rates)]) + if self.verbose: + logger.info('Estimate dilation {} with weight {}.'.format( + self.dilation_rates, + norm_w.detach().cpu().numpy().tolist())) + + sum0, sum1, w_sum = 0, 0, 0 + for i in range(len(self.dilation_rates)): + sum0 += norm_w[i].item() * self.dilation_rates[i][0] + sum1 += norm_w[i].item() * self.dilation_rates[i][1] + w_sum += norm_w[i].item() + estimated = [ + np.clip( + int(round(sum0 / w_sum)), self.global_config['mmin'], + self.global_config['mmax']).item(), + np.clip( + int(round(sum1 / w_sum)), self.global_config['mmin'], + self.global_config['mmax']).item() + ] + self.op_layer.dilation = tuple(estimated) + self.op_layer.padding = self.get_padding(self.op_layer.dilation) + self.dilation_rates = [tuple(estimated)] + if self.verbose: + logger.info(f'Estimate as {tuple(estimated)}') + + def expand_rates(self): + """Expand dilation rate.""" + dilation = self.op_layer.dilation + dilation_rates = expand_rates(dilation, self.global_config) + if self.op_layer.kernel_size[ + 0] == 1 or self.op_layer.kernel_size[0] % 2 == 0: + dilation_rates = [(dilation[0], r[1]) for r in dilation_rates] + if self.op_layer.kernel_size[ + 1] == 1 or self.op_layer.kernel_size[1] % 2 == 0: + dilation_rates = [(r[0], dilation[1]) for r in dilation_rates] + + self.dilation_rates = copy.deepcopy(dilation_rates) + if self.verbose: + logger.info(f'Expand as {self.dilation_rates}') + nn.init.constant_(self.branch_weights, + self.global_config['init_alphas']) + + def get_padding(self, dilation): + padding = (get_single_padding(self.op_layer.kernel_size[0], + self.op_layer.stride[0], dilation[0]), + get_single_padding(self.op_layer.kernel_size[1], + self.op_layer.stride[1], dilation[1])) + return padding diff --git a/mmcv/cnn/rfsearch/search.py b/mmcv/cnn/rfsearch/search.py new file mode 100644 index 0000000000..d54021a0c0 --- /dev/null +++ b/mmcv/cnn/rfsearch/search.py @@ -0,0 +1,238 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from typing import Dict, Optional + +import mmengine +import torch.nn as nn +from mmengine.hooks import Hook +from mmengine.logging import MMLogger +from mmengine.registry import HOOKS + +from mmcv.cnn.rfsearch.utils import get_single_padding, write_to_json +from .operator import BaseConvRFSearchOp + +logger = MMLogger.get_current_instance() + + +@HOOKS.register_module() +class RFSearchHook(Hook): + """Rcecptive field search via dilation rates. + + Please refer to `RF-Next: Efficient Receptive Field + Search for Convolutional Neural Networks + `_ for more details. + + + Args: + mode (str, optional): It can be set to the following types: + 'search', 'fixed_single_branch', or 'fixed_multi_branch'. + Defaults to 'search'. + config (Dict, optional): config dict of search. + By default this config contains "search", + and config["search"] must include: + + - "step": recording the current searching step. + - "max_step": The maximum number of searching steps + to update the structures. + - "search_interval": The interval (epoch/iteration) + between two updates. + - "exp_rate": The controller of the sparsity of search space. + - "init_alphas": The value for initializing weights of each branch. + - "mmin": The minimum dilation rate. + - "mmax": The maximum dilation rate. + - "num_branches": The controller of the size of + search space (the number of branches). + - "skip_layer": The modules in skip_layer will be ignored + during the receptive field search. + rfstructure_file (str, optional): Path to load searched receptive + fields of the model. Defaults to None. + by_epoch (bool, optional): Determine to perform step by epoch or + by iteration. If set to True, it will step by epoch. Otherwise, by + iteration. Defaults to True. + verbose (bool): Determines whether to print rf-next related logging + messages. Defaults to True. + """ + + def __init__(self, + mode: str = 'search', + config: Dict = {}, + rfstructure_file: Optional[str] = None, + by_epoch: bool = True, + verbose: bool = True): + assert mode in ['search', 'fixed_single_branch', 'fixed_multi_branch'] + assert config is not None + self.config = config + self.config['structure'] = {} + self.verbose = verbose + if rfstructure_file is not None: + rfstructure = mmengine.load(rfstructure_file)['structure'] + self.config['structure'] = rfstructure + self.mode = mode + self.num_branches = self.config['search']['num_branches'] + self.by_epoch = by_epoch + + def init_model(self, model: nn.Module): + """init model with search ability. + + Args: + model (nn.Module): pytorch model + + Raises: + NotImplementedError: only support three modes: + search/fixed_single_branch/fixed_multi_branch + """ + if self.verbose: + logger.info('RFSearch init begin.') + if self.mode == 'search': + if self.config['structure']: + self.set_model(model, search_op='Conv2d') + self.wrap_model(model, search_op='Conv2d') + elif self.mode == 'fixed_single_branch': + self.set_model(model, search_op='Conv2d') + elif self.mode == 'fixed_multi_branch': + self.set_model(model, search_op='Conv2d') + self.wrap_model(model, search_op='Conv2d') + else: + raise NotImplementedError + if self.verbose: + logger.info('RFSearch init end.') + + def after_train_epoch(self, runner): + """Performs a dilation searching step after one training epoch.""" + if self.by_epoch and self.mode == 'search': + self.step(runner.model, runner.work_dir) + + def after_train_iter(self, runner): + """Performs a dilation searching step after one training iteration.""" + if not self.by_epoch and self.mode == 'search': + self.step(runner.model, runner.work_dir) + + def step(self, model: nn.Module, work_dir: str): + """Performs a dilation searching step. + + Args: + model (nn.Module): pytorch model + work_dir (str): Directory to save the searching results. + """ + self.config['search']['step'] += 1 + if (self.config['search']['step'] + ) % self.config['search']['search_interval'] == 0 and (self.config[ + 'search']['step']) < self.config['search']['max_step']: + self.estimate_and_expand(model) + for name, module in model.named_modules(): + if isinstance(module, BaseConvRFSearchOp): + self.config['structure'][name] = module.op_layer.dilation + + write_to_json( + self.config, + os.path.join( + work_dir, + 'local_search_config_step%d.json' % + self.config['search']['step'], + ), + ) + + def estimate_and_expand(self, model: nn.Module): + """estimate and search for RFConvOp. + + Args: + model (nn.Module): pytorch model + """ + for module in model.modules(): + if isinstance(module, BaseConvRFSearchOp): + module.estimate_rates() + module.expand_rates() + + def wrap_model(self, + model: nn.Module, + search_op: str = 'Conv2d', + prefix: str = ''): + """wrap model to support searchable conv op. + + Args: + model (nn.Module): pytorch model + search_op (str): The module that uses RF search. + Defaults to 'Conv2d'. + init_rates (int, optional): Set to other initial dilation rates. + Defaults to None. + prefix (str): Prefix for function recursion. Defaults to ''. + """ + op = 'torch.nn.' + search_op + for name, module in model.named_children(): + if prefix == '': + fullname = 'module.' + name + else: + fullname = prefix + '.' + name + if self.config['search']['skip_layer'] is not None: + if any(layer in fullname + for layer in self.config['search']['skip_layer']): + continue + if isinstance(module, eval(op)): + if 1 < module.kernel_size[0] and \ + 0 != module.kernel_size[0] % 2 or \ + 1 < module.kernel_size[1] and \ + 0 != module.kernel_size[1] % 2: + moduleWrap = eval(search_op + 'RFSearchOp')( + module, self.config['search'], self.verbose) + moduleWrap = moduleWrap.to(module.weight.device) + if self.verbose: + logger.info('Wrap model %s to %s.' % + (str(module), str(moduleWrap))) + setattr(model, name, moduleWrap) + elif not isinstance(module, BaseConvRFSearchOp): + self.wrap_model(module, search_op, fullname) + + def set_model(self, + model: nn.Module, + search_op: str = 'Conv2d', + init_rates: Optional[int] = None, + prefix: str = ''): + """set model based on config. + + Args: + model (nn.Module): pytorch model + config (Dict): config file + search_op (str): The module that uses RF search. + Defaults to 'Conv2d'. + init_rates (int, optional): Set to other initial dilation rates. + Defaults to None. + prefix (str): Prefix for function recursion. Defaults to ''. + """ + op = 'torch.nn.' + search_op + for name, module in model.named_children(): + if prefix == '': + fullname = 'module.' + name + else: + fullname = prefix + '.' + name + if self.config['search']['skip_layer'] is not None: + if any(layer in fullname + for layer in self.config['search']['skip_layer']): + continue + if isinstance(module, eval(op)): + if 1 < module.kernel_size[0] and \ + 0 != module.kernel_size[0] % 2 or \ + 1 < module.kernel_size[1] and \ + 0 != module.kernel_size[1] % 2: + if isinstance(self.config['structure'][fullname], int): + self.config['structure'][fullname] = [ + self.config['structure'][fullname], + self.config['structure'][fullname] + ] + module.dilation = ( + self.config['structure'][fullname][0], + self.config['structure'][fullname][1], + ) + module.padding = ( + get_single_padding( + module.kernel_size[0], module.stride[0], + self.config['structure'][fullname][0]), + get_single_padding( + module.kernel_size[1], module.stride[1], + self.config['structure'][fullname][1])) + setattr(model, name, module) + if self.verbose: + logger.info( + 'Set module %s dilation as: [%d %d]' % + (fullname, module.dilation[0], module.dilation[1])) + elif not isinstance(module, BaseConvRFSearchOp): + self.set_model(module, search_op, init_rates, fullname) diff --git a/mmcv/cnn/rfsearch/utils.py b/mmcv/cnn/rfsearch/utils.py new file mode 100644 index 0000000000..4c8168e343 --- /dev/null +++ b/mmcv/cnn/rfsearch/utils.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmengine +import numpy as np + + +def write_to_json(config: dict, filename: str): + """save config to json file. + + Args: + config (dict): Config to be saved. + filename (str): Path to save config. + """ + + with open(filename, 'w', encoding='utf-8') as f: + mmengine.dump(config, f, file_format='json') + + +def expand_rates(dilation: tuple, config: dict) -> list: + """expand dilation rate according to config. + + Args: + dilation (int): _description_ + config (dict): config dict + + Returns: + list: list of expanded dilation rates + """ + exp_rate = config['exp_rate'] + + large_rates = [] + small_rates = [] + for _ in range(config['num_branches'] // 2): + large_rates.append( + tuple([ + np.clip( + int(round((1 + exp_rate) * dilation[0])), config['mmin'], + config['mmax']).item(), + np.clip( + int(round((1 + exp_rate) * dilation[1])), config['mmin'], + config['mmax']).item() + ])) + small_rates.append( + tuple([ + np.clip( + int(round((1 - exp_rate) * dilation[0])), config['mmin'], + config['mmax']).item(), + np.clip( + int(round((1 - exp_rate) * dilation[1])), config['mmin'], + config['mmax']).item() + ])) + + small_rates.reverse() + + if config['num_branches'] % 2 == 0: + rate_list = small_rates + large_rates + else: + rate_list = small_rates + [dilation] + large_rates + + unique_rate_list = list(set(rate_list)) + unique_rate_list.sort(key=rate_list.index) + return unique_rate_list + + +def get_single_padding(kernel_size: int, + stride: int = 1, + dilation: int = 1) -> int: + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding diff --git a/tests/test_cnn/test_rfsearch/test_operator.py b/tests/test_cnn/test_rfsearch/test_operator.py new file mode 100644 index 0000000000..b555605fc5 --- /dev/null +++ b/tests/test_cnn/test_rfsearch/test_operator.py @@ -0,0 +1,325 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy + +import torch +import torch.nn as nn + +from mmcv.cnn.rfsearch.operator import Conv2dRFSearchOp + +global_config = dict( + step=0, + max_step=12, + search_interval=1, + exp_rate=0.5, + init_alphas=0.01, + mmin=1, + mmax=24, + num_branches=2, + skip_layer=['stem', 'layer1']) + + +# test with 3x3 conv +def test_rfsearch_operator_3x3(): + conv = nn.Conv2d( + in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1) + operator = Conv2dRFSearchOp(conv, global_config) + x = torch.randn(1, 3, 32, 32) + + # set no_grad to perform in-place operator + with torch.no_grad(): + # After expand: (1, 1) (2, 2) + assert len(operator.dilation_rates) == 2 + assert operator.dilation_rates[0] == (1, 1) + assert operator.dilation_rates[1] == (2, 2) + assert torch.all(operator.branch_weights.data == + global_config['init_alphas']).item() + # test forward + assert operator(x).shape == (1, 3, 32, 32) + + # After estimate: (2, 2) with branch_weights of [0.5 0.5] + operator.estimate_rates() + assert len(operator.dilation_rates) == 1 + assert operator.dilation_rates[0] == (2, 2) + assert operator.op_layer.dilation == (2, 2) + assert operator.op_layer.padding == (2, 2) + # test forward + assert operator(x).shape == (1, 3, 32, 32) + + # After expand: (1, 1) (3, 3) + operator.expand_rates() + assert len(operator.dilation_rates) == 2 + assert operator.dilation_rates[0] == (1, 1) + assert operator.dilation_rates[1] == (3, 3) + assert torch.all(operator.branch_weights.data == + global_config['init_alphas']).item() + # test forward + assert operator(x).shape == (1, 3, 32, 32) + + operator.branch_weights[0] = 0.1 + operator.branch_weights[1] = 0.4 + # After estimate: (3, 3) with branch_weights of [0.2 0.8] + operator.estimate_rates() + assert len(operator.dilation_rates) == 1 + assert operator.dilation_rates[0] == (3, 3) + assert operator.op_layer.dilation == (3, 3) + assert operator.op_layer.padding == (3, 3) + # test forward + assert operator(x).shape == (1, 3, 32, 32) + + +# test with 5x5 conv +def test_rfsearch_operator_5x5(): + conv = nn.Conv2d( + in_channels=3, out_channels=3, kernel_size=5, stride=1, padding=2) + operator = Conv2dRFSearchOp(conv, global_config) + x = torch.randn(1, 3, 32, 32) + + with torch.no_grad(): + # After expand: (1, 1) (2, 2) + assert len(operator.dilation_rates) == 2 + assert operator.dilation_rates[0] == (1, 1) + assert operator.dilation_rates[1] == (2, 2) + assert torch.all(operator.branch_weights.data == + global_config['init_alphas']).item() + # test forward + assert operator(x).shape == (1, 3, 32, 32) + + # After estimate: (2, 2) with branch_weights of [0.5 0.5] + operator.estimate_rates() + assert len(operator.dilation_rates) == 1 + assert operator.dilation_rates[0] == (2, 2) + assert operator.op_layer.dilation == (2, 2) + assert operator.op_layer.padding == (4, 4) + # test forward + assert operator(x).shape == (1, 3, 32, 32) + + # After expand: (1, 1) (3, 3) + operator.expand_rates() + assert len(operator.dilation_rates) == 2 + assert operator.dilation_rates[0] == (1, 1) + assert operator.dilation_rates[1] == (3, 3) + assert torch.all(operator.branch_weights.data == + global_config['init_alphas']).item() + # test forward + assert operator(x).shape == (1, 3, 32, 32) + + operator.branch_weights[0] = 0.1 + operator.branch_weights[1] = 0.4 + # After estimate: (3, 3) with branch_weights of [0.2 0.8] + operator.estimate_rates() + assert len(operator.dilation_rates) == 1 + assert operator.dilation_rates[0] == (3, 3) + assert operator.op_layer.dilation == (3, 3) + assert operator.op_layer.padding == (6, 6) + # test forward + assert operator(x).shape == (1, 3, 32, 32) + + +# test with 5x5 conv num_branches=3 +def test_rfsearch_operator_5x5_branch3(): + conv = nn.Conv2d( + in_channels=3, out_channels=3, kernel_size=5, stride=1, padding=2) + config = deepcopy(global_config) + config['num_branches'] = 3 + operator = Conv2dRFSearchOp(conv, config) + x = torch.randn(1, 3, 32, 32) + + with torch.no_grad(): + # After expand: (1, 1) (2, 2) + assert len(operator.dilation_rates) == 2 + assert operator.dilation_rates[0] == (1, 1) + assert operator.dilation_rates[1] == (2, 2) + assert torch.all(operator.branch_weights.data == + global_config['init_alphas']).item() + # test forward + assert operator(x).shape == (1, 3, 32, 32) + + # After estimate: (2, 2) with branch_weights of [0.5 0.5] + operator.estimate_rates() + assert len(operator.dilation_rates) == 1 + assert operator.dilation_rates[0] == (2, 2) + assert operator.op_layer.dilation == (2, 2) + assert operator.op_layer.padding == (4, 4) + # test forward + assert operator(x).shape == (1, 3, 32, 32) + + # After expand: (1, 1) (2, 2) (3, 3) + operator.expand_rates() + assert len(operator.dilation_rates) == 3 + assert operator.dilation_rates[0] == (1, 1) + assert operator.dilation_rates[1] == (2, 2) + assert operator.dilation_rates[2] == (3, 3) + assert torch.all(operator.branch_weights.data == + global_config['init_alphas']).item() + # test forward + assert operator(x).shape == (1, 3, 32, 32) + + operator.branch_weights[0] = 0.1 + operator.branch_weights[1] = 0.3 + operator.branch_weights[2] = 0.6 + # After estimate: (3, 3) with branch_weights of [0.1 0.3 0.6] + operator.estimate_rates() + assert len(operator.dilation_rates) == 1 + assert operator.dilation_rates[0] == (3, 3) + assert operator.op_layer.dilation == (3, 3) + assert operator.op_layer.padding == (6, 6) + # test forward + assert operator(x).shape == (1, 3, 32, 32) + + +# test with 1x5 conv +def test_rfsearch_operator_1x5(): + conv = nn.Conv2d( + in_channels=3, + out_channels=3, + kernel_size=(1, 5), + stride=1, + padding=(0, 2)) + operator = Conv2dRFSearchOp(conv, global_config) + x = torch.randn(1, 3, 32, 32) + + # After expand: (1, 1) (1, 2) + assert len(operator.dilation_rates) == 2 + assert operator.dilation_rates[0] == (1, 1) + assert operator.dilation_rates[1] == (1, 2) + assert torch.all( + operator.branch_weights.data == global_config['init_alphas']).item() + # test forward + assert operator(x).shape == (1, 3, 32, 32) + + with torch.no_grad(): + # After estimate: (1, 2) with branch_weights of [0.5 0.5] + operator.estimate_rates() + assert len(operator.dilation_rates) == 1 + assert operator.dilation_rates[0] == (1, 2) + assert operator.op_layer.dilation == (1, 2) + assert operator.op_layer.padding == (0, 4) + # test forward + assert operator(x).shape == (1, 3, 32, 32) + + # After expand: (1, 1) (1, 3) + operator.expand_rates() + assert len(operator.dilation_rates) == 2 + assert operator.dilation_rates[0] == (1, 1) + assert operator.dilation_rates[1] == (1, 3) + assert torch.all(operator.branch_weights.data == + global_config['init_alphas']).item() + # test forward + assert operator(x).shape == (1, 3, 32, 32) + + operator.branch_weights[0] = 0.2 + operator.branch_weights[1] = 0.8 + # After estimate: (3, 3) with branch_weights of [0.2 0.8] + operator.estimate_rates() + assert len(operator.dilation_rates) == 1 + assert operator.dilation_rates[0] == (1, 3) + assert operator.op_layer.dilation == (1, 3) + assert operator.op_layer.padding == (0, 6) + # test forward + assert operator(x).shape == (1, 3, 32, 32) + + +# test with 5x5 conv initial_dilation=(2, 2) +def test_rfsearch_operator_5x5_d2x2(): + conv = nn.Conv2d( + in_channels=3, + out_channels=3, + kernel_size=5, + stride=1, + padding=4, + dilation=(2, 2)) + operator = Conv2dRFSearchOp(conv, global_config) + x = torch.randn(1, 3, 32, 32) + + with torch.no_grad(): + # After expand: (1, 1) (3, 3) + assert len(operator.dilation_rates) == 2 + assert operator.dilation_rates[0] == (1, 1) + assert operator.dilation_rates[1] == (3, 3) + assert torch.all(operator.branch_weights.data == + global_config['init_alphas']).item() + # test forward + assert operator(x).shape == (1, 3, 32, 32) + + # After estimate: (2, 2) with branch_weights of [0.5 0.5] + operator.estimate_rates() + assert len(operator.dilation_rates) == 1 + assert operator.dilation_rates[0] == (2, 2) + assert operator.op_layer.dilation == (2, 2) + assert operator.op_layer.padding == (4, 4) + # test forward + assert operator(x).shape == (1, 3, 32, 32) + + # After expand: (1, 1) (3, 3) + operator.expand_rates() + assert len(operator.dilation_rates) == 2 + assert operator.dilation_rates[0] == (1, 1) + assert operator.dilation_rates[1] == (3, 3) + assert torch.all(operator.branch_weights.data == + global_config['init_alphas']).item() + # test forward + assert operator(x).shape == (1, 3, 32, 32) + + operator.branch_weights[0] = 0.8 + operator.branch_weights[1] = 0.2 + # After estimate: (3, 3) with branch_weights of [0.8 0.2] + operator.estimate_rates() + assert len(operator.dilation_rates) == 1 + assert operator.dilation_rates[0] == (1, 1) + assert operator.op_layer.dilation == (1, 1) + assert operator.op_layer.padding == (2, 2) + # test forward + assert operator(x).shape == (1, 3, 32, 32) + + +# test with 5x5 conv initial_dilation=(1, 2) +def test_rfsearch_operator_5x5_d1x2(): + conv = nn.Conv2d( + in_channels=3, + out_channels=3, + kernel_size=5, + stride=1, + padding=(2, 4), + dilation=(1, 2)) + operator = Conv2dRFSearchOp(conv, global_config) + x = torch.randn(1, 3, 32, 32) + + with torch.no_grad(): + # After expand: (1, 1) (2, 3) + assert len(operator.dilation_rates) == 2 + assert operator.dilation_rates[0] == (1, 1) + assert operator.dilation_rates[1] == (2, 3) + assert torch.all(operator.branch_weights.data == + global_config['init_alphas']).item() + # test forward + assert operator(x).shape == (1, 3, 32, 32) + + # After estimate: (2, 2) with branch_weights of [0.5 0.5] + operator.estimate_rates() + assert len(operator.dilation_rates) == 1 + assert operator.dilation_rates[0] == (2, 2) + assert operator.op_layer.dilation == (2, 2) + assert operator.op_layer.padding == (4, 4) + # test forward + assert operator(x).shape == (1, 3, 32, 32) + + # After expand: (1, 1) (3, 3) + operator.expand_rates() + assert len(operator.dilation_rates) == 2 + assert operator.dilation_rates[0] == (1, 1) + assert operator.dilation_rates[1] == (3, 3) + assert torch.all(operator.branch_weights.data == + global_config['init_alphas']).item() + # test forward + assert operator(x).shape == (1, 3, 32, 32) + + operator.branch_weights[0] = 0.1 + operator.branch_weights[1] = 0.8 + # After estimate: (3, 3) with branch_weights of [0.1 0.8] + operator.estimate_rates() + assert len(operator.dilation_rates) == 1 + assert operator.dilation_rates[0] == (3, 3) + assert operator.op_layer.dilation == (3, 3) + assert operator.op_layer.padding == (6, 6) + # test forward + assert operator(x).shape == (1, 3, 32, 32) diff --git a/tests/test_cnn/test_rfsearch/test_search.py b/tests/test_cnn/test_rfsearch/test_search.py new file mode 100644 index 0000000000..1821349811 --- /dev/null +++ b/tests/test_cnn/test_rfsearch/test_search.py @@ -0,0 +1,177 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Tests the rfsearch with runners. + +CommandLine: + pytest tests/test_runner/test_hooks.py + xdoctest tests/test_hooks.py zero +""" + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader + +from mmcv.cnn.rfsearch import Conv2dRFSearchOp, RFSearchHook +from tests.test_runner.test_hooks import _build_demo_runner + + +def test_rfsearchhook(): + + def conv(in_channels, out_channels, kernel_size, stride, padding, + dilation): + return nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation) + + class Model(nn.Module): + + def __init__(self): + super().__init__() + self.stem = conv(1, 2, 3, 1, 1, 1) + self.conv0 = conv(2, 2, 3, 1, 1, 1) + self.layer0 = nn.Sequential( + conv(2, 2, 3, 1, 1, 1), conv(2, 2, 3, 1, 1, 1)) + self.conv1 = conv(2, 2, 1, 1, 0, 1) + self.conv2 = conv(2, 2, 3, 1, 1, 1) + self.conv3 = conv(2, 2, (1, 3), 1, (0, 1), 1) + + def forward(self, x): + x1 = self.stem(x) + x2 = self.layer0(x1) + x3 = self.conv0(x2) + x4 = self.conv1(x3) + x5 = self.conv2(x4) + x6 = self.conv3(x5) + return x6 + + def train_step(self, x, optimizer, **kwargs): + return dict(loss=self(x).mean(), num_samples=x.shape[0]) + + rfsearch_cfg = dict( + mode='search', + rfstructure_file=None, + config=dict( + search=dict( + step=0, + max_step=12, + search_interval=1, + exp_rate=0.5, + init_alphas=0.01, + mmin=1, + mmax=24, + num_branches=2, + skip_layer=['stem', 'conv0', 'layer0.1'])), + ) + + # hook for search + rfsearchhook_search = RFSearchHook( + 'search', rfsearch_cfg['config'], by_epoch=True, verbose=True) + rfsearchhook_search.config['structure'] = { + 'module.layer0.0': [1, 1], + 'module.conv2': [2, 2], + 'module.conv3': [1, 1] + } + # hook for fixed_single_branch + rfsearchhook_fixed_single_branch = RFSearchHook( + 'fixed_single_branch', + rfsearch_cfg['config'], + by_epoch=True, + verbose=True) + rfsearchhook_fixed_single_branch.config['structure'] = { + 'module.layer0.0': [1, 1], + 'module.conv2': [2, 2], + 'module.conv3': [1, 1] + } + # hook for fixed_multi_branch + rfsearchhook_fixed_multi_branch = RFSearchHook( + 'fixed_multi_branch', + rfsearch_cfg['config'], + by_epoch=True, + verbose=True) + rfsearchhook_fixed_multi_branch.config['structure'] = { + 'module.layer0.0': [1, 1], + 'module.conv2': [2, 2], + 'module.conv3': [1, 1] + } + + def test_skip_layer(): + assert not isinstance(model.stem, Conv2dRFSearchOp) + assert not isinstance(model.conv0, Conv2dRFSearchOp) + assert isinstance(model.layer0[0], Conv2dRFSearchOp) + assert not isinstance(model.layer0[1], Conv2dRFSearchOp) + + # 1. test init_model() with mode of search + model = Model() + rfsearchhook_search.init_model(model) + + test_skip_layer() + assert not isinstance(model.conv1, Conv2dRFSearchOp) + assert isinstance(model.conv2, Conv2dRFSearchOp) + assert isinstance(model.conv3, Conv2dRFSearchOp) + assert model.conv2.dilation_rates == [(1, 1), (3, 3)] + assert model.conv3.dilation_rates == [(1, 1), (1, 2)] + + # 1. test step() with mode of search + loader = DataLoader(torch.ones((1, 1, 1, 1))) + runner = _build_demo_runner() + runner.model = model + runner.register_hook(rfsearchhook_search) + runner.run([loader], [('train', 1)]) + + test_skip_layer() + assert not isinstance(model.conv1, Conv2dRFSearchOp) + assert isinstance(model.conv2, Conv2dRFSearchOp) + assert isinstance(model.conv3, Conv2dRFSearchOp) + assert model.conv2.dilation_rates == [(1, 1), (3, 3)] + assert model.conv3.dilation_rates == [(1, 1), (1, 3)] + + # 2. test init_model() with mode of fixed_single_branch + model = Model() + rfsearchhook_fixed_single_branch.init_model(model) + + assert not isinstance(model.conv1, Conv2dRFSearchOp) + assert not isinstance(model.conv2, Conv2dRFSearchOp) + assert not isinstance(model.conv3, Conv2dRFSearchOp) + assert model.conv1.dilation == (1, 1) + assert model.conv2.dilation == (2, 2) + assert model.conv3.dilation == (1, 1) + + # 2. test step() with mode of fixed_single_branch + runner = _build_demo_runner() + runner.model = model + runner.register_hook(rfsearchhook_fixed_single_branch) + runner.run([loader], [('train', 1)]) + + assert not isinstance(model.conv1, Conv2dRFSearchOp) + assert not isinstance(model.conv2, Conv2dRFSearchOp) + assert not isinstance(model.conv3, Conv2dRFSearchOp) + assert model.conv1.dilation == (1, 1) + assert model.conv2.dilation == (2, 2) + assert model.conv3.dilation == (1, 1) + + # 3. test init_model() with mode of fixed_multi_branch + model = Model() + rfsearchhook_fixed_multi_branch.init_model(model) + + test_skip_layer() + assert not isinstance(model.conv1, Conv2dRFSearchOp) + assert isinstance(model.conv2, Conv2dRFSearchOp) + assert isinstance(model.conv3, Conv2dRFSearchOp) + assert model.conv2.dilation_rates == [(1, 1), (3, 3)] + assert model.conv3.dilation_rates == [(1, 1), (1, 2)] + + # 3. test step() with mode of fixed_single_branch + runner = _build_demo_runner() + runner.model = model + runner.register_hook(rfsearchhook_fixed_multi_branch) + runner.run([loader], [('train', 1)]) + + test_skip_layer() + assert not isinstance(model.conv1, Conv2dRFSearchOp) + assert isinstance(model.conv2, Conv2dRFSearchOp) + assert isinstance(model.conv3, Conv2dRFSearchOp) + assert model.conv2.dilation_rates == [(1, 1), (3, 3)] + assert model.conv3.dilation_rates == [(1, 1), (1, 2)]