diff --git a/advanced_source/python_custom_ops.py b/advanced_source/python_custom_ops.py index 5ace0b40897..1f20125f785 100644 --- a/advanced_source/python_custom_ops.py +++ b/advanced_source/python_custom_ops.py @@ -112,7 +112,10 @@ def crop(pic: torch.Tensor, box: Sequence[int]) -> torch.Tensor: def _(pic, box): channels = pic.shape[0] x0, y0, x1, y1 = box - return pic.new_empty(channels, y1 - y0, x1 - x0) + result = pic.new_empty(y1 - y0, x1 - x0, channels).permute(2, 0, 1) + # The result should have the same metadata (shape/strides/``dtype``/device) + # as running the ``crop`` function above. + return result ###################################################################### # After this, ``crop`` now works without graph breaks: