Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def get_args_parser(add_help=True):
action="store_true",
)

parser.add_argument(
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
)

# distributed training parameters
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
Expand All @@ -153,6 +157,12 @@ def main(args):

device = torch.device(args.device)

if args.use_deterministic_algorithms:
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
else:
torch.backends.cudnn.benchmark = True

# Data loading code
print("Loading data")

Expand All @@ -162,7 +172,7 @@ def main(args):
print("Creating data loaders")
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
else:
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
Expand Down Expand Up @@ -243,6 +253,9 @@ def main(args):
scaler.load_state_dict(checkpoint["scaler"])

if args.test_only:
# We disable the cudnn benchmarking because it can noticeably affect the accuracy
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
evaluate(model, data_loader_test, device=device)
return

Expand Down
9 changes: 9 additions & 0 deletions references/optical_flow/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,12 @@ def main(args):
raise ValueError("The device must be cuda if we want to run in distributed mode using torchrun")
device = torch.device(args.device)

if args.use_deterministic_algorithms:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid duplicated code with args.test_only, I would suggest to do something like:

Suggested change
if args.use_deterministic_algorithms:
if args.use_deterministic_algorithms or args.test_only:

and to remove these 2 lines

torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

I would also suggest to do that for every references as they seem to follow the same pattern

Copy link
Contributor Author

@YosuaMichael YosuaMichael May 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @NicolasHug , I have just tried what you suggested and got error when I try to use test-only:

  File "/fsx/users/yosuamichael/conda/envs/vision-c113/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 103, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`, but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. For more information, go to https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility

Following the error note, when I set the env var CUBLAS_WORKSPACE_CONFIG=:4096:8 it works, but I think this will create a friction to users.

I think the main issue is that we use torch.use_deterministic_algorithms(True) when we set args.use_deterministic_algorithms and this is much more stricter than torch.backends.cudnn.deterministic = True when we use args.test_only. (see here )

Hence as of now we can only avoid duplicate of the line torch.backends.cudnn.benchmark = False which I think still okay to have 1 line duplicate for now. What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, fair point. I forgot that one is stricter than the other. I think the way you did it is fine then. Sorry for the noise!

torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
else:
torch.backends.cudnn.benchmark = True

model = torchvision.models.optical_flow.__dict__[args.model](weights=args.weights)

if args.distributed:
Expand Down Expand Up @@ -370,6 +376,9 @@ def get_args_parser(add_help=True):

parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")
parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu, Default: cuda)")
parser.add_argument(
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
)

return parser

Expand Down
33 changes: 32 additions & 1 deletion references/segmentation/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import os
import time
import warnings

import presets
import torch
Expand Down Expand Up @@ -61,16 +62,34 @@ def evaluate(model, data_loader, device, num_classes):
confmat = utils.ConfusionMatrix(num_classes)
metric_logger = utils.MetricLogger(delimiter=" ")
header = "Test:"
num_processed_samples = 0
with torch.inference_mode():
for image, target in metric_logger.log_every(data_loader, 100, header):
image, target = image.to(device), target.to(device)
output = model(image)
output = output["out"]

confmat.update(target.flatten(), output.argmax(1).flatten())
# FIXME need to take into account that the datasets
# could have been padded in distributed setup
num_processed_samples += image.shape[0]

confmat.reduce_from_all_processes()

num_processed_samples = utils.reduce_across_processes(num_processed_samples)
if (
hasattr(data_loader.dataset, "__len__")
and len(data_loader.dataset) != num_processed_samples
and torch.distributed.get_rank() == 0
):
# See FIXME above
warnings.warn(
f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
"samples were used for the validation, which might bias the results. "
"Try adjusting the batch size and / or the world size. "
"Setting the world size to 1 is always a safe bet."
)

return confmat


Expand Down Expand Up @@ -108,12 +127,18 @@ def main(args):

device = torch.device(args.device)

if args.use_deterministic_algorithms:
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
else:
torch.backends.cudnn.benchmark = True

dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(True, args))
dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(False, args))

if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
else:
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
Expand Down Expand Up @@ -191,6 +216,9 @@ def main(args):
scaler.load_state_dict(checkpoint["scaler"])

if args.test_only:
# We disable the cudnn benchmarking because it can noticeably affect the accuracy
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
print(confmat)
return
Expand Down Expand Up @@ -261,6 +289,9 @@ def get_args_parser(add_help=True):
help="Only test the model",
action="store_true",
)
parser.add_argument(
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
)
# distributed training parameters
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
Expand Down
24 changes: 13 additions & 11 deletions references/segmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,7 @@ def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
dist.barrier()
dist.all_reduce(t)
t = reduce_across_processes([self.count, self.total])
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
Expand Down Expand Up @@ -92,12 +88,7 @@ def compute(self):
return acc_global, acc, iu

