Skip to content

Commit

Permalink
ultralytics 8.1.40 search in Python sets {} for speed (#9450)
Browse files Browse the repository at this point in the history
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
glenn-jocher committed Mar 31, 2024
1 parent 30484d5 commit ea52750
Show file tree
Hide file tree
Showing 41 changed files with 97 additions and 93 deletions.
2 changes: 1 addition & 1 deletion tests/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def test_labels_and_crops():
crop_dirs = [p for p in (save_path / "crops").iterdir()]
crop_files = [f for p in crop_dirs for f in p.glob("*")]
# Crop directories match detections
assert all([r.names.get(c) in [d.name for d in crop_dirs] for c in cls_idxs])
assert all([r.names.get(c) in {d.name for d in crop_dirs} for c in cls_idxs])
# Same number of crops as detections
assert len([f for f in crop_files if im_name in f.name]) == len(r.boxes.data)

Expand Down
2 changes: 1 addition & 1 deletion ultralytics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license

__version__ = "8.1.39"
__version__ = "8.1.40"

from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
Expand Down
6 changes: 3 additions & 3 deletions ultralytics/cfg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def get_save_dir(args, name=None):

project = args.project or (ROOT.parent / "tests/tmp/runs" if TESTS_RUNNING else RUNS_DIR) / args.task
name = name or args.name or f"{args.mode}"
save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in (-1, 0) else True)
save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in {-1, 0} else True)

return Path(save_dir)

Expand Down Expand Up @@ -566,10 +566,10 @@ def entrypoint(debug=""):
task = model.task

