diff --git a/references/detection/engine.py b/references/detection/engine.py index 9f34336b0cc..49992af60a9 100644 --- a/references/detection/engine.py +++ b/references/detection/engine.py @@ -84,7 +84,8 @@ def evaluate(model, data_loader, device): for images, targets in metric_logger.log_every(data_loader, 100, header): images = list(img.to(device) for img in images) - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.synchronize() model_time = time.time() outputs = model(images)