Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AutoShape() models as DetectMultiBackend() instances #5845

Merged
merged 34 commits into from
Dec 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
a157ca6
Update AutoShape()
glenn-jocher Dec 1, 2021
1cb159d
autodownload ONNX
glenn-jocher Dec 1, 2021
c3bca29
Cleanup
glenn-jocher Dec 1, 2021
b61478e
Finish updates
glenn-jocher Dec 1, 2021
6ac0646
Add Usage
glenn-jocher Dec 1, 2021
cb4b15f
Update
glenn-jocher Dec 1, 2021
9e267b9
Update
glenn-jocher Dec 1, 2021
1212a1d
Update
glenn-jocher Dec 1, 2021
897b65e
Update
glenn-jocher Dec 1, 2021
8de5735
Update
glenn-jocher Dec 1, 2021
6b6e22a
Update
glenn-jocher Dec 1, 2021
a054dbb
Update
glenn-jocher Dec 1, 2021
d6fc861
Update
glenn-jocher Dec 1, 2021
6b48205
Update
glenn-jocher Dec 1, 2021
514b6a1
Update
glenn-jocher Dec 1, 2021
7ed948f
Update
glenn-jocher Dec 1, 2021
e2396b9
Update
glenn-jocher Dec 1, 2021
9ae1554
Update
glenn-jocher Dec 1, 2021
03ca9e9
Update
glenn-jocher Dec 1, 2021
beb3775
Update
glenn-jocher Dec 1, 2021
a5bb9c1
Update
glenn-jocher Dec 1, 2021
7e8815e
Update
glenn-jocher Dec 1, 2021
4979514
fix device
glenn-jocher Dec 1, 2021
fc6fcf0
Update hubconf.py
glenn-jocher Dec 2, 2021
cd9e306
Update common.py
glenn-jocher Dec 2, 2021
2c7a229
smart param selection
glenn-jocher Dec 2, 2021
3148c3b
autodownload all formats
glenn-jocher Dec 2, 2021
84c5ca3
autopad only pytorch models
glenn-jocher Dec 2, 2021
981c4a0
new_shape edits
glenn-jocher Dec 2, 2021
aa140e2
Merge master
glenn-jocher Dec 2, 2021
4593d4a
Merge master
glenn-jocher Dec 2, 2021
266a6f7
stride tensor fix
glenn-jocher Dec 4, 2021
8fd4d28
Cleanup
glenn-jocher Dec 4, 2021
76f39bf
Merge master
glenn-jocher Dec 4, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion export.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def parse_opt():
parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version')
parser.add_argument('--opset', type=int, default=14, help='ONNX: opset version')
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
Expand Down
14 changes: 7 additions & 7 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Usage:
import torch
model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
model = torch.hub.load('ultralytics/yolov5:master', 'custom', 'path/to/yolov5s.onnx') # file from branch
"""

import torch
Expand All @@ -27,26 +28,25 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
"""
from pathlib import Path

from models.common import AutoShape
from models.experimental import attempt_load
from models.common import AutoShape, DetectMultiBackend
from models.yolo import Model
from utils.downloads import attempt_download
from utils.general import check_requirements, intersect_dicts, set_logging
from utils.torch_utils import select_device

file = Path(__file__).resolve()
check_requirements(exclude=('tensorboard', 'thop', 'opencv-python'))
set_logging(verbose=verbose)

save_dir = Path('') if str(name).endswith('.pt') else file.parent
path = (save_dir / name).with_suffix('.pt') # checkpoint path
name = Path(name)
path = name.with_suffix('.pt') if name.suffix == '' else name # checkpoint path
try:
device = select_device(('0' if torch.cuda.is_available() else 'cpu') if device is None else device)

if pretrained and channels == 3 and classes == 80:
model = attempt_load(path, map_location=device) # download/load FP32 model
model = DetectMultiBackend(path, device=device) # download/load FP32 model
# model = models.experimental.attempt_load(path, map_location=device) # download/load FP32 model
else:
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
cfg = list((Path(__file__).parent / 'models').rglob(f'{path.name}.yaml'))[0] # model.yaml path
model = Model(cfg, channels, classes) # create model
if pretrained:
ckpt = torch.load(attempt_download(path), map_location=device) # load
Expand Down
40 changes: 24 additions & 16 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def forward(self, x):

class DetectMultiBackend(nn.Module):
# YOLOv5 MultiBackend class for python inference on various backends
def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
def __init__(self, weights='yolov5s.pt', device=None, dnn=False):
# Usage:
# PyTorch: weights = *.pt
# TorchScript: *.torchscript
Expand All @@ -287,13 +287,16 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
# ONNX Runtime: *.onnx
# OpenCV DNN: *.onnx with dnn=True
# TensorRT: *.engine
from models.experimental import attempt_download, attempt_load # scoped to avoid circular import

