diff --git a/hubconf.py b/hubconf.py index 5bb629005597..011eaa57ff34 100644 --- a/hubconf.py +++ b/hubconf.py @@ -29,6 +29,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo from pathlib import Path from models.common import AutoShape, DetectMultiBackend + from models.experimental import attempt_load from models.yolo import Model from utils.downloads import attempt_download from utils.general import LOGGER, check_requirements, intersect_dicts, logging @@ -42,8 +43,12 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo try: device = select_device(device) if pretrained and channels == 3 and classes == 80: - model = DetectMultiBackend(path, device=device, fuse=autoshape) # download/load FP32 model - # model = models.experimental.attempt_load(path, map_location=device) # download/load FP32 model + try: + model = DetectMultiBackend(path, device=device, fuse=autoshape) # detection model + if autoshape: + model = AutoShape(model) # for file/URI/PIL/cv2/np inputs and NMS + except Exception: + model = attempt_load(path, device=device, fuse=False) # arbitrary model else: cfg = list((Path(__file__).parent / 'models').rglob(f'{path.stem}.yaml'))[0] # model.yaml path model = Model(cfg, channels, classes) # create model @@ -54,9 +59,6 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo model.load_state_dict(csd, strict=False) # load if len(ckpt['model'].names) == classes: model.names = ckpt['model'].names # set class names attribute - if autoshape: - model.model.model[-1].inplace = False # Detect.inplace=False for safe multithread inference - model = AutoShape(model) # for file/URI/PIL/cv2/np inputs and NMS if not verbose: LOGGER.setLevel(logging.INFO) # reset to default return model.to(device) diff --git a/models/common.py b/models/common.py index 959c965e6002..c898d94a921a 100644 --- a/models/common.py +++ b/models/common.py @@ -562,6 +562,9 @@ def __init__(self, model, verbose=True): self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance self.pt = not self.dmb or model.pt # PyTorch model self.model = model.eval() + if self.pt: + m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect() + m.inplace = False # Detect.inplace=False for safe multithread inference def _apply(self, fn): # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers