From 9f61b17a951d9f247b20a79dae96a86e6f080ed0 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 13 Oct 2021 09:29:52 +0000 Subject: [PATCH 1/5] WIP --- references/classification/train.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/references/classification/train.py b/references/classification/train.py index a71d337a1b4..de520fb34f3 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -54,6 +54,13 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=" model.eval() metric_logger = utils.MetricLogger(delimiter=" ") header = f"Test: {log_suffix}" + def _reduce(val): + val = torch.tensor([val], dtype=torch.int, device="cuda") + torch.distributed.barrier() + torch.distributed.all_reduce(val) + return val.item() + + n_samples = 0 with torch.no_grad(): for image, target in metric_logger.log_every(data_loader, print_freq, header): image = image.to(device, non_blocking=True) @@ -68,7 +75,12 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=" metric_logger.update(loss=loss.item()) metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) + n_samples += batch_size # gather the stats from all processes + + n_samples = _reduce(n_samples) + print(f"We processed {n_samples} in total") + metric_logger.synchronize_between_processes() print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}") @@ -164,7 +176,10 @@ def main(args): device = torch.device(args.device) - torch.backends.cudnn.benchmark = True + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.deterministic = True + train_dir = os.path.join(args.data_path, "train") val_dir = os.path.join(args.data_path, "val") From 5264b1a670107bcb4dc89e83a369f6fd97466ef8 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 13 Oct 2021 12:03:10 +0000 Subject: [PATCH 2/5] i'm not inspired to write a message --- references/classification/train.py | 40 +++++++++++++++++++----------- references/classification/utils.py | 9 +++++++ 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index de520fb34f3..a6666c5dbf2 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -1,6 +1,7 @@ import datetime import os import time +import warnings import presets import torch @@ -54,13 +55,8 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=" model.eval() metric_logger = utils.MetricLogger(delimiter=" ") header = f"Test: {log_suffix}" - def _reduce(val): - val = torch.tensor([val], dtype=torch.int, device="cuda") - torch.distributed.barrier() - torch.distributed.all_reduce(val) - return val.item() - n_samples = 0 + num_processed_samples = 0 with torch.no_grad(): for image, target in metric_logger.log_every(data_loader, print_freq, header): image = image.to(device, non_blocking=True) @@ -75,11 +71,19 @@ def _reduce(val): metric_logger.update(loss=loss.item()) metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) - n_samples += batch_size + num_processed_samples += batch_size # gather the stats from all processes - n_samples = _reduce(n_samples) - print(f"We processed {n_samples} in total") + if torch.distributed.is_initialized(): + # See FIXME above + num_processed_samples = utils.reduce_across_processes(num_processed_samples) + if hasattr(data_loader.dataset, "__len__") and len(data_loader.dataset) != num_processed_samples: + warnings.warn( + f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} " + "samples were used for the validation, which might bias the results. " + "Try adjusting the batch size and / or the world size. " + "Setting the world size to 1 is always a safe bet." + ) metric_logger.synchronize_between_processes() @@ -159,7 +163,7 @@ def load_data(traindir, valdir, args): print("Creating data loaders") if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) - test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) + test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False) else: train_sampler = torch.utils.data.RandomSampler(dataset) test_sampler = torch.utils.data.SequentialSampler(dataset_test) @@ -176,10 +180,11 @@ def main(args): device = torch.device(args.device) - torch.backends.cudnn.benchmark = False - torch.use_deterministic_algorithms(True) - torch.backends.cudnn.deterministic = True - + if args.use_deterministic_algorithms: + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True) + else: + torch.backends.cudnn.benchmark = True train_dir = os.path.join(args.data_path, "train") val_dir = os.path.join(args.data_path, "val") @@ -292,6 +297,10 @@ def main(args): model_ema.load_state_dict(checkpoint["model_ema"]) if args.test_only: + # We disable the cudnn benchmarking because it can noticeably affect the accuracy + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + evaluate(model, criterion, data_loader_test, device=device) return @@ -409,6 +418,9 @@ def get_args_parser(add_help=True): default=0.9, help="decay factor for Exponential Moving Average of model parameters(default: 0.9)", ) + parser.add_argument( + "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only." + ) return parser diff --git a/references/classification/utils.py b/references/classification/utils.py index 5dbb6b8fd24..e8b036cdc56 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -400,3 +400,12 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T os.replace(tmp_path, output_path) return output_path + + +def reduce_across_processes(val): + if not torch.cuda.is_available(): + return val + val = torch.tensor([val], device="cuda") + dist.barrier() + dist.all_reduce(val) + return val.item() From 40a42ed43b914e792c85e511b09e2fc2ef302734 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 13 Oct 2021 12:42:57 +0000 Subject: [PATCH 3/5] avoid some duplication --- references/classification/train.py | 17 ++++++++--------- references/classification/utils.py | 18 +++++++----------- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index a6666c5dbf2..59ba08d2e16 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -74,16 +74,15 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=" num_processed_samples += batch_size # gather the stats from all processes - if torch.distributed.is_initialized(): + num_processed_samples = utils.reduce_across_processes(num_processed_samples).item() + if hasattr(data_loader.dataset, "__len__") and len(data_loader.dataset) != num_processed_samples: # See FIXME above - num_processed_samples = utils.reduce_across_processes(num_processed_samples) - if hasattr(data_loader.dataset, "__len__") and len(data_loader.dataset) != num_processed_samples: - warnings.warn( - f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} " - "samples were used for the validation, which might bias the results. " - "Try adjusting the batch size and / or the world size. " - "Setting the world size to 1 is always a safe bet." - ) + warnings.warn( + f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} " + "samples were used for the validation, which might bias the results. " + "Try adjusting the batch size and / or the world size. " + "Setting the world size to 1 is always a safe bet." + ) metric_logger.synchronize_between_processes() diff --git a/references/classification/utils.py b/references/classification/utils.py index e8b036cdc56..c7f9af94518 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -32,11 +32,7 @@ def synchronize_between_processes(self): """ Warning: does not synchronize the deque! """ - if not is_dist_avail_and_initialized(): - return - t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") - dist.barrier() - dist.all_reduce(t) + t = reduce_across_processes([self.count, self.total]) t = t.tolist() self.count = int(t[0]) self.total = t[1] @@ -402,10 +398,10 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T return output_path -def reduce_across_processes(val): - if not torch.cuda.is_available(): - return val - val = torch.tensor([val], device="cuda") +def reduce_across_processes(l): + if not is_dist_avail_and_initialized(): + return l + t = torch.tensor(l, device="cuda") dist.barrier() - dist.all_reduce(val) - return val.item() + dist.all_reduce(t) + return t From 4985725fedc4c0befff5534e7e660227037749e4 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 13 Oct 2021 12:52:39 +0000 Subject: [PATCH 4/5] Only warn on rank == 0 --- references/classification/train.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 59ba08d2e16..9b1994bad57 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -74,8 +74,12 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=" num_processed_samples += batch_size # gather the stats from all processes - num_processed_samples = utils.reduce_across_processes(num_processed_samples).item() - if hasattr(data_loader.dataset, "__len__") and len(data_loader.dataset) != num_processed_samples: + num_processed_samples = utils.reduce_across_processes(num_processed_samples) + if ( + hasattr(data_loader.dataset, "__len__") + and len(data_loader.dataset) != num_processed_samples + and torch.distributed.get_rank() == 0 + ): # See FIXME above warnings.warn( f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} " From eaa0536b89a4c76ec66260be0cd94e5f5e5f2044 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 13 Oct 2021 12:56:14 +0000 Subject: [PATCH 5/5] hopefully fix flake8 --- references/classification/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/references/classification/utils.py b/references/classification/utils.py index c7f9af94518..c186a60fc1e 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -398,10 +398,10 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T return output_path -def reduce_across_processes(l): +def reduce_across_processes(val): if not is_dist_avail_and_initialized(): - return l - t = torch.tensor(l, device="cuda") + return val + t = torch.tensor(val, device="cuda") dist.barrier() dist.all_reduce(t) return t