diff --git a/hubconf.py b/hubconf.py index 3488fef76ac5..03335f7906f0 100644 --- a/hubconf.py +++ b/hubconf.py @@ -27,6 +27,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo """ from pathlib import Path + from models.common import AutoShape from models.experimental import attempt_load from models.yolo import Model from utils.downloads import attempt_download @@ -55,7 +56,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo if len(ckpt['model'].names) == classes: model.names = ckpt['model'].names # set class names attribute if autoshape: - model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS + model = AutoShape(model) # for file/URI/PIL/cv2/np inputs and NMS return model.to(device) except Exception as e: diff --git a/models/common.py b/models/common.py index 3930c8e7b2df..b9604f3c1cbd 100644 --- a/models/common.py +++ b/models/common.py @@ -23,7 +23,7 @@ from utils.general import (LOGGER, check_requirements, check_suffix, colorstr, increment_path, make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh) from utils.plots import Annotator, colors, save_one_box -from utils.torch_utils import time_sync +from utils.torch_utils import copy_attr, time_sync def autopad(k, p=None): # kernel, padding @@ -405,12 +405,10 @@ class AutoShape(nn.Module): def __init__(self, model): super().__init__() + LOGGER.info('Adding AutoShape... ') + copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes self.model = model.eval() - def autoshape(self): - LOGGER.info('AutoShape already enabled, skipping... ') # model already converted to model.autoshape() - return self - def _apply(self, fn): # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers self = super()._apply(fn) diff --git a/models/yolo.py b/models/yolo.py index 305f0ca0cc88..db3d711a81fa 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -22,8 +22,7 @@ from utils.autoanchor import check_anchor_order from utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args from utils.plots import feature_visualization -from utils.torch_utils import (copy_attr, fuse_conv_and_bn, initialize_weights, model_info, scale_img, select_device, - time_sync) +from utils.torch_utils import fuse_conv_and_bn, initialize_weights, model_info, scale_img, select_device, time_sync try: import thop # for FLOPs computation @@ -226,12 +225,6 @@ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers self.info() return self - def autoshape(self): # add AutoShape module - LOGGER.info('Adding AutoShape... ') - m = AutoShape(self) # wrap model - copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes - return m - def info(self, verbose=False, img_size=640): # print model information model_info(self, verbose, img_size)