Skip to content

Commit

Permalink
ultralytics 8.1.43 40% faster ultralytics imports (#9547)
Browse files Browse the repository at this point in the history
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
glenn-jocher committed Apr 5, 2024
1 parent 99c61d6 commit a262865
Show file tree
Hide file tree
Showing 21 changed files with 240 additions and 225 deletions.
29 changes: 11 additions & 18 deletions tests/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import torch
import yaml
from PIL import Image
from torchvision.transforms import ToTensor

from ultralytics import RTDETR, YOLO
from ultralytics.cfg import TASK2DATA
Expand Down Expand Up @@ -108,20 +107,17 @@ def test_predict_img():
assert len(model(batch, imgsz=32)) == len(batch) # multiple sources in a batch

# Test tensor inference
im = cv2.imread(str(SOURCE)) # OpenCV
t = cv2.resize(im, (32, 32))
t = ToTensor()(t)
t = torch.stack([t, t, t, t])
results = model(t, imgsz=32)
assert len(results) == t.shape[0]
results = seg_model(t, imgsz=32)
assert len(results) == t.shape[0]
results = cls_model(t, imgsz=32)
assert len(results) == t.shape[0]
results = pose_model(t, imgsz=32)
assert len(results) == t.shape[0]
results = obb_model(t, imgsz=32)
assert len(results) == t.shape[0]
im = torch.rand((4, 3, 32, 32)) # batch-size 4, FP32 0.0-1.0 RGB order
results = model(im, imgsz=32)
assert len(results) == im.shape[0]
results = seg_model(im, imgsz=32)
assert len(results) == im.shape[0]
results = cls_model(im, imgsz=32)
assert len(results) == im.shape[0]
results = pose_model(im, imgsz=32)
assert len(results) == im.shape[0]
results = obb_model(im, imgsz=32)
assert len(results) == im.shape[0]


def test_predict_grey_and_4ch():
Expand Down Expand Up @@ -592,8 +588,6 @@ def image():
)
def test_classify_transforms_train(image, auto_augment, erasing, force_color_jitter):
"""Tests classification transforms during training with various augmentation settings."""
import torchvision.transforms as T

from ultralytics.data.augment import classify_augmentations

transform = classify_augmentations(
Expand All @@ -610,7 +604,6 @@ def test_classify_transforms_train(image, auto_augment, erasing, force_color_jit
hsv_v=0.4,
force_color_jitter=force_color_jitter,
erasing=erasing,
interpolation=T.InterpolationMode.BILINEAR,
)

transformed_image = transform(Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)))
Expand Down
2 changes: 1 addition & 1 deletion ultralytics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Ultralytics YOLO 馃殌, AGPL-3.0 license

__version__ = "8.1.42"
__version__ = "8.1.43"

from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
Expand Down
21 changes: 12 additions & 9 deletions ultralytics/data/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import cv2
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image

from ultralytics.utils import LOGGER, colorstr
from ultralytics.utils.checks import check_version
Expand Down Expand Up @@ -167,8 +167,8 @@ def _update_label_text(self, labels):
text2id = {text: i for i, text in enumerate(mix_texts)}

for label in [labels] + labels["mix_labels"]:
for i, l in enumerate(label["cls"].squeeze(-1).tolist()):
text = label["texts"][int(l)]
for i, cls in enumerate(label["cls"].squeeze(-1).tolist()):
text = label["texts"][int(cls)]
label["cls"][i] = text2id[tuple(text)]
label["texts"] = mix_texts
return labels
Expand Down Expand Up @@ -1133,7 +1133,7 @@ def classify_transforms(
size=224,
mean=DEFAULT_MEAN,
std=DEFAULT_STD,
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
interpolation=Image.BILINEAR,
crop_fraction: float = DEFAULT_CROP_FTACTION,
):
"""
Expand All @@ -1149,6 +1149,7 @@ def classify_transforms(
Returns:
(T.Compose): torchvision transforms
"""
import torchvision.transforms as T # scope for faster 'import ultralytics'

if isinstance(size, (tuple, list)):
assert len(size) == 2
Expand All @@ -1157,12 +1158,12 @@ def classify_transforms(
scale_size = math.floor(size / crop_fraction)
scale_size = (scale_size, scale_size)

# aspect ratio is preserved, crops center within image, no borders are added, image is lost
# Aspect ratio is preserved, crops center within image, no borders are added, image is lost
if scale_size[0] == scale_size[1]:
# simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
# Simple case, use torchvision built-in Resize with the shortest edge mode (scalar size arg)
tfl = [T.Resize(scale_size[0], interpolation=interpolation)]
else:
# resize shortest edge to matching target dim for non-square target
# Resize the shortest edge to matching target dim for non-square target
tfl = [T.Resize(scale_size)]
tfl += [T.CenterCrop(size)]

Expand Down Expand Up @@ -1192,7 +1193,7 @@ def classify_augmentations(
hsv_v=0.4, # image HSV-Value augmentation (fraction)
force_color_jitter=False,
erasing=0.0,
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
interpolation=Image.BILINEAR,
):
"""
Classification transforms with augmentation for training. Inspired by timm/data/transforms_factory.py.
Expand All @@ -1216,7 +1217,9 @@ def classify_augmentations(
Returns:
(T.Compose): torchvision transforms
"""
# Transforms to apply if albumentations not installed
# Transforms to apply if Albumentations not installed
import torchvision.transforms as T # scope for faster 'import ultralytics'

if not isinstance(size, int):
raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)")
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
Expand Down

0 comments on commit a262865

Please sign in to comment.