Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ parent of the `xla` directory.

The following example runs the alexnet benchmark on GPU through the
Pytorch/XLA-dynamo path and through the Inductor-dynamo with 5 repetitions each.
The results will be stored in a json file in `experiment_results`.
The results will be stored in a json file (eg results.jsonl) in `experiment_results`.

```
cd pytorch
Expand Down Expand Up @@ -74,6 +74,22 @@ among the flags `--dynamo`, `--xla`, and `--test`, 4 of which are supported:
- `dynamo=inductor`, `xla=None`, `test=train`


## Run benchmarking for a single configuration

The section `Experiment runner` above shows how to run the benchmarking script for a combination of configurations. For each configuration,
the script starts a process and run the benchmarking. This section shows how to run the benchmarking for a single configuration without spawning new processes.

```
cd pytorch
python xla/benchmarks/experiment_runner.py \
--suite-name=torchbench \
--progress-bar \
--model-config='{"model_name":"BERT_pytorch"}' \
--experiment-config='{"accelerator":"cuda","xla":"PJRT","xla_flags":null,"dynamo":"openxla","torch_xla2":null,"test":"train","keep_model_data_on_cuda":"false"}' \
--repeat 1
```


## Verification module

Verification flag, enabled by running the experiment runner script with `--verify`
Expand Down Expand Up @@ -116,8 +132,8 @@ PT/XLA, and compare it against some basline.
Run the `result_analyzer.py` from the `pytorch` directory, which should be the
parent of the `xla` directory.

The following example analyzes the results generated by the above invocation of
`experiment_runner.py`. The aggregates are saved in CSV format in
The following example analyzes the results (eg results.jsonl) generated by the above invocation of
`experiment_runner.py`. So make sure to use consistent `--output-dirname` parameter. The aggregates are saved in CSV format in
`experiment_results/metric_report.csv`.

```
Expand Down
18 changes: 16 additions & 2 deletions benchmarks/benchmark_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def list_experiment_configs(self):
"dynamo": [None, "inductor", "openxla_eval", "openxla"],
"torch_xla2": [None], # options only apply to torch_xla2
"test": ["eval", "train"],
"keep_model_data_on_cuda": [False],
}

# Apply command line choices.
Expand All @@ -42,6 +43,10 @@ def list_experiment_configs(self):
if self._args.xla_flags:
config_choices["xla_flags"] = list(
map(parse_none_str, set(self._args.xla_flags)))
if self._args.keep_model_data_on_cuda:
config_choices["keep_model_data_on_cuda"] = [
self._args.keep_model_data_on_cuda
]

# Expand experiment configs and add env vars.
logger.debug(f"Expand experiment configs")
Expand Down Expand Up @@ -71,6 +76,7 @@ def _is_available(self, experiment_config):
cfg_xla = experiment_config["xla"]
cfg_test = experiment_config["test"]
cfg_torch_xla2 = experiment_config["torch_xla2"]
cfg_keep_model_data_on_cuda = experiment_config["keep_model_data_on_cuda"]

