Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix="
header = f"Test: {log_suffix}"

num_processed_samples = 0
with torch.no_grad():
with torch.inference_mode():
for image, target in metric_logger.log_every(data_loader, print_freq, header):
image = image.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
Expand Down
2 changes: 1 addition & 1 deletion references/classification/train_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def main(args):
print("Starting training for epoch", epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq)
lr_scheduler.step()
with torch.no_grad():
with torch.inference_mode():
if epoch >= args.num_observer_update_epochs:
print("Disabling observer for subseq epochs, epoch = ", epoch)
model.apply(torch.quantization.disable_observer)
Expand Down
2 changes: 1 addition & 1 deletion references/classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def update_parameters(self, model):

def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
with torch.inference_mode():
maxk = max(topk)
batch_size = target.size(0)
if target.ndim == 2:
Expand Down
2 changes: 1 addition & 1 deletion references/detection/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _get_iou_types(model):
return iou_types


@torch.no_grad()
@torch.inference_mode()
def evaluate(model, data_loader, device):
n_threads = torch.get_num_threads()
# FIXME remove this and make paste_masks_in_image run on the GPU
Expand Down
2 changes: 1 addition & 1 deletion references/detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def reduce_dict(input_dict, average=True):
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
with torch.inference_mode():
names = []
values = []
# sort the keys so that they are consistent across processes
Expand Down
2 changes: 1 addition & 1 deletion references/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def evaluate(model, data_loader, device, num_classes):
confmat = utils.ConfusionMatrix(num_classes)
metric_logger = utils.MetricLogger(delimiter=" ")
header = "Test:"
with torch.no_grad():
with torch.inference_mode():
for image, target in metric_logger.log_every(data_loader, 100, header):
image, target = image.to(device), target.to(device)
output = model(image)
Expand Down
2 changes: 1 addition & 1 deletion references/segmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def update(self, a, b):
n = self.num_classes
if self.mat is None:
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
with torch.no_grad():
with torch.inference_mode():
k = (a >= 0) & (a < n)
inds = n * a[k].to(torch.int64) + b[k]
self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n)
Expand Down
2 changes: 1 addition & 1 deletion references/similarity/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def find_best_threshold(dists, targets, device):
return best_thresh, accuracy


@torch.no_grad()
@torch.inference_mode()
def evaluate(model, loader, device):
model.eval()
embeds, labels = [], []
Expand Down
2 changes: 1 addition & 1 deletion references/video_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def evaluate(model, criterion, data_loader, device):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = "Test:"
with torch.no_grad():
with torch.inference_mode():
for video, target in metric_logger.log_every(data_loader, 100, header):
video = video.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
Expand Down
2 changes: 1 addition & 1 deletion references/video_classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def log_every(self, iterable, print_freq, header=None):

def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
with torch.inference_mode():
maxk = max(topk)
batch_size = target.size(0)

Expand Down