diff --git a/references/detection/train.py b/references/detection/train.py index 758171013e8..d3b394b8bd0 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -132,6 +132,10 @@ def get_args_parser(add_help=True): action="store_true", ) + parser.add_argument( + "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only." + ) + # distributed training parameters parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") @@ -153,6 +157,12 @@ def main(args): device = torch.device(args.device) + if args.use_deterministic_algorithms: + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True) + else: + torch.backends.cudnn.benchmark = True + # Data loading code print("Loading data") @@ -162,7 +172,7 @@ def main(args): print("Creating data loaders") if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) - test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) + test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False) else: train_sampler = torch.utils.data.RandomSampler(dataset) test_sampler = torch.utils.data.SequentialSampler(dataset_test) @@ -243,6 +253,9 @@ def main(args): scaler.load_state_dict(checkpoint["scaler"]) if args.test_only: + # We disable the cudnn benchmarking because it can noticeably affect the accuracy + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True evaluate(model, data_loader_test, device=device) return diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 9b88c83df3a..7c4c45ab275 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -209,6 +209,12 @@ def main(args): raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun") device = torch.device(args.device) + if args.use_deterministic_algorithms: + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True) + else: + torch.backends.cudnn.benchmark = True + model = torchvision.models.optical_flow.__dict__[args.model](weights=args.weights) if args.distributed: @@ -370,6 +376,9 @@ def get_args_parser(add_help=True): parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.") parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu, Default: cuda)") + parser.add_argument( + "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only." + ) return parser diff --git a/references/segmentation/train.py b/references/segmentation/train.py index e8570ab7f69..95dfedb5e9a 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -1,6 +1,7 @@ import datetime import os import time +import warnings import presets import torch @@ -61,6 +62,7 @@ def evaluate(model, data_loader, device, num_classes): confmat = utils.ConfusionMatrix(num_classes) metric_logger = utils.MetricLogger(delimiter=" ") header = "Test:" + num_processed_samples = 0 with torch.inference_mode(): for image, target in metric_logger.log_every(data_loader, 100, header): image, target = image.to(device), target.to(device) @@ -68,9 +70,26 @@ def evaluate(model, data_loader, device, num_classes): output = output["out"] confmat.update(target.flatten(), output.argmax(1).flatten()) + # FIXME need to take into account that the datasets + # could have been padded in distributed setup + num_processed_samples += image.shape[0] confmat.reduce_from_all_processes() + num_processed_samples = utils.reduce_across_processes(num_processed_samples) + if ( + hasattr(data_loader.dataset, "__len__") + and len(data_loader.dataset) != num_processed_samples + and torch.distributed.get_rank() == 0 + ): + # See FIXME above + warnings.warn( + f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} " + "samples were used for the validation, which might bias the results. " + "Try adjusting the batch size and / or the world size. " + "Setting the world size to 1 is always a safe bet." + ) + return confmat @@ -108,12 +127,18 @@ def main(args): device = torch.device(args.device) + if args.use_deterministic_algorithms: + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True) + else: + torch.backends.cudnn.benchmark = True + dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(True, args)) dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(False, args)) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) - test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) + test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False) else: train_sampler = torch.utils.data.RandomSampler(dataset) test_sampler = torch.utils.data.SequentialSampler(dataset_test) @@ -191,6 +216,9 @@ def main(args): scaler.load_state_dict(checkpoint["scaler"]) if args.test_only: + # We disable the cudnn benchmarking because it can noticeably affect the accuracy + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) print(confmat) return @@ -261,6 +289,9 @@ def get_args_parser(add_help=True): help="Only test the model", action="store_true", ) + parser.add_argument( + "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only." + ) # distributed training parameters parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") diff --git a/references/segmentation/utils.py b/references/segmentation/utils.py index 27c8f4ce51e..dfd12726b53 100644 --- a/references/segmentation/utils.py +++ b/references/segmentation/utils.py @@ -30,11 +30,7 @@ def synchronize_between_processes(self): """ Warning: does not synchronize the deque! """ - if not is_dist_avail_and_initialized(): - return - t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") - dist.barrier() - dist.all_reduce(t) + t = reduce_across_processes([self.count, self.total]) t = t.tolist() self.count = int(t[0]) self.total = t[1] @@ -92,12 +88,7 @@ def compute(self): return acc_global, acc, iu def reduce_from_all_processes(self): - if not torch.distributed.is_available(): - return - if not torch.distributed.is_initialized(): - return - torch.distributed.barrier() - torch.distributed.all_reduce(self.mat) + reduce_across_processes(self.mat) def __str__(self): acc_global, acc, iu = self.compute() @@ -296,3 +287,14 @@ def init_distributed_mode(args): ) torch.distributed.barrier() setup_for_distributed(args.rank == 0) + + +def reduce_across_processes(val): + if not is_dist_avail_and_initialized(): + # nothing to sync, but we still convert to tensor for consistency with the distributed case. + return torch.tensor(val) + + t = torch.tensor(val, device="cuda") + dist.barrier() + dist.all_reduce(t) + return t diff --git a/references/similarity/train.py b/references/similarity/train.py index 9c24ce73f3c..146e2bef688 100644 --- a/references/similarity/train.py +++ b/references/similarity/train.py @@ -88,6 +88,13 @@ def save(model, epoch, save_dir, file_name): def main(args): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + if args.use_deterministic_algorithms: + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True) + else: + torch.backends.cudnn.benchmark = True + p = args.labels_per_batch k = args.samples_per_label batch_size = p * k @@ -126,6 +133,13 @@ def main(args): ) test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.workers) + if args.test_only: + # We disable the cudnn benchmarking because it can noticeably affect the accuracy + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + evaluate(model, test_loader, device) + return + for epoch in range(1, args.epochs + 1): print("Training...") train_epoch(model, optimizer, criterion, train_loader, device, epoch, args.print_freq) @@ -155,6 +169,15 @@ def parse_args(): parser.add_argument("--print-freq", default=20, type=int, help="print frequency") parser.add_argument("--save-dir", default=".", type=str, help="Model save directory") parser.add_argument("--resume", default="", type=str, help="path of checkpoint") + parser.add_argument( + "--test-only", + dest="test_only", + help="Only test the model", + action="store_true", + ) + parser.add_argument( + "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only." + ) return parser.parse_args() diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 918a012282e..26c856da878 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -1,6 +1,7 @@ import datetime import os import time +import warnings import presets import torch @@ -50,6 +51,7 @@ def evaluate(model, criterion, data_loader, device): model.eval() metric_logger = utils.MetricLogger(delimiter=" ") header = "Test:" + num_processed_samples = 0 with torch.inference_mode(): for video, target in metric_logger.log_every(data_loader, 100, header): video = video.to(device, non_blocking=True) @@ -64,7 +66,28 @@ def evaluate(model, criterion, data_loader, device): metric_logger.update(loss=loss.item()) metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) + num_processed_samples += batch_size # gather the stats from all processes + num_processed_samples = utils.reduce_across_processes(num_processed_samples) + if isinstance(data_loader.sampler, DistributedSampler): + # Get the len of UniformClipSampler inside DistributedSampler + num_data_from_sampler = len(data_loader.sampler.dataset) + else: + num_data_from_sampler = len(data_loader.sampler) + + if ( + hasattr(data_loader.dataset, "__len__") + and num_data_from_sampler != num_processed_samples + and torch.distributed.get_rank() == 0 + ): + # See FIXME above + warnings.warn( + f"It looks like the sampler has {num_data_from_sampler} samples, but {num_processed_samples} " + "samples were used for the validation, which might bias the results. " + "Try adjusting the batch size and / or the world size. " + "Setting the world size to 1 is always a safe bet." + ) + metric_logger.synchronize_between_processes() print( @@ -99,7 +122,11 @@ def main(args): device = torch.device(args.device) - torch.backends.cudnn.benchmark = True + if args.use_deterministic_algorithms: + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True) + else: + torch.backends.cudnn.benchmark = True # Data loading code print("Loading data") @@ -173,7 +200,7 @@ def main(args): test_sampler = UniformClipSampler(dataset_test.video_clips, args.clips_per_video) if args.distributed: train_sampler = DistributedSampler(train_sampler) - test_sampler = DistributedSampler(test_sampler) + test_sampler = DistributedSampler(test_sampler, shuffle=False) data_loader = torch.utils.data.DataLoader( dataset, @@ -248,6 +275,9 @@ def main(args): scaler.load_state_dict(checkpoint["scaler"]) if args.test_only: + # We disable the cudnn benchmarking because it can noticeably affect the accuracy + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True evaluate(model, criterion, data_loader_test, device=device) return @@ -335,6 +365,9 @@ def parse_args(): help="Only test the model", action="store_true", ) + parser.add_argument( + "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only." + ) # distributed training parameters parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") diff --git a/references/video_classification/utils.py b/references/video_classification/utils.py index 116adf8d72f..024426d5916 100644 --- a/references/video_classification/utils.py +++ b/references/video_classification/utils.py @@ -30,11 +30,7 @@ def synchronize_between_processes(self): """ Warning: does not synchronize the deque! """ - if not is_dist_avail_and_initialized(): - return - t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") - dist.barrier() - dist.all_reduce(t) + t = reduce_across_processes([self.count, self.total]) t = t.tolist() self.count = int(t[0]) self.total = t[1] @@ -255,3 +251,14 @@ def init_distributed_mode(args): ) torch.distributed.barrier() setup_for_distributed(args.rank == 0) + + +def reduce_across_processes(val): + if not is_dist_avail_and_initialized(): + # nothing to sync, but we still convert to tensor for consistency with the distributed case. + return torch.tensor(val) + + t = torch.tensor(val, device="cuda") + dist.barrier() + dist.all_reduce(t) + return t