# Check that dynamo refers to an existing backend.
if cfg_dynamo is not None and cfg_dynamo not in dynamo.list_backends(
Expand Down Expand Up @@ -110,6 +116,10 @@ def _is_available(self, experiment_config):
else:
raise NotImplementedError

# cfg_keep_model_data_on_cuda is only avaible when using dynamo
if cfg_keep_model_data_on_cuda and cfg_dynamo != "openxla":
return False

return True

def load_experiment(self, experiment_config):
Expand All @@ -120,25 +130,28 @@ def load_experiment(self, experiment_config):
test = experiment_config["test"]
batch_size = experiment_config.get("batch_size", self._args.batch_size)
torch_xla2 = experiment_config["torch_xla2"]
keep_model_data_on_cuda = experiment_config["keep_model_data_on_cuda"]
return BenchmarkExperiment(
accelerator=accelerator,
xla=xla,
xla_flags=xla_flags,
dynamo=dynamo,
torch_xla2=torch_xla2,
keep_model_data_on_cuda=keep_model_data_on_cuda,
test=test,
batch_size=batch_size)


class BenchmarkExperiment:

def __init__(self, accelerator, xla, xla_flags, dynamo, torch_xla2, test,
batch_size):
def __init__(self, accelerator, xla, xla_flags, dynamo, torch_xla2,
keep_model_data_on_cuda: bool, test, batch_size):
self.accelerator = accelerator
self.xla = xla
self.xla_flags = xla_flags
self.dynamo = dynamo
self.torch_xla2 = torch_xla2
self.keep_model_data_on_cuda = keep_model_data_on_cuda
self.test = test
self.batch_size = batch_size
self.accelerator_model = get_accelerator_model(self.accelerator)
Expand Down Expand Up @@ -202,6 +215,7 @@ def to_dict(self):
d["xla_flags"] = self.xla_flags
d["dynamo"] = self.dynamo
d["torch_xla2"] = self.torch_xla2
d["keep_model_data_on_cuda"] = self.keep_model_data_on_cuda
d["test"] = self.test
d["batch_size"] = self.batch_size
return d
15 changes: 10 additions & 5 deletions benchmarks/benchmark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def prepare_for_experiment(self, dynamo_compilation_opts):
else:
raise NotImplementedError

keep_model_data_on_cuda = self.benchmark_experiment.keep_model_data_on_cuda
if self.benchmark_experiment.torch_xla2:
import torch_xla2.export
import torch_xla2
Expand All @@ -141,12 +142,12 @@ def prepare_for_experiment(self, dynamo_compilation_opts):
weights)
jax_func = jax.jit(jax_func)
self.module = lambda *x: jax_func(weights, x)
self.example_inputs = move_to_device(self.example_inputs, device,
self.benchmark_experiment.torch_xla2)
else:
self.example_inputs = move_to_device(
self.example_inputs, device, torch_xla2=True)
elif not keep_model_data_on_cuda:
self.module = self.module.to(self.device)
self.example_inputs = move_to_device(self.example_inputs, self.device,
self.benchmark_experiment.torch_xla2)
self.example_inputs = move_to_device(
self.example_inputs, self.device, torch_xla2=False)

if self.benchmark_experiment.dynamo:
compilation_opts = dynamo_compilation_opts.copy()
Expand All @@ -155,6 +156,10 @@ def prepare_for_experiment(self, dynamo_compilation_opts):
logger.info(f"Running torch.compile with opts {compilation_opts}")
self.model_iter_fn = torch.compile(self.model_iter_fn, **compilation_opts)

if keep_model_data_on_cuda:
assert self.example_inputs[0].device.type.lower(
) == 'cuda', 'When keep_model_data_on_cuda is set, the input data should remain on the CUDA device.'