super().__init__()
w = str(weights[0] if isinstance(weights, list) else weights)
suffix = Path(w).suffix.lower()
suffixes = ['.pt', '.torchscript', '.onnx', '.engine', '.tflite', '.pb', '', '.mlmodel']
check_suffix(w, suffixes) # check weights have acceptable suffix
pt, jit, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
attempt_download(w) # download if not local

if jit: # TorchScript
LOGGER.info(f'Loading {w} for TorchScript inference...')
Expand All @@ -303,11 +306,12 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
d = json.loads(extra_files['config.txt']) # extra_files dict
stride, names = int(d['stride']), d['names']
elif pt: # PyTorch
from models.experimental import attempt_load # scoped to avoid circular import
model = attempt_load(weights, map_location=device)
stride = int(model.stride.max()) # model stride
names = model.module.names if hasattr(model, 'module') else model.names # get class names
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
elif coreml: # CoreML
LOGGER.info(f'Loading {w} for CoreML inference...')
import coremltools as ct
model = ct.models.MLModel(w)
elif dnn: # ONNX OpenCV DNN
Expand All @@ -316,7 +320,7 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
net = cv2.dnn.readNetFromONNX(w)
elif onnx: # ONNX Runtime
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime'))
check_requirements(('onnx', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'))
import onnxruntime
session = onnxruntime.InferenceSession(w, None)
elif engine: # TensorRT
Expand Down Expand Up @@ -376,7 +380,7 @@ def forward(self, im, augment=False, visualize=False, val=False):
if self.pt: # PyTorch
y = self.model(im) if self.jit else self.model(im, augment=augment, visualize=visualize)
return y if val else y[0]
elif self.coreml: # CoreML *.mlmodel
elif self.coreml: # CoreML
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
im = Image.fromarray((im[0] * 255).astype('uint8'))
# im = im.resize((192, 320), Image.ANTIALIAS)
Expand Down Expand Up @@ -433,24 +437,28 @@ class AutoShape(nn.Module):
# YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
conf = 0.25 # NMS confidence threshold
iou = 0.45 # NMS IoU threshold
classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
agnostic = False # NMS class-agnostic
multi_label = False # NMS multiple labels per box
classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
max_det = 1000 # maximum number of detections per image

def __init__(self, model):
super().__init__()
LOGGER.info('Adding AutoShape... ')
copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
self.pt = not self.dmb or model.pt # PyTorch model
self.model = model.eval()

def _apply(self, fn):
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
self = super()._apply(fn)
m = self.model.model[-1] # Detect()
m.stride = fn(m.stride)
m.grid = list(map(fn, m.grid))
if isinstance(m.anchor_grid, list):
m.anchor_grid = list(map(fn, m.anchor_grid))
if self.pt:
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
m.stride = fn(m.stride)
m.grid = list(map(fn, m.grid))
if isinstance(m.anchor_grid, list):
m.anchor_grid = list(map(fn, m.anchor_grid))
return self

@torch.no_grad()
Expand All @@ -465,7 +473,7 @@ def forward(self, imgs, size=640, augment=False, profile=False):
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images

t = [time_sync()]
p = next(self.model.parameters()) # for device and type
p = next(self.model.parameters()) if self.pt else torch.zeros(1) # for device and type
if isinstance(imgs, torch.Tensor): # torch
with amp.autocast(enabled=p.device.type != 'cpu'):
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
Expand All @@ -489,21 +497,21 @@ def forward(self, imgs, size=640, augment=False, profile=False):
g = (size / max(s)) # gain
shape1.append([y * g for y in s])
imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
shape1 = [make_divisible(x, self.stride) for x in np.stack(shape1, 0).max(0)] # inference shape
x = [letterbox(im, new_shape=shape1 if self.pt else size, auto=False)[0] for im in imgs] # pad
x = np.stack(x, 0) if n > 1 else x[0][None] # stack
x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
t.append(time_sync())

with amp.autocast(enabled=p.device.type != 'cpu'):
# Inference
y = self.model(x, augment, profile)[0] # forward
y = self.model(x, augment, profile) # forward
t.append(time_sync())

# Post-process
y = non_max_suppression(y, self.conf, iou_thres=self.iou, classes=self.classes,
multi_label=self.multi_label, max_det=self.max_det) # NMS
y = non_max_suppression(y if self.dmb else y[0], self.conf, iou_thres=self.iou, classes=self.classes,
agnostic=self.agnostic, multi_label=self.multi_label, max_det=self.max_det) # NMS
for i in range(n):
scale_coords(shape1, y[i][:, :4], shape0[i])

Expand Down
4 changes: 3 additions & 1 deletion utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,9 @@ def download_one(url, dir):


def make_divisible(x, divisor):
# Returns x evenly divisible by divisor
# Returns nearest x divisible by divisor
if isinstance(divisor, torch.Tensor):
divisor = int(divisor.max()) # to int
return math.ceil(x / divisor) * divisor


Expand Down