# Mode
if mode in ("predict", "track") and "source" not in overrides:
if mode in {"predict", "track"} and "source" not in overrides:
overrides["source"] = DEFAULT_CFG.source or ASSETS
LOGGER.warning(f"WARNING ⚠️ 'source' argument is missing. Using default 'source={overrides['source']}'.")
elif mode in ("train", "val"):
elif mode in {"train", "val"}:
if "data" not in overrides and "resume" not in overrides:
overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
LOGGER.warning(f"WARNING ⚠️ 'data' argument is missing. Using default 'data={overrides['data']}'.")
Expand Down
4 changes: 2 additions & 2 deletions ultralytics/data/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class Mosaic(BaseMixTransform):
def __init__(self, dataset, imgsz=640, p=1.0, n=4):
"""Initializes the object with a dataset, image size, probability, and border."""
assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}."
assert n in (4, 9), "grid must be equal to 4 or 9."
assert n in {4, 9}, "grid must be equal to 4 or 9."
super().__init__(dataset=dataset, p=p)
self.dataset = dataset
self.imgsz = imgsz
Expand Down Expand Up @@ -685,7 +685,7 @@ def __init__(self, p=0.5, direction="horizontal", flip_idx=None) -> None:
Default is 'horizontal'.
flip_idx (array-like, optional): Index mapping for flipping keypoints, if any.
"""
assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}"
assert direction in {"horizontal", "vertical"}, f"Support direction `horizontal` or `vertical`, got {direction}"
assert 0 <= p <= 1.0

self.p = p
Expand Down
4 changes: 2 additions & 2 deletions ultralytics/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch.utils.data import Dataset

from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM
from .utils import HELP_URL, IMG_FORMATS
from .utils import HELP_URL, FORMATS_HELP_MSG, IMG_FORMATS


class BaseDataset(Dataset):
Expand Down Expand Up @@ -118,7 +118,7 @@ def get_img_files(self, img_path):
raise FileNotFoundError(f"{self.prefix}{p} does not exist")
im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
assert im_files, f"{self.prefix}No images found in {img_path}"
assert im_files, f"{self.prefix}No images found in {img_path}. {FORMATS_HELP_MSG}"
except Exception as e:
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
if self.fraction < 1:
Expand Down
4 changes: 2 additions & 2 deletions ultralytics/data/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,15 +481,15 @@ def merge_multi_segment(segments):
segments[i] = np.roll(segments[i], -idx[0], axis=0)
segments[i] = np.concatenate([segments[i], segments[i][:1]])
# Deal with the first segment and the last one
if i in [0, len(idx_list) - 1]:
if i in {0, len(idx_list) - 1}:
s.append(segments[i])
else:
idx = [0, idx[1] - idx[0]]
s.append(segments[i][idx[0] : idx[1] + 1])

else:
for i in range(len(idx_list) - 1, -1, -1):
if i not in [0, len(idx_list) - 1]:
if i not in {0, len(idx_list) - 1}:
idx = idx_list[i]
nidx = abs(idx[1] - idx[0])
s.append(segments[i][nidx:])
Expand Down
8 changes: 4 additions & 4 deletions ultralytics/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def cache_labels(self, path=Path("./labels.cache")):
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
total = len(self.im_files)
nkpt, ndim = self.data.get("kpt_shape", (0, 0))
if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}):
raise ValueError(
"'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
Expand Down Expand Up @@ -142,7 +142,7 @@ def get_labels(self):

# Display cache
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
if exists and LOCAL_RANK in (-1, 0):
if exists and LOCAL_RANK in {-1, 0}:
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
if cache["msgs"]:
Expand Down Expand Up @@ -235,7 +235,7 @@ def collate_fn(batch):
value = values[i]
if k == "img":
value = torch.stack(value, 0)
if k in ["masks", "keypoints", "bboxes", "cls", "segments", "obb"]:
if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:
value = torch.cat(value, 0)
new_batch[k] = value
new_batch["batch_idx"] = list(new_batch["batch_idx"])
Expand Down Expand Up @@ -334,7 +334,7 @@ def verify_images(self):
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
if LOCAL_RANK in (-1, 0):
if LOCAL_RANK in {-1, 0}:
d = f"{desc} {nf} images, {nc} corrupt"
TQDM(None, desc=d, total=n, initial=n)
if cache["msgs"]:
Expand Down
19 changes: 11 additions & 8 deletions ultralytics/data/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch
from PIL import Image

from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS, FORMATS_HELP_MSG
from ultralytics.utils import LOGGER, is_colab, is_kaggle, ops
from ultralytics.utils.checks import check_requirements

Expand Down Expand Up @@ -83,7 +83,7 @@ def __init__(self, sources="file.streams", vid_stride=1, buffer=False):
for i, s in enumerate(sources): # index, source
# Start thread to read frames from video stream
st = f"{i + 1}/{n}: {s}... "
if urlparse(s).hostname in ("www.youtube.com", "youtube.com", "youtu.be"): # if source is YouTube video
if urlparse(s).hostname in {"www.youtube.com", "youtube.com", "youtu.be"}: # if source is YouTube video
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/LNwODJXcvt4'
s = get_best_youtube_url(s)
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
Expand Down Expand Up @@ -291,8 +291,14 @@ def __init__(self, path, batch=1, vid_stride=1):
else:
raise FileNotFoundError(f"{p} does not exist")

images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS]
videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS]
# Define files as images or videos
images, videos = [], []
for f in files:
suffix = f.split(".")[-1].lower() # Get file extension without the dot and lowercase
if suffix in IMG_FORMATS:
images.append(f)
elif suffix in VID_FORMATS:
videos.append(f)
ni, nv = len(images), len(videos)

self.files = images + videos
Expand All @@ -307,10 +313,7 @@ def __init__(self, path, batch=1, vid_stride=1):
else:
self.cap = None
if self.nf == 0:
raise FileNotFoundError(
f"No images or videos found in {p}. "
f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
)
raise FileNotFoundError(f"No images or videos found in {p}. {FORMATS_HELP_MSG}")

def __iter__(self):
"""Returns an iterator object for VideoStream or ImageFolder."""
Expand Down
2 changes: 1 addition & 1 deletion ultralytics/data/split_dota.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def load_yolo_dota(data_root, split="train"):
- train
- val
"""
assert split in ["train", "val"]
assert split in {"train", "val"}, f"Split must be 'train' or 'val', not {split}."
im_dir = Path(data_root) / "images" / split
assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
im_files = glob(str(Path(data_root) / "images" / split / "*"))
Expand Down
15 changes: 8 additions & 7 deletions ultralytics/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"} # image suffixes
VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
FORMATS_HELP_MSG = f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"


def img2label_paths(img_paths):
Expand All @@ -63,7 +64,7 @@ def exif_size(img: Image.Image):
exif = img.getexif()
if exif:
rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274
if rotation in [6, 8]: # rotation 270 or 90
if rotation in {6, 8}: # rotation 270 or 90
s = s[1], s[0]
return s