def pick_grad(self):
if self.benchmark_experiment.test == "eval":
return torch.no_grad()
Expand Down
5 changes: 5 additions & 0 deletions benchmarks/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,11 @@ def __str__(self):
help="""Collect CUDA and CPU times per operation. This will also gather
CPU fallbacks.""",
)
parser.add_argument(
"--keep-model-data-on-cuda",
action="store_true",
help="""Whether to keep the model and data on CUDA and not to move to an XLA device. This is to be used with PyTorch/XLA dynamo. When set, PyTorch/XLA dynamo bridge move the model and data to the XLA device.""",
)
parser.add_argument(
"--xla-flags",
type=str,
Expand Down
52 changes: 37 additions & 15 deletions benchmarks/result_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def run_csv(self):
"xla_flags": pd.Series(dtype="str"),
"dynamo": pd.Series(dtype="str"),
"torch_xla2": pd.Series(dtype="str"),
"keep_model_data_on_cuda": pd.Series(dtype="bool"),
"test": pd.Series(dtype="str"),
"batch_size": pd.Series(dtype="int"),
"repeat": pd.Series(dtype="int"),
Expand Down Expand Up @@ -119,6 +120,9 @@ def extract_metrics_jsonl(self, file):
dynamo_value = "None" if dynamo is None else dynamo
torch_xla2 = dataline["experiment"]["torch_xla2"]
torch_xla2_value = "None" if torch_xla2 is None else torch_xla2
keep_model_data_on_cuda = dataline["experiment"][
"keep_model_data_on_cuda"]
keep_model_data_on_cuda_value = "None" if keep_model_data_on_cuda is None else keep_model_data_on_cuda
test = dataline["experiment"]["test"]
test_value = "None" if test is None else test
outputs_file = dataline["experiment"].get("outputs_file", None)
Expand All @@ -139,6 +143,7 @@ def extract_metrics_jsonl(self, file):
"xla": xla_value,
"dynamo": dynamo_value,
"torch_xla2": torch_xla2_value,
"keep_model_data_on_cuda": keep_model_data_on_cuda_value,
"test": test_value,
"outputs_file": outputs_file_value
}
Expand Down Expand Up @@ -171,21 +176,38 @@ def extract_metrics_csv(self, file, metric_df):
timestamp = dataline[
"timestamp"] if "timestamp" in dataline else self.timestamp
d = {
"timestamp": timestamp,
"suite_name": dataline["model"]["suite_name"],
"model_name": dataline["model"]["model_name"],
"accelerator": dataline["experiment"]["accelerator"],
"accelerator_model": dataline["experiment"]["accelerator_model"],
"xla": dataline["experiment"]["xla"],
"xla_flags": dataline["experiment"]["xla_flags"],
"dynamo": dataline["experiment"]["dynamo"],
"torch_xla2": dataline["experiment"]["torch_xla2"],
"test": dataline["experiment"]["test"],
"batch_size": dataline["experiment"]["batch_size"],
"repeat": dataline["repeat"],
"iterations_per_run": dataline["iterations_per_run"],
"error_message": None,
"outputs_file": dataline["experiment"].get("outputs_file", ""),
"timestamp":
timestamp,
"suite_name":
dataline["model"]["suite_name"],
"model_name":
dataline["model"]["model_name"],
"accelerator":
dataline["experiment"]["accelerator"],
"accelerator_model":
dataline["experiment"]["accelerator_model"],
"xla":
dataline["experiment"]["xla"],
"xla_flags":
dataline["experiment"]["xla_flags"],
"dynamo":
dataline["experiment"]["dynamo"],
"torch_xla2":
dataline["experiment"]["torch_xla2"],
"keep_model_data_on_cuda":
dataline["experiment"]["keep_model_data_on_cuda"],
"test":
dataline["experiment"]["test"],
"batch_size":
dataline["experiment"]["batch_size"],
"repeat":
dataline["repeat"],
"iterations_per_run":
dataline["iterations_per_run"],
"error_message":
None,
"outputs_file":
dataline["experiment"].get("outputs_file", ""),
}

if "error" in dataline["metrics"] and not self._args.hide_errors:
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/torchbench_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ def set_up(self):
if self.benchmark_experiment.xla:
# First, move the model and the inputs to CPU.
# This avoids having dupplicated data on CUDA.
if self.is_accelerator_cuda():
keep_model_data_on_cuda = self.benchmark_experiment.keep_model_data_on_cuda
if self.is_accelerator_cuda() and not keep_model_data_on_cuda:
self.module = self.module.to("cpu")
self.example_inputs = move_to_device(self.example_inputs, "cpu")
self._cleanup()
Expand Down
5 changes: 3 additions & 2 deletions test/benchmarks/test_benchmark_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@ class BenchmarkExperimentTest(unittest.TestCase):

def test_to_dict(self):
be = BenchmarkExperiment("cpu", "PJRT", "some xla_flags", "openxla", None,
"train", "123")
False, "train", "123")
actual = be.to_dict()
self.assertEqual(8, len(actual))
self.assertEqual(9, len(actual))
self.assertEqual("cpu", actual["accelerator"])
self.assertTrue("accelerator_model" in actual)
self.assertEqual("PJRT", actual["xla"])
self.assertEqual("some xla_flags", actual["xla_flags"])
self.assertEqual("openxla", actual["dynamo"])
self.assertEqual(None, actual["torch_xla2"])
self.assertEqual(False, actual["keep_model_data_on_cuda"])
self.assertEqual("train", actual["test"])
self.assertEqual("123", actual["batch_size"])

Expand Down
Loading