Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ultralytics 8.1.41 DDP resume untrained-checkpoint fix #9453

Merged
merged 14 commits into from
Apr 1, 2024
2 changes: 1 addition & 1 deletion ultralytics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Ultralytics YOLO πŸš€, AGPL-3.0 license

__version__ = "8.1.40"
__version__ = "8.1.41"

from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
Expand Down
21 changes: 9 additions & 12 deletions ultralytics/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@
# LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
dist.init_process_group(
"nccl" if dist.is_nccl_available() else "gloo",
backend="nccl" if dist.is_nccl_available() else "gloo",
timeout=timedelta(seconds=10800), # 3 hours
rank=RANK,
world_size=world_size,
Expand Down Expand Up @@ -648,8 +648,8 @@

resume = True
self.args = get_cfg(ckpt_args)
self.args.model = str(last) # reinstate model
for k in "imgsz", "batch": # allow arg updates to reduce memory on resume if crashed due to CUDA OOM
self.args.model = self.args.resume = str(last) # reinstate model
for k in "imgsz", "batch", "device": # allow arg updates to reduce memory or update device on resume
if k in overrides:
setattr(self.args, k, overrides[k])

Expand All @@ -662,7 +662,7 @@

def resume_training(self, ckpt):
"""Resume YOLO training from given epoch and best fitness."""
if ckpt is None:
if ckpt is None or not self.resume:
return
best_fitness = 0.0
start_epoch = ckpt.get("epoch", -1) + 1
Expand All @@ -672,14 +672,11 @@
if self.ema and ckpt.get("ema"):
self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
self.ema.updates = ckpt["updates"]
if self.resume:
assert start_epoch > 0, (
f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
)
LOGGER.info(
f"Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs"
)
assert start_epoch > 0, (
f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
)
LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs")

Check warning on line 679 in ultralytics/engine/trainer.py

View check run for this annotation

Codecov / codecov/patch

ultralytics/engine/trainer.py#L679

Added line #L679 was not covered by tests
if self.epochs < start_epoch:
LOGGER.info(
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
Expand Down
2 changes: 1 addition & 1 deletion ultralytics/utils/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=()
try:
t = time.time()
assert is_online(), "AutoUpdate skipped (offline)"
with Retry(times=1, delay=1): # retry once on failure after 1 second
with Retry(times=2, delay=1): # run up to 2 times with 1-second retry delay
LOGGER.info(subprocess.check_output(f"pip install --no-cache {s} {cmds}", shell=True).decode())
dt = time.time() - t
LOGGER.info(
Expand Down
2 changes: 1 addition & 1 deletion ultralytics/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def convert_optimizer_state_dict_to_fp16(state_dict):
"""
for state in state_dict["state"].values():
for k, v in state.items():
if isinstance(v, torch.Tensor) and v.dtype is torch.float32:
if k != "step" and isinstance(v, torch.Tensor) and v.dtype is torch.float32:
state[k] = v.half()

return state_dict
Expand Down