Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update on "Fast path binary ops in fake tensor"
Fast path execution of a few binary ops in fake tensor, to speed up trace time. When testing `python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18`, I get the following trace speedup. Before: ``` cuda eval hrnet_w18 PASS TIMING: entire_frame_compile:53.97591 backend_compile:33.60832 STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:89985 | ProxyTorchDispatchMode.__torch_dispatch__:3010 ``` After: ``` cuda eval hrnet_w18 PASS TIMING: entire_frame_compile:40.18931 backend_compile:25.28828 STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:69478 | attempt fast:4399 | fast is_contiguous:4399 | ProxyTorchDispatchMode.__torch_dispatch__:3010 ``` My experiment notebook can be found at https://docs.google.com/document/d/1_dTIQUwjIVnEWmiFAavJQYVF8uzXqD9Dk6b9gGQLF_U/edit# This is not the "most" optimized version of the code; compared with Horace/Voz roofline experiment: ``` diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index e3bf545f3b8..395942c6ffe 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -774,6 +774,10 @@ class FakeTensorMode(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = kwargs if kwargs else {} + with no_dispatch(): + if func in {aten.mul.Tensor, aten.add.Tensor, aten.sub.Tensor, aten.relu.default}: + return FakeTensor(self, torch.empty(args[0].shape, device='meta'), device='cuda') + if func == torch.ops.prim.device.default: assert len(args) == 1 and isinstance(args[0], FakeTensor) if args[0].fake_mode.in_kernel_invocation: ``` I am still leaving about 10s of trace time improvement on the table (5s of which is attributable to not yet handling relu.) The implementation here is based off of #93118 but I modeled the short circuit logic off of TensorIterator's implementation, for ease of code review and correctness verification. However, there are some important divergences: * Traditional fast setup in TensorIterator only short circuits if the shapes of all input elements are equal. On hrnet_w18, only 5% of fastpath'ed binary operators actually satisfy this. So instead, I compute the broadcasted shape, but then I only allow the fast path through if at least one of the input operands matches the broadcasted shape exactly (the idea being that we will probably use that tensor's layout.) I am pretty sure this is not sound, but I need to check tests to see how unsound it is. * I had to manually adjust the logic to handle wrapped numbers (which ordinarily are handled by wrapping into tensors). I think I got this right. I intend to verify whether or not the new algorithm is correct using Z3. Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
- Loading branch information