diff --git a/references/classification/utils.py b/references/classification/utils.py index 1a4adc7f60f..61aff1b6dfb 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -403,7 +403,9 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T def reduce_across_processes(val): if not is_dist_avail_and_initialized(): - return val + # nothing to sync, but we still convert to tensor for consistency with the distributed case. + return torch.tensor(val) + t = torch.tensor(val, device="cuda") dist.barrier() dist.all_reduce(t)