Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RTDETRDetectionModel TorchScript, ONNX Predict and Val support #8818

Merged
merged 11 commits into from
Mar 9, 2024
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ jobs:
run: |
yolo checks
pip list
- name: Benchmark World DetectionModel
- name: Benchmark YOLOWorld DetectionModel
shell: bash
run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/yolov8s-worldv2.pt' imgsz=160 verbose=0.318
- name: Benchmark SegmentationModel
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def test_export(model, format):
def test_rtdetr(task="detect", model="yolov8n-rtdetr.yaml", data="coco8.yaml"):
"""Test the RTDETR functionality with the Ultralytics framework."""
# Warning: MUST use imgsz=640
run(f"yolo train {task} model={model} data={data} --imgsz= 640 epochs =1, cache = disk") # add coma, spaces to args
run(f"yolo predict {task} model={model} source={ASSETS / 'bus.jpg'} imgsz=640 save save_crop save_txt")
run(f"yolo train {task} model={model} data={data} --imgsz= 160 epochs =1, cache = disk") # add coma, spaces to args
run(f"yolo predict {task} model={model} source={ASSETS / 'bus.jpg'} imgsz=160 save save_crop save_txt")


@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="MobileSAM Clip is not supported in Python 3.12")
Expand Down
2 changes: 0 additions & 2 deletions ultralytics/models/rtdetr/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ def __init__(self, model="rtdetr-l.pt") -> None:
Raises:
NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
"""
if model and Path(model).suffix not in (".pt", ".yaml", ".yml"):
raise NotImplementedError("RT-DETR only supports creating from *.pt, *.yaml, or *.yml files.")
super().__init__(model=model, task="detect")

@property
Expand Down
5 changes: 4 additions & 1 deletion ultralytics/models/rtdetr/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,17 @@
The method filters detections based on confidence and class if specified in `self.args`.

Args:
preds (torch.Tensor): Raw predictions from the model.
preds (list): List of [predictions, extra] from the model.
img (torch.Tensor): Processed input images.
orig_imgs (list or torch.Tensor): Original, unprocessed images.

Returns:
(list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores,
and class labels.
"""
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
preds = [preds, None]

Check warning on line 50 in ultralytics/models/rtdetr/predict.py

View check run for this annotation

Codecov / codecov/patch

ultralytics/models/rtdetr/predict.py#L50

Added line #L50 was not covered by tests

nd = preds[0].shape[-1]
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)

Expand Down
3 changes: 3 additions & 0 deletions ultralytics/models/rtdetr/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@

def postprocess(self, preds):
"""Apply Non-maximum suppression to prediction outputs."""
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
preds = [preds, None]

Check warning on line 98 in ultralytics/models/rtdetr/val.py

View check run for this annotation

Codecov / codecov/patch

ultralytics/models/rtdetr/val.py#L98

Added line #L98 was not covered by tests

bs, _, nd = preds[0].shape
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
bboxes *= self.args.imgsz
Expand Down
2 changes: 1 addition & 1 deletion ultralytics/utils/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def check_file(file, suffix="", download=True, hard=True):
downloads.safe_download(url=url, file=file, unzip=False)
return file
else: # search
files = glob.glob(str(ROOT / "cfg" / "**" / file), recursive=True) # find file
files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file
if not files and hard:
raise FileNotFoundError(f"'{file}' does not exist")
elif len(files) > 1 and hard:
Expand Down
41 changes: 41 additions & 0 deletions ultralytics/utils/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,44 @@
"""Return path to most recent 'last.pt' in /runs (i.e. to --resume from)."""
last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
return max(last_list, key=os.path.getctime) if last_list else ""


def update_models(model_names=("yolov8n.pt",), source_dir=Path("."), update_names=False):
"""
Updates and re-saves specified YOLO models in an 'updated_models' subdirectory.

Args:
model_names (tuple, optional): Model filenames to update, defaults to ("yolov8n.pt").
source_dir (Path, optional): Directory containing models and target subdirectory, defaults to current directory.
update_names (bool, optional): Update model names from a data YAML.

Example:
```python
from ultralytics.utils.files import update_models

model_names = (f"rtdetr-{size}.pt" for size in "lx")
update_models(model_names)
```
"""
from ultralytics import YOLO
from ultralytics.nn.autobackend import default_class_names

Check warning on line 168 in ultralytics/utils/files.py

View check run for this annotation

Codecov / codecov/patch

ultralytics/utils/files.py#L167-L168

Added lines #L167 - L168 were not covered by tests

target_dir = source_dir / "updated_models"
target_dir.mkdir(parents=True, exist_ok=True) # Ensure target directory exists

Check warning on line 171 in ultralytics/utils/files.py

View check run for this annotation

Codecov / codecov/patch

ultralytics/utils/files.py#L170-L171

Added lines #L170 - L171 were not covered by tests

for model_name in model_names:
model_path = source_dir / model_name
print(f"Loading model from {model_path}")

Check warning on line 175 in ultralytics/utils/files.py

View check run for this annotation

Codecov / codecov/patch

ultralytics/utils/files.py#L173-L175

Added lines #L173 - L175 were not covered by tests

# Load model
model = YOLO(model_path)
model.half()
if update_names: # update model names from a dataset YAML
model.model.names = default_class_names("coco8.yaml")

Check warning on line 181 in ultralytics/utils/files.py

View check run for this annotation

Codecov / codecov/patch

ultralytics/utils/files.py#L178-L181

Added lines #L178 - L181 were not covered by tests

# Define new save path
save_path = target_dir / model_name

Check warning on line 184 in ultralytics/utils/files.py

View check run for this annotation

Codecov / codecov/patch

ultralytics/utils/files.py#L184

Added line #L184 was not covered by tests

# Save model using model.save()
print(f"Re-saving {model_name} model to {save_path}")
model.save(save_path, use_dill=False)

Check warning on line 188 in ultralytics/utils/files.py

View check run for this annotation

Codecov / codecov/patch

ultralytics/utils/files.py#L187-L188

Added lines #L187 - L188 were not covered by tests