Skip to content

Commit

Permalink
Update on "Fast path binary ops in fake tensor"
Browse files Browse the repository at this point in the history
Fast path execution of a few binary ops in fake tensor, to speed up trace time. When testing `python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18`, I get the following trace speedup.

Before:

```
cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:53.97591 backend_compile:33.60832
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:89985 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```

After:

```
cuda eval  hrnet_w18                           PASS
TIMING: entire_frame_compile:40.18931 backend_compile:25.28828
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:69478 | attempt fast:4399 | fast is_contiguous:4399 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```

My experiment notebook can be found at https://docs.google.com/document/d/1_dTIQUwjIVnEWmiFAavJQYVF8uzXqD9Dk6b9gGQLF_U/edit#

This is not the "most" optimized version of the code; compared with Horace/Voz roofline experiment:

```
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index e3bf545f3b8..395942c6ffe 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -774,6 +774,10 @@ class FakeTensorMode(TorchDispatchMode):
     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
         kwargs = kwargs if kwargs else {}
 
+        with no_dispatch():
+            if func in {aten.mul.Tensor, aten.add.Tensor, aten.sub.Tensor, aten.relu.default}:
+                return FakeTensor(self, torch.empty(args[0].shape, device='meta'), device='cuda')
+
         if func == torch.ops.prim.device.default:
             assert len(args) == 1 and isinstance(args[0], FakeTensor)
             if args[0].fake_mode.in_kernel_invocation:
```

I am still leaving about 5s of trace time improvement on the table (3s of which is attributable to not yet handling relu.)

The implementation here is based off of #93118 but I modeled the short circuit logic off of TensorIterator's implementation, for ease of code review and correctness verification. However, there are some important divergences:

* Traditional fast setup in TensorIterator only short circuits if the shapes of all input elements are equal. On hrnet_w18, only 5% of fastpath'ed binary operators actually satisfy this. So instead, I compute the broadcasted shape, but then I only allow the fast path if (1) at least one input tensor has a shape that is exactly the output size, and (2) all the tensors are contiguous (or if all the tensors are channels last).
* I had to manually adjust the logic to handle wrapped numbers (which ordinarily are handled by wrapping into tensors). I think I got this right.

Some evidence that this heuristic is correct is here in: https://gist.github.com/ezyang/b22fa7b72b7349137211d8dc7041f758 I exhaustively test all dim=3 tensors with sizes [1, 2] and show that we get the same significant strides between PrimTorch and the new algorithm. In fact, there ARE differences between this algorithm and PrimTorch, but in fact this algorithm agrees with TensorIterator where PrimTorch is wrong (sample case: size=(1, 1, 2), stride=(1, 1, 1), stride=(1, 1, 1))

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

[ghstack-poisoned]
  • Loading branch information
ezyang committed Feb 7, 2023
2 parents 10d1ef0 + 57d7abd commit 8e1119d
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions torch/_subclasses/fake_tensor.py
Expand Up @@ -39,17 +39,21 @@

CONSTANT_NUMEL_LIMIT = 1

CNT = 0
RECURSION_COUNT = 0


class Increment:
# Small helper that increments recursion count, and
# resets it when the object goes out of scope. Useful
# if you don't want to increase indentation which is
# what a context manager would do.
class IncrementRecursionCount:
def __init__(self):
global CNT
CNT += 1
global RECURSION_COUNT
RECURSION_COUNT += 1

def __del__(self):
global CNT
CNT -= 1
global RECURSION_COUNT
RECURSION_COUNT -= 1


@dataclass
Expand Down Expand Up @@ -554,7 +558,12 @@ def infer_size(a, b):
dimB = dimsB - 1 - offset
sizeA = a[dimA] if dimA >= 0 else 1
sizeB = b[dimB] if dimB >= 0 else 1
assert sizeA == sizeB or sizeA == 1 or sizeB == 1
if not (sizeA == sizeB or sizeA == 1 or sizeB == 1):
raise RuntimeError(
f"The size of tensor a ({sizeA}) "
f"must match the size of tensor b ({sizeB}) "
f"at non-singleton dimension {i})"
)
expandedSizes[i] = sizeB if sizeA == 1 else sizeA
return tuple(expandedSizes)

Expand All @@ -567,7 +576,15 @@ def slow(msg):
return slow_ref(*args, **kwargs)

count_label("attempt fast")
# == Fast path (based off of TensorIterator fast path) ==

# Fast path (based off of TensorIterator fast path).
# Unfortunately, there is no way to easily deduplicate
# this with either the TensorIterator C++ implementation
# (which we don't want to SymIntify, and also the algorithm
# here is slightly different from TensorIterator to allow
# for broadcasting), nor the PrimTorch implementation
# (which does not actually implement a fast path.)

operands = args

# compute_shape
Expand Down Expand Up @@ -981,8 +998,8 @@ def dispatch(self, func, types, args=(), kwargs=None):
return args[0].fake_device

if log.getEffectiveLevel() <= logging.DEBUG:
log.debug(f"{' ' * CNT}FakeTensorMode.__torch_dispatch__: {func}")
incr = Increment()
log.debug(f"{' ' * RECURSION_COUNT}FakeTensorMode.__torch_dispatch__: {func}")
incr = IncrementRecursionCount()

# Some attribute queries that can be serviced directly
# See Note [is_coalesced is dispatched]
Expand Down

0 comments on commit 8e1119d

Please sign in to comment.