def reduce_from_all_processes(self):
if not torch.distributed.is_available():
return
if not torch.distributed.is_initialized():
return
torch.distributed.barrier()
torch.distributed.all_reduce(self.mat)
reduce_across_processes(self.mat)

def __str__(self):
acc_global, acc, iu = self.compute()
Expand Down Expand Up @@ -296,3 +287,14 @@ def init_distributed_mode(args):
)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)


def reduce_across_processes(val):
if not is_dist_avail_and_initialized():
# nothing to sync, but we still convert to tensor for consistency with the distributed case.
return torch.tensor(val)

t = torch.tensor(val, device="cuda")
dist.barrier()
dist.all_reduce(t)
return t
23 changes: 23 additions & 0 deletions references/similarity/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ def save(model, epoch, save_dir, file_name):

def main(args):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if args.use_deterministic_algorithms:
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
else:
torch.backends.cudnn.benchmark = True

p = args.labels_per_batch
k = args.samples_per_label
batch_size = p * k
Expand Down Expand Up @@ -126,6 +133,13 @@ def main(args):
)
test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.workers)

if args.test_only:
# We disable the cudnn benchmarking because it can noticeably affect the accuracy
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
evaluate(model, test_loader, device)
return

for epoch in range(1, args.epochs + 1):
print("Training...")
train_epoch(model, optimizer, criterion, train_loader, device, epoch, args.print_freq)
Expand Down Expand Up @@ -155,6 +169,15 @@ def parse_args():
parser.add_argument("--print-freq", default=20, type=int, help="print frequency")
parser.add_argument("--save-dir", default=".", type=str, help="Model save directory")
parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
parser.add_argument(
"--test-only",
dest="test_only",
help="Only test the model",
action="store_true",
)
parser.add_argument(
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
)

return parser.parse_args()

Expand Down
37 changes: 35 additions & 2 deletions references/video_classification/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import os
import time
import warnings

import presets
import torch
Expand Down Expand Up @@ -50,6 +51,7 @@ def evaluate(model, criterion, data_loader, device):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = "Test:"
num_processed_samples = 0
with torch.inference_mode():
for video, target in metric_logger.log_every(data_loader, 100, header):
video = video.to(device, non_blocking=True)
Expand All @@ -64,7 +66,28 @@ def evaluate(model, criterion, data_loader, device):
metric_logger.update(loss=loss.item())
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
num_processed_samples += batch_size
# gather the stats from all processes
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
if isinstance(data_loader.sampler, DistributedSampler):
# Get the len of UniformClipSampler inside DistributedSampler
num_data_from_sampler = len(data_loader.sampler.dataset)
else:
num_data_from_sampler = len(data_loader.sampler)

if (
hasattr(data_loader.dataset, "__len__")
and num_data_from_sampler != num_processed_samples
and torch.distributed.get_rank() == 0
):
# See FIXME above
warnings.warn(
f"It looks like the sampler has {num_data_from_sampler} samples, but {num_processed_samples} "
"samples were used for the validation, which might bias the results. "
"Try adjusting the batch size and / or the world size. "
"Setting the world size to 1 is always a safe bet."
)

metric_logger.synchronize_between_processes()

print(
Expand Down Expand Up @@ -99,7 +122,11 @@ def main(args):

device = torch.device(args.device)

torch.backends.cudnn.benchmark = True
if args.use_deterministic_algorithms:
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
else:
torch.backends.cudnn.benchmark = True

# Data loading code
print("Loading data")
Expand Down Expand Up @@ -173,7 +200,7 @@ def main(args):
test_sampler = UniformClipSampler(dataset_test.video_clips, args.clips_per_video)
if args.distributed:
train_sampler = DistributedSampler(train_sampler)
test_sampler = DistributedSampler(test_sampler)
test_sampler = DistributedSampler(test_sampler, shuffle=False)

data_loader = torch.utils.data.DataLoader(
dataset,
Expand Down Expand Up @@ -248,6 +275,9 @@ def main(args):
scaler.load_state_dict(checkpoint["scaler"])

if args.test_only:
# We disable the cudnn benchmarking because it can noticeably affect the accuracy
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
evaluate(model, criterion, data_loader_test, device=device)
return

Expand Down Expand Up @@ -335,6 +365,9 @@ def parse_args():
help="Only test the model",
action="store_true",
)
parser.add_argument(
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
)

# distributed training parameters
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
Expand Down
17 changes: 12 additions & 5 deletions references/video_classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,7 @@ def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
dist.barrier()
dist.all_reduce(t)
t = reduce_across_processes([self.count, self.total])
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
Expand Down Expand Up @@ -255,3 +251,14 @@ def init_distributed_mode(args):
)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)


def reduce_across_processes(val):
if not is_dist_avail_and_initialized():
# nothing to sync, but we still convert to tensor for consistency with the distributed case.
return torch.tensor(val)

t = torch.tensor(val, device="cuda")
dist.barrier()
dist.all_reduce(t)
return t