Expand All @@ -79,8 +80,8 @@ def verify_image(args):
shape = exif_size(im) # image size
shape = (shape[1], shape[0]) # hw
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
if im.format.lower() in ("jpg", "jpeg"):
assert im.format.lower() in IMG_FORMATS, f"Invalid image format {im.format}. {FORMATS_HELP_MSG}"
if im.format.lower() in {"jpg", "jpeg"}:
with open(im_file, "rb") as f:
f.seek(-2, 2)
if f.read() != b"\xff\xd9": # corrupt JPEG
Expand All @@ -105,8 +106,8 @@ def verify_image_label(args):
shape = exif_size(im) # image size
shape = (shape[1], shape[0]) # hw
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
if im.format.lower() in ("jpg", "jpeg"):
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}. {FORMATS_HELP_MSG}"
if im.format.lower() in {"jpg", "jpeg"}:
with open(im_file, "rb") as f:
f.seek(-2, 2)
if f.read() != b"\xff\xd9": # corrupt JPEG
Expand Down Expand Up @@ -336,7 +337,7 @@ def check_det_dataset(dataset, autodownload=True):
else: # python script
exec(s, {"yaml": data})
dt = f"({round(time.time() - t, 1)}s)"
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in {0, None} else f"failure {dt} ❌"
LOGGER.info(f"Dataset download {s}\n")
check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts

Expand Down Expand Up @@ -366,7 +367,7 @@ def check_cls_dataset(dataset, split=""):
# Download (optional if dataset=https://file.zip is passed directly)
if str(dataset).startswith(("http:/", "https:/")):
dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
elif Path(dataset).suffix in (".zip", ".tar", ".gz"):
elif Path(dataset).suffix in {".zip", ".tar", ".gz"}:
file = check_file(dataset)
dataset = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)

Expand Down
6 changes: 3 additions & 3 deletions ultralytics/engine/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
_callbacks (dict, optional): Dictionary of callback functions. Defaults to None.
"""
self.args = get_cfg(cfg, overrides)
if self.args.format.lower() in ("coreml", "mlmodel"): # fix attempt for protobuf<3.20.x errors
if self.args.format.lower() in {"coreml", "mlmodel"}: # fix attempt for protobuf<3.20.x errors
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # must run before TensorBoard callback

self.callbacks = _callbacks or callbacks.get_default_callbacks()
Expand All @@ -171,9 +171,9 @@ def __call__(self, model=None):
self.run_callbacks("on_export_start")
t = time.time()
fmt = self.args.format.lower() # to lowercase
if fmt in ("tensorrt", "trt"): # 'engine' aliases
if fmt in {"tensorrt", "trt"}: # 'engine' aliases
fmt = "engine"
if fmt in ("mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"): # 'coreml' aliases
if fmt in {"mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"}: # 'coreml' aliases
fmt = "coreml"
fmts = tuple(export_formats()["Argument"][1:]) # available export formats
flags = [x == fmt for x in fmts]
Expand Down
4 changes: 2 additions & 2 deletions ultralytics/engine/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(
return

# Load or create new YOLO model
if Path(model).suffix in (".yaml", ".yml"):
if Path(model).suffix in {".yaml", ".yml"}:
self._new(model, task=task, verbose=verbose)
else:
self._load(model, task=task)
Expand Down Expand Up @@ -666,7 +666,7 @@ def train(
self.trainer.hub_session = self.session # attach optional HUB session
self.trainer.train()
# Update model and cfg after training
if RANK in (-1, 0):
if RANK in {-1, 0}:
ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
self.model, _ = attempt_load_one_weight(ckpt)
self.overrides = self.model.args
Expand Down
4 changes: 2 additions & 2 deletions ultralytics/engine/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def __init__(self, boxes, orig_shape) -> None:
if boxes.ndim == 1:
boxes = boxes[None, :]
n = boxes.shape[-1]
assert n in (6, 7), f"expected 6 or 7 values but got {n}" # xyxy, track_id, conf, cls
assert n in {6, 7}, f"expected 6 or 7 values but got {n}" # xyxy, track_id, conf, cls
super().__init__(boxes, orig_shape)
self.is_track = n == 7
self.orig_shape = orig_shape
Expand Down Expand Up @@ -687,7 +687,7 @@ def __init__(self, boxes, orig_shape) -> None:
if boxes.ndim == 1:
boxes = boxes[None, :]
n = boxes.shape[-1]
assert n in (7, 8), f"expected 7 or 8 values but got {n}" # xywh, rotation, track_id, conf, cls
assert n in {7, 8}, f"expected 7 or 8 values but got {n}" # xywh, rotation, track_id, conf, cls
super().__init__(boxes, orig_shape)
self.is_track = n == 8
self.orig_shape = orig_shape
Expand Down

0 comments on commit ea52750

Please sign in to comment.