New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Set up per-operator input database, per-operator microbenchmarking #785
Conversation
benchmarks/common.py
Outdated
@@ -1616,6 +1631,44 @@ def main(runner, original_dir=None): | |||
print_summary(output_filename) | |||
|
|||
|
|||
def log_operator_inputs(model, example_inputs, model_iter_fn, name, args): | |||
output_split = args.output.split("/") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use os.path for filename manipulation.
benchmarks/common.py
Outdated
model_iter_fn(model, example_inputs, collect_outputs=False) | ||
except Exception as e2: | ||
print(f"{name} failed to run with real. Exception: {e2}") | ||
raise e2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
re-raise exception
raise e2 | |
raise |
from torch.utils._pytree import tree_flatten | ||
from torch.utils._pytree import tree_map | ||
|
||
OP_INP_DIRECTORY = os.path.dirname(__file__) + "/operator_inp_logs/" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
os.path.join()
benchmarks/timm_models.py
Outdated
@@ -275,7 +275,8 @@ def _gen_target(self, batch_size, device): | |||
) | |||
|
|||
def compute_loss(self, pred): | |||
return self.loss(pred, self.target) | |||
# calling lift so modes enabled for forward/backward can handle self.target | |||
return self.loss(pred, torch.ops.aten.lift_fresh_copy(self.target)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this effect performance measurements?
Thought about this a little more - instead of storing the .zip files, can we regenerate the contents of the .zip from the runner each time? Storing the .zip in a git repo makes it a bit of a hard to deal with black box (can't diff by line) and also the way we store .zips but not produce the .zip in the runner means there a hidden step - zipping the output. So instead of: We do: A) Run operatorbench What do you think @eellison ? |
I didn't really read the PR, but you could also unzip the zip before checking it in lol |
It takes way too long to generate the inputs, even with fake tensor, for this to really make sense. You don't want to have to wait multiple minutes every time you want to run a script that tests the performance of changing an operator lowering with recorded inputs.
I'm not sure this is really an issue, since no one is going to be line by line comparing the 6 megabytes of operator inputs from TIMM before and after some change.
This would be over 10MB for the three files - seems kind of wasteful when the whole repro size is ~3 MB ( as opposed to ~.4 MB compressed) I don't know if anyone else has any strong thoughts here - the changes talked about here are pretty minimal so we can always land and make changes as we use this more.. |
if isinstance(i, (torch.memory_format, torch.storage.UntypedStorage)): | ||
return True | ||
# TODO: serialize/deserialize sparse arguments | ||
if isinstance(i, torch.Tensor) and i.is_sparse: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @ezyang any idea how hard this is ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, given that you wrote the json format yourself, pretty easy. You'll need to say how many sparse and dense dims, and maybe nnz and coalesced if you want to get frisky
The gists are too big for github to load 😂 |
Elias, something that's not clear from the PR description: the input database is metadata only, right? If so, I think we should design a compact text format for describing this sort of metadata; something like Python code you could eval() to inflate the tensors would be a pretty good start. Then we should feel pretty comfortable with checking these in as plaintext; they're basically like OpInfo sample inputs but machine generated. |
benchmarks/microbenchmarks/operator_inp_logs/hf_train/AlbertForMaskedLM_training.txt
Outdated
Show resolved
Hide resolved
benchmarks/microbenchmarks/operator_inp_logs/hf_train/AlbertForMaskedLM_training.txt
Outdated
Show resolved
Hide resolved
It would be good to get PR feedback from the folks who would also be using the microbenchmarks. |
It would be useful to see some data generate from this. What ops are we the slowest on? It also might make sense to filter out view ops. Views should be represented in the strides of inputs to other ops. |
torch.jit.trace(gm, gm_args), gm_args, copy_outputs=False | ||
) | ||
|
||
repeats = 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just have a fast correctness checking mode? It will improve our improve test coverage, and may also help us to identify if any of those model accuracy error is real.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do as follow up, when I looked into this previously I think there are are a lot of nan-handling errors, which I think the details of are a little bit tricky what pytorch guarantees (will do non-empty input generation to avoid nan errors).
Might get useful to get an operator count minus the ones we're already decomposing. |
Just opened https://github.com/pytorch/torchdynamo/issues/922 and pytorch/pytorch#93636 (still need to benchmark torchbench ops).
Yea I filter those out and constructors in |
For benchmarking we'd also need strides, not just sizes - different strides can result in completely different perf |
@ngimel those are being recorded - see the output. when the tensors are contiguous we omit serializing the strides, otherwise we serialize them: |
qMerge branch 'main' of https://github.com/pytorch/torchdynamo into op_benchmarking
Cool, I've seen a few cases w/o strides so didn't notice that they were recorded when needed. |
g.output(node) | ||
|
||
gm = torch.fx.GraphModule({}, g) | ||
gm, gm_inps = gen_gm_and_inputs(target, args, kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CheckEachNode is called with python_key which is being removed. You need to update this if we want to re-generate the data in the future. I am ok with leaving it as is for this commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I already saw you filing issues found through this approach, so it would be valuable to have the PR in soon and work on follow-ups if needed.
Introduce a mode for benchmarking that will run models and serialize operators and their frequency that is toggled with
--log-operator-inputs
. Example usage:python benchmarks/runner.py --suites=torchbench --training --dtypes=float16 --output=/scratch/eellison/work/torchdynamo/benchmarks/bench_logs/torchbench_train/ --log-operator-inputs
The outputs for torchbench, timm, and huggingface have been included in this PR as
.zip
files.Here are operators and call count for torchbench, huggingface, and timm.
Also introduces a microbench script to compare operators to eager/nvfuser:
Example usage (just running single input for now):
python ./benchmarks/microbenchmarks/operatorbench.py --op=aten.avg_pool2d.default --dtype=float16 --suite=timm
Follow-ups: do sweep on operators we are slow on, prioritize lowerings which are invoked more frequently.