Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ In each model repo, the assumption is that the user would already have all of th
### Using `test.py`
`python test.py` will execute the APIs for each model, as a sanity check. For benchmarking, use test_bench.py. It is based on unittest, and supports filtering via CLI.

For instance, to run the BERT model on CPU for the example execution mode:
```
python test.py -k "test_BERT_pytorch_example_cpu"
```

The test name follows the following pattern:

```
"test_" + <model_name> + "_" + {"example" | "train" | "eval" } + "_" + {"cpu" | "cuda"}
```

### Using pytest-benchmark driver
`pytest test_bench.py` invokes the benchmark driver. See `--help` for a complete list of options.

Expand Down
76 changes: 76 additions & 0 deletions collect_graph_ir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/usr/bin/env python
import argparse
import gc
import logging
import os
import re
import warnings

from torchbenchmark import list_models
import torch

NO_JIT = {"demucs", "dlrm", "maml", "yolov3", "moco", "pytorch_CycleGAN_and_pix2pix", "tacotron2"}
NO_GET_MODULE = {"Background_Matting"}

def get_dump_filename(name, device, args):
if args.no_profiling:
return f"{name}.{device}.last_executed_graph.noprofile.log"
if args.inlined_graph:
return f"{name}.{device}.inlined_graph.log"
return f"{name}.{device}.last_executed_graph.log"

def iter_models(args):
device = "cpu"
for benchmark_cls in list_models():
bench_name = benchmark_cls.name
if args.benchmark and args.benchmark != bench_name:
continue
if bench_name in NO_GET_MODULE:
print(f"{bench_name} has no get_module, skipped")
continue
if bench_name in NO_JIT:
print(f"{bench_name} has no scripted module, skipped")
continue
try:
# disable profiling mode so that the collected graph does not contain
# profiling node
if args.no_profiling:
torch._C._jit_set_profiling_mode(False)

benchmark = benchmark_cls(device=device, jit=True)
model, example_inputs = benchmark.get_module()

# extract ScriptedModule object for BERT model
if bench_name == "BERT_pytorch":
model = model.bert

fname = get_dump_filename(bench_name, device, args)
print(f"Dump Graph IR for {bench_name} to {fname}")

# default mode need to warm up ProfileExecutor
if not (args.no_profiling or args.inlined_graph):
model.graph_for(*example_inputs)

with open(fname, 'w') as dump_file:
if args.inlined_graph:
print(model.inlined_graph, file=dump_file)
else:
print(model.graph_for(*example_inputs), file=dump_file)
except NotImplementedError:
print(f"Cannot collect graph IR dump for {bench_name}")
pass

def main(args=None):
parser = argparse.ArgumentParser(description="dump last_executed graph for all benchmarks with JIT implementation")
parser.add_argument("--benchmark", "-b",
help="dump graph for <BENCHMARK>")
parser.add_argument("--no_profiling", action="store_true",
help="dump last_executed graphs w/o profiling executor")
parser.add_argument("--inlined_graph", action="store_true",
help="dump graphs dumped by module.inlined_graph")
args = parser.parse_args(args)

iter_models(args)

if __name__ == '__main__':
main()