Skip to content

Commit

Permalink
Feature:To add --tolerance option to benchmark scripts (#102218)
Browse files Browse the repository at this point in the history
The "tolerance" option evaluates the model on the baseline device in eager mode (default: CPU) compared to the test device (e.g., CUDA, XLA, etc.) and compares the output tensors to determine the absolute tolerance value based on the [formula](https://pytorch.org/docs/stable/generated/torch.allclose.html). It then saves the results in a CSV file. This comparison highlights the tolerance/accuracy difference between XLA and GPU/CPU devices and can also be used to evaluate newer accelerators. This feature aims to identify accuracy failures on the test device (e.g., XLA) and facilitate quick bug triaging.

This feature enables the following capabilities:
1. Ability to monitor accuracy issues of backends
2. Provide more informative picture on accuracy beyond pass/ fail status
3. Having a dump of accuracy information will help triage models accordingly

The data generated using this feature is in the [spreadsheet](https://docs.google.com/spreadsheets/d/1A8BAzSqfAw0Q5rgzK5Gk__Uy7qhuynh8tedxKnH-t94/edit#gid=0).

The spreadsheet data can be used to compile the below summary table:

| Suite                     | Max Tolerance                |          | No. of models with high inaccuracy(>=0.005) |          | Mean Tolerance |          |
|------------------ |:-------------:|:--------:|:-------------------------------------------:|:--------:|:--------------:|:--------:|
|                             |      xla           | inductor      |                     xla     | inductor |                                                xla      | inductor |
| huggingface       |        0.1169  |   0.0032      |                            1 |        0 |                                                   0.0022 |   0.0005 |
| timm_models     |        0.0373 |   2.8892      |                          10 |        8 |                                                   0.0028 |   0.7044 |
| torchbench        |         3.013   |   3.0381       |                            6 |        2 |                                                    0.0016 |   0.0016 |
| All models          |         3.013   |   3.0381      |                           17 |       10 |                                                  0.0028 |   0.7044 |

I used PyTorch release/2.0 branch and corresponding [commit_pin](https://github.com/pytorch/pytorch/blob/release/2.0/.github/ci_commit_pins/xla.txt) for XLA to generate the above data.

Fixes #ISSUE_NUMBER

Pull Request resolved: #102218
Approved by: https://github.com/jansel
  • Loading branch information
vinayburugu authored and pytorchmergebot committed Jun 3, 2023
1 parent 1237502 commit 8215468
Showing 1 changed file with 93 additions and 1 deletion.
94 changes: 93 additions & 1 deletion benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,6 +1531,89 @@ def deepcopy_and_maybe_ddp(model):

return record_status(accuracy_status, dynamo_start_stats=start_stats)

def check_tolerance(
self, name, model, example_inputs, optimize_ctx, base_device="cpu"
):
"""
Checks tolerance based on https://pytorch.org/docs/stable/generated/torch.allclose.html.
"""
tolerance_status = "pass"
if name in self.skip_accuracy_checks_large_models_dashboard:
tolerance_status = "pass_due_to_skip"
return tolerance_status
# Cast the model to float16/float32 as necessary
model, example_inputs = self.maybe_cast(model, example_inputs)

with self.pick_grad(name, self.args.training):
# Get results of native pytorch
reset_rng_state()
model_copy = copy.deepcopy(model)
model_copy = model_copy.to(base_device)
example_inputs_copy = copy.deepcopy(example_inputs)
example_inputs_copy = tree_map(
lambda x: x.to(base_device), example_inputs_copy
)
self.init_optimizer(name, base_device, model_copy.parameters())
correct_result = self.run_n_iterations(model_copy, example_inputs_copy)

# Run with Dynamo
# Sometime CI fails with random triton compilation failure which will be skipped for now
# TODO: revisit this after switching to new Triton runtime
reset_rng_state()
torch._dynamo.reset()
try:
self.init_optimizer(name, current_device, model.parameters())
optimized_model_iter_fn = optimize_ctx(self.run_n_iterations)
new_result = optimized_model_iter_fn(model, example_inputs)
except Exception as e:
log.exception(e)
if (
self.args.ci
and isinstance(e, BackendCompilerFailed)
and (
"Internal Triton PTX codegen error" in str(e)
or "cubin" in str(e)
)
):
return "pass_due_to_skip"
else:
print(
"TorchDynamo optimized model failed to run because of following error"
)
return "fail_to_run"

def dump_max_mean_values(tol, ref, res):
if isinstance(ref, (list, tuple, torch.nn.ParameterList, torch.Size)):
for refi, resi in zip(ref, res):
dump_max_mean_values(tol, refi, resi)
elif isinstance(ref, dict):
for k in ref.keys():
dump_max_mean_values(tol, ref[k], res[k])
elif isinstance(ref, torch.Tensor):
res = res.to(base_device)
t = torch.abs(ref - res) / (1 + torch.abs(ref))
tol.append(t.flatten().to(torch.float32))
return tol

tol = []
dump_max_mean_values(tol, correct_result, new_result)
tol = torch.cat(tol)
tol = torch.tensor(tol)
max = torch.max(tol)
mean = torch.mean(tol)
div = torch.std(tol)
headers = ["dev", "name", "batch_size", "max", "mean", "std"]
fields = [
current_device,
current_name,
current_batch_size,
max.item(),
mean.item(),
div.item(),
]
output_csv(output_filename, headers, fields)
return tolerance_status

def run_performance_test(
self, name, model, example_inputs, optimize_ctx, experiment, tag=None
):
Expand Down Expand Up @@ -1644,6 +1727,9 @@ def run_one_model(
name, model, example_inputs, optimize_ctx, experiment, tag
)
print(status)
elif self.args.tolerance:
status = self.check_tolerance(name, model, example_inputs, optimize_ctx)
print(status)
elif self.args.performance:
status = self.run_performance_test(
name, model, example_inputs, optimize_ctx, experiment, tag
Expand Down Expand Up @@ -2132,7 +2218,11 @@ def get_example_inputs(self):
mode_group.add_argument(
"--performance", action="store_true", help="Measures performance speedup"
)

mode_group.add_argument(
"--tolerance",
action="store_true",
help="extracts the tolerance for each model with small batch size and eval mode",
)
run_mode_group = parser.add_mutually_exclusive_group(required=True)
run_mode_group.add_argument(
"--training",
Expand Down Expand Up @@ -2445,6 +2535,8 @@ def run(runner, args, original_dir=None):
experiment = speedup_experiment
if args.accuracy:
output_filename = f"accuracy_{args.backend}.csv"
elif args.tolerance:
output_filename = f"tolerance_{args.backend}.csv"
else:
output_filename = f"speedup_{args.backend}.csv"
elif args.recompile_profiler:
Expand Down

0 comments on commit 8215468

Please sign in to comment.