diff --git a/aten/src/ATen/LegacyBatchingRegistrations.cpp b/aten/src/ATen/LegacyBatchingRegistrations.cpp index d32c093cdb8f..bae40e3c8e51 100644 --- a/aten/src/ATen/LegacyBatchingRegistrations.cpp +++ b/aten/src/ATen/LegacyBatchingRegistrations.cpp @@ -1,12 +1,12 @@ +#include #include -#include #include +#include #include -#include #include -#include +#include #include -#include +#include #include @@ -25,15 +25,14 @@ namespace at { // NOTE: [When should I add a batching rule?] // When you are adding a new operator, you'll need to add a batching rule so -// that vmap can work efficiently with said operator. If you do not, we'll -// attempt to generate a slow fallback for the batching rule. +// that vmap can work efficiently with said operator. If you do not, we'll attempt +// to generate a slow fallback for the batching rule. // NOTE: [How to write batching rules?] -// The signature of a batching rule should look like exactly like the C++ -// signature of its operator. +// The signature of a batching rule should look like exactly like the C++ signature +// of its operator. // -// First, see NOTE: [Logical vs physical args] in VmapTransforms.h for -// terminology. +// First, see NOTE: [Logical vs physical args] in VmapTransforms.h for terminology. // // At a high level, what a batching rule does is the following: // 1. Converts (logical) BatchedTensors to views on physical tensors. @@ -43,31 +42,27 @@ namespace at { // some physical results. // 4. Converts physical results back to BatchedTensors. // -// Steps 1, 2, and 4 differ for operators with different batching behaviors. -// When writing a new batching rule, please select a VmapTransform that matches -// the batching behavior of your operation. The VmapTransform provides helper -// functions to do steps (1), (2), and (4). (see NOTE: [What is an -// VmapTransform?] in VmapTransforms.h) +// Steps 1, 2, and 4 differ for operators with different batching behaviors. When +// writing a new batching rule, please select a VmapTransform that matches the +// batching behavior of your operation. The VmapTransform provides helper functions +// to do steps (1), (2), and (4). +// (see NOTE: [What is an VmapTransform?] in VmapTransforms.h) // Note: [Future plans] // The API for writing a batching rule isn't stable. In the future, we'd like -// to think about the problem of translating these batching rules to -// TorchScript. Ideally batching rules in eager mode vs TorchScript would look -// pretty similar, if not use the same mechanism. In order to accomplish that we -// might have to do some refactoring. +// to think about the problem of translating these batching rules to TorchScript. +// Ideally batching rules in eager mode vs TorchScript would look pretty similar, +// if not use the same mechanism. In order to accomplish that we might have to +// do some refactoring. -namespace { +namespace{ // PyTorch allows operations to specify dim 0 and dim -1 on a scalar tensor. static bool is_allowed_dim_on_scalar_tensor(int64_t dim) { return dim == 0 || dim == -1; } -Tensor sum_batching_rule( - const Tensor& self, - OptionalIntArrayRef opt_dims, - bool keepdim, - optional dtype) { +Tensor sum_batching_rule(const Tensor& self, OptionalIntArrayRef opt_dims, bool keepdim, optional dtype) { if (opt_dims.has_value()) { auto dims = opt_dims.value(); // PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail @@ -76,9 +71,7 @@ Tensor sum_batching_rule( // >>> x = torch.randn(B0) # the per-examples are all scalars // >>> vmap(partial(torch.sum, dim=0), x) // then we replicate the behavior of sum(scalar_tensor, dim=0). - if (/*logical*/ self.dim() == 0 && - (dims.empty() || - (dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])))) { + if (/*logical*/self.dim() == 0 && (dims.empty() || (dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])))) { return self.clone(); } } @@ -101,14 +94,10 @@ bool isPhysicalScalarTensor(const Tensor& logical_tensor) { template Tensor binary_pointwise_batching_rule( - const Tensor& self, - const Tensor& other, - ExtraArgs... args) { + const Tensor& self, const Tensor& other, ExtraArgs... args) { if (self.dim() > 0 && other.dim() > 0) { - auto physical_args = - BroadcastingVmapTransform::logicalToPhysical({self, other}); - auto result = - Func(physical_args[0].tensor(), physical_args[1].tensor(), args...); + auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other}); + auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...); return physical_args[0].getPhysicalToLogicalMap().apply(result); } if (isPhysicalScalarTensor(self)) { @@ -122,8 +111,8 @@ Tensor binary_pointwise_batching_rule( return self_physical.getPhysicalToLogicalMap().apply(result); } - // At this point, we know at least one of the operands is a logical Scalar - // tensor. Here we must emulate TensorIterator's special behavior on Scalars. + // At this point, we know at least one of the operands is a logical Scalar tensor. + // Here we must emulate TensorIterator's special behavior on Scalars. // // As a motivating example, consider the following: // x = torch.randn(3, 10) @@ -159,64 +148,49 @@ Tensor binary_pointwise_batching_rule( } auto physical_args = BroadcastingVmapTransform::logicalToPhysical( {std::move(logical_self), std::move(logical_other)}); - auto result = - Func(physical_args[0].tensor(), physical_args[1].tensor(), args...); + auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...); return physical_args[0].getPhysicalToLogicalMap().apply(result); } -Tensor expand_batching_rule( - const Tensor& self, - IntArrayRef size, - bool implicit) { +Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto size_physical = self_physical.getPhysicalShape(size); auto self_physical_dim = self_physical.tensor().dim(); - TORCH_CHECK( - self_physical_dim <= static_cast(size_physical.size()), - "expand: the number of sizes provided (", - /*logical*/ size.size(), - ") ", - "must be greater or equal to the number of dimensions in the tensor (", - /*logical dim*/ self.dim(), - ")"); + TORCH_CHECK(self_physical_dim <= static_cast(size_physical.size()), + "expand: the number of sizes provided (", /*logical*/size.size(), ") ", + "must be greater or equal to the number of dimensions in the tensor (", + /*logical dim*/self.dim(), ")"); if (self_physical_dim == static_cast(size_physical.size())) { auto result = self_physical.tensor().expand(size_physical, implicit); return self_physical.getPhysicalToLogicalMap().apply(result); } - TORCH_INTERNAL_ASSERT( - self_physical_dim < static_cast(size_physical.size())); + TORCH_INTERNAL_ASSERT(self_physical_dim < static_cast(size_physical.size())); // Here, we know we are expanding a (logical) tensor to a larger number // of dimensions. We have to be careful because we can't call expand directly // due to the presence of batch dimensions. // - // As an example, let B0 be a batch dimension and consider expand(Tensor[B0, - // 3], [2, 3]). The result should be a tensor of size [B0, 2, 3]. A physical - // view of size [B0, 3] can't directly be expanded to size [B0, 2, 3] so the - // strategy here is to view it first as a tensor of size [B0, 1, 3] and then - // expand. + // As an example, let B0 be a batch dimension and consider expand(Tensor[B0, 3], [2, 3]). + // The result should be a tensor of size [B0, 2, 3]. + // A physical view of size [B0, 3] can't directly be expanded to size [B0, 2, 3] + // so the strategy here is to view it first as a tensor of size [B0, 1, 3] and + // then expand. auto self_physical_size = self_physical.tensor().sizes(); auto extra_dims = size_physical.size() - self_physical_dim; VmapDimVector view_shape(size_physical.size(), 1); - std::copy( - self_physical_size.begin(), - self_physical_size.begin() + self_physical.numBatchDims(), - view_shape.begin()); - std::copy( - self_physical_size.begin() + self_physical.numBatchDims(), - self_physical_size.end(), - view_shape.begin() + self_physical.numBatchDims() + extra_dims); - auto result = - self_physical.tensor().view(view_shape).expand(size_physical, implicit); + std::copy(self_physical_size.begin(), + self_physical_size.begin() + self_physical.numBatchDims(), + view_shape.begin()); + std::copy(self_physical_size.begin() + self_physical.numBatchDims(), + self_physical_size.end(), + view_shape.begin() + self_physical.numBatchDims() + extra_dims); + auto result = self_physical.tensor().view(view_shape).expand(size_physical, implicit); return self_physical.getPhysicalToLogicalMap().apply(result); } -std::vector chunk_batching_rule( - const Tensor& self, - int64_t chunks, - int64_t dim) { +std::vector chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto dim_physical = self_physical.getPhysicalDim(dim); auto result = at::chunk(self_physical.tensor(), chunks, dim_physical); @@ -224,10 +198,7 @@ std::vector chunk_batching_rule( return result; } -Tensor clamp_batching_rule( - const Tensor& self, - const optional& min, - const optional& max) { +Tensor clamp_batching_rule(const Tensor& self, const optional& min, const optional& max) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto result = at::clamp(self_physical.tensor(), min, max); return self_physical.getPhysicalToLogicalMap().apply(result); @@ -245,22 +216,15 @@ Tensor clamp_max_batching_rule(const Tensor& self, const Scalar& max) { return self_physical.getPhysicalToLogicalMap().apply(result); } -std::vector tensor_split_sections_batching_rule( - const Tensor& self, - int64_t sections, - int64_t dim) { +std::vector tensor_split_sections_batching_rule(const Tensor& self, int64_t sections, int64_t dim) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto dim_physical = self_physical.getPhysicalDim(dim); - auto result = - at::tensor_split(self_physical.tensor(), sections, dim_physical); + auto result = at::tensor_split(self_physical.tensor(), sections, dim_physical); self_physical.getPhysicalToLogicalMap().applyInplace(result); return result; } -std::vector tensor_split_indices_batching_rule( - const Tensor& self, - IntArrayRef indices, - int64_t dim) { +std::vector tensor_split_indices_batching_rule(const Tensor& self, IntArrayRef indices, int64_t dim) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto dim_physical = self_physical.getPhysicalDim(dim); auto result = at::tensor_split(self_physical.tensor(), indices, dim_physical); @@ -270,13 +234,12 @@ std::vector tensor_split_indices_batching_rule( Tensor unsqueeze_batching_rule(const Tensor& self, int64_t dim) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); - // NB: unsqueeze has some special handling of its `dim` argument so we can't - // call self_physical.getPhysicalDim directly. In particular, - // native::unsqueeze wraps the dim to (the logical dimension) + 1, so we need - // to do that here too. + // NB: unsqueeze has some special handling of its `dim` argument so we can't call + // self_physical.getPhysicalDim directly. In particular, native::unsqueeze + // wraps the dim to (the logical dimension) + 1, so we need to do that here too. // https://github.com/pytorch/pytorch/blob/b623bdeabb0aa8da44285d303246e7f8ac06c2a9/aten/src/ATen/native/TensorShape.cpp#L1413 - auto dim_physical = self_physical.numBatchDims() + - maybe_wrap_dim(dim, /*logical_dim*/ self.dim() + 1); + auto dim_physical = + self_physical.numBatchDims() + maybe_wrap_dim(dim, /*logical_dim*/self.dim() + 1); auto result = self_physical.tensor().unsqueeze(dim_physical); return self_physical.getPhysicalToLogicalMap().apply(result); } @@ -292,7 +255,7 @@ Tensor& fill_inplace_tensor_batching_rule(Tensor& self, const Tensor& value) { if (value_batched) { auto physical_args = - BroadcastingVmapTransform::logicalToPhysical({self, value}); + BroadcastingVmapTransform::logicalToPhysical({self, value}); physical_args[0].tensor().copy_(physical_args[1].tensor()); } else { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); @@ -301,7 +264,7 @@ Tensor& fill_inplace_tensor_batching_rule(Tensor& self, const Tensor& value) { return self; } -Tensor& zero_inplace_batching_rule(Tensor& self) { +Tensor& zero_inplace_batching_rule(Tensor &self) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); self_physical.tensor().zero_(); return self; @@ -318,9 +281,7 @@ Tensor squeeze_batching_rule(const Tensor& self) { squeezed_sizes.end(), physical_sizes.begin(), physical_sizes.begin() + num_batch_dims); - for (auto it = physical_sizes.begin() + num_batch_dims; - it != physical_sizes.end(); - ++it) { + for (auto it = physical_sizes.begin() + num_batch_dims; it != physical_sizes.end(); ++it) { if (*it != 1) { squeezed_sizes.push_back(*it); } @@ -347,38 +308,29 @@ Tensor squeeze_dims_batching_rule(const Tensor& self, IntArrayRef dims) { Tensor trace_batching_rule(const Tensor& self) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); // Batched Diagonal View - auto self_diag = at::diagonal( - self_physical.tensor(), /*offset*/ 0, /*dim1*/ -2, /*dim2*/ -1); - auto result = at::sum(self_diag, -1); + auto self_diag = at::diagonal(self_physical.tensor(), /*offset*/0, /*dim1*/-2, /*dim2*/-1); + auto result = at::sum(self_diag, -1); return self_physical.getPhysicalToLogicalMap().apply(result); } -Tensor trace_backward_batching_rule( - const Tensor& grad, - IntArrayRef input_sizes) { +Tensor trace_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes) { auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad); - auto grad_input = - at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options()); + auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options()); // Batched Diagonal View - auto grad_input_diag = - at::diagonal(grad_input, /*offset*/ 0, /*dim1*/ -2, /*dim2*/ -1); + auto grad_input_diag = at::diagonal(grad_input, /*offset*/0, /*dim1*/-2, /*dim2*/-1); // Append a dimension of size one to the grad output auto grad_physical_tensor = grad_physical.tensor().unsqueeze(-1); grad_input_diag.copy_(grad_physical_tensor); return grad_physical.getPhysicalToLogicalMap().apply(grad_input); } -Tensor transpose_int_batching_rule( - const Tensor& self, - int64_t dim0, - int64_t dim1) { +Tensor transpose_int_batching_rule(const Tensor& self, int64_t dim0, int64_t dim1) { // PyTorch has a special case where scalar_tensor.transpose(dim0, dim1) works - // for dim0, dim1 in {0, -1} and returns the scalar tensor. If the following - // happens: + // for dim0, dim1 in {0, -1} and returns the scalar tensor. If the following happens: // >>> x = torch.randn(B0) # the per-examples are all scalars // >>> vmap(lambda x: x.transpose(0, -1), x) // then we replicate this behavior. - if (/*logical*/ self.dim() == 0 && is_allowed_dim_on_scalar_tensor(dim0) && + if (/*logical*/self.dim() == 0 && is_allowed_dim_on_scalar_tensor(dim0) && is_allowed_dim_on_scalar_tensor(dim1)) { return self; } @@ -399,7 +351,9 @@ Tensor permute_batching_rule(const Tensor& self, IntArrayRef dims) { all_dims_physical.push_back(bdim); } all_dims_physical.insert( - all_dims_physical.end(), dims_physical.begin(), dims_physical.end()); + all_dims_physical.end(), + dims_physical.begin(), + dims_physical.end()); auto result = self_physical.tensor().permute(all_dims_physical); return self_physical.getPhysicalToLogicalMap().apply(result); } @@ -411,23 +365,14 @@ Tensor select_batching_rule(const Tensor& self, int64_t dim, int64_t index) { return self_physical.getPhysicalToLogicalMap().apply(result); } -static int64_t getGradInputPhysicalDim( - int64_t dim, - IntArrayRef input_sizes, - int64_t num_batch_dims) { +static int64_t getGradInputPhysicalDim(int64_t dim, IntArrayRef input_sizes, int64_t num_batch_dims) { return maybe_wrap_dim(dim, input_sizes.size()) + num_batch_dims; } -Tensor select_backward_batching_rule( - const Tensor& grad, - IntArrayRef input_sizes, - int64_t dim, - int64_t index) { +Tensor select_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) { auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad); - auto grad_input = - at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options()); - auto physical_dim = - getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims()); + auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options()); + auto physical_dim = getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims()); grad_input.select(physical_dim, index).copy_(grad_physical.tensor()); return grad_physical.getPhysicalToLogicalMap().apply(grad_input); } @@ -444,63 +389,36 @@ Tensor slice_batching_rule( return self_physical.getPhysicalToLogicalMap().apply(result); } -Tensor slice_backward_batching_rule( - const Tensor& grad, - IntArrayRef input_sizes, - int64_t dim, - int64_t start, - int64_t end, - int64_t step) { +Tensor slice_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) { auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad); - auto grad_input = - at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options()); - auto physical_dim = - getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims()); - grad_input.slice(physical_dim, start, end, step) - .copy_(grad_physical.tensor()); + auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options()); + auto physical_dim = getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims()); + grad_input.slice(physical_dim, start, end, step).copy_(grad_physical.tensor()); return grad_physical.getPhysicalToLogicalMap().apply(grad_input); } -Tensor diagonal_batching_rule( - const Tensor& self, - int64_t offset, - int64_t dim1, - int64_t dim2) { +Tensor diagonal_batching_rule(const Tensor& self, int64_t offset, int64_t dim1, int64_t dim2) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto dim1_physical = self_physical.getPhysicalDim(dim1); auto dim2_physical = self_physical.getPhysicalDim(dim2); - auto result = at::diagonal( - self_physical.tensor(), offset, dim1_physical, dim2_physical); + auto result = at::diagonal(self_physical.tensor(), offset, dim1_physical, dim2_physical); return self_physical.getPhysicalToLogicalMap().apply(result); } -Tensor diagonal_backward_batching_rule( - const Tensor& grad, - IntArrayRef input_sizes, - int64_t offset, - int64_t dim1, - int64_t dim2) { +Tensor diagonal_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) { auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad); - auto grad_input = - at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options()); - auto dim1_physical = - getGradInputPhysicalDim(dim1, input_sizes, grad_physical.numBatchDims()); - auto dim2_physical = - getGradInputPhysicalDim(dim2, input_sizes, grad_physical.numBatchDims()); - grad_input.diagonal(offset, dim1_physical, dim2_physical) - .copy_(grad_physical.tensor()); + auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options()); + auto dim1_physical = getGradInputPhysicalDim(dim1, input_sizes, grad_physical.numBatchDims()); + auto dim2_physical = getGradInputPhysicalDim(dim2, input_sizes, grad_physical.numBatchDims()); + grad_input.diagonal(offset, dim1_physical, dim2_physical).copy_(grad_physical.tensor()); return grad_physical.getPhysicalToLogicalMap().apply(grad_input); } -Tensor movedim_batching_rule( - const Tensor& self, - IntArrayRef source, - IntArrayRef destination) { +Tensor movedim_batching_rule(const Tensor& self, IntArrayRef source, IntArrayRef destination) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto source_physical = self_physical.getPhysicalDims(source); auto destination_physical = self_physical.getPhysicalDims(destination); - auto result = at::movedim( - self_physical.tensor(), source_physical, destination_physical); + auto result = at::movedim(self_physical.tensor(), source_physical, destination_physical); return self_physical.getPhysicalToLogicalMap().apply(result); } @@ -511,10 +429,7 @@ Tensor reshape_batching_rule(const Tensor& self, IntArrayRef shape) { return self_physical.getPhysicalToLogicalMap().apply(result); } -std::vector split_batching_rule( - const Tensor& self, - int64_t split_size, - int64_t dim) { +std::vector split_batching_rule(const Tensor& self, int64_t split_size, int64_t dim) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto dim_physical = self_physical.getPhysicalDim(dim); auto result = at::split(self_physical.tensor(), split_size, dim_physical); @@ -522,14 +437,10 @@ std::vector split_batching_rule( return result; } -std::vector split_with_sizes_batching_rule( - const Tensor& self, - IntArrayRef split_sizes, - int64_t dim) { +std::vector split_with_sizes_batching_rule(const Tensor& self, IntArrayRef split_sizes, int64_t dim) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto dim_physical = self_physical.getPhysicalDim(dim); - auto result = - at::split_with_sizes(self_physical.tensor(), split_sizes, dim_physical); + auto result = at::split_with_sizes(self_physical.tensor(), split_sizes, dim_physical); self_physical.getPhysicalToLogicalMap().applyInplace(result); return result; } @@ -542,22 +453,15 @@ std::vector unbind_batching_rule(const Tensor& self, int64_t dim) { return result; } -Tensor unfold_batching_rule( - const Tensor& self, - int64_t dim, - int64_t size, - int64_t step) { +Tensor unfold_batching_rule(const Tensor& self, int64_t dim, int64_t size, int64_t step) { auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto dim_physical = self_physical.getPhysicalDim(dim); auto result = self_physical.tensor().unfold(dim_physical, size, step); return self_physical.getPhysicalToLogicalMap().apply(result); } -Tensor contiguous_batching_rule( - const Tensor& self, - MemoryFormat memory_format) { - TORCH_CHECK( - memory_format == MemoryFormat::Contiguous, +Tensor contiguous_batching_rule(const Tensor& self, MemoryFormat memory_format) { + TORCH_CHECK(memory_format == MemoryFormat::Contiguous, "NYI: Tensor.contiguous(...) inside of vmap for memory_format other ", "than torch.contiguous_format"); auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self); @@ -575,8 +479,7 @@ Tensor view_batching_rule(const Tensor& self, IntArrayRef size) { Tensor view_as_complex_batching_rule(const Tensor& self) { // guard against the user passing in a batch of scalar tensors with batch // size equal to 2. - TORCH_CHECK( - !self.sizes().empty(), "Input tensor must have one or more dimensions"); + TORCH_CHECK(!self.sizes().empty(), "Input tensor must have one or more dimensions"); auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto result = at::view_as_complex(self_physical.tensor()); return self_physical.getPhysicalToLogicalMap().apply(result); @@ -585,9 +488,7 @@ Tensor view_as_complex_batching_rule(const Tensor& self) { // Checks that the smallest batch stride is greater than the largest example // stride. This is something we can support but we choose not to because it's // potentially error prone. -static void checkBatchDimsAtFrontInLayout( - IntArrayRef physical_strides, - int64_t num_batch_dims) { +static void checkBatchDimsAtFrontInLayout(IntArrayRef physical_strides, int64_t num_batch_dims) { auto smallest_batch_stride = std::min_element( physical_strides.begin(), physical_strides.begin() + num_batch_dims); auto largest_example_stride = std::max_element( @@ -596,22 +497,19 @@ static void checkBatchDimsAtFrontInLayout( // No example dimensions return; } - TORCH_CHECK( - *smallest_batch_stride >= *largest_example_stride, - "vmap: Calling Tensor.as_strided is not supported unless the batch dims being ", - "vmapped over are at the front of the tensor (in memory layout). When they are ", - "not at the front of the tensor this operation can be error prone so we " - "actively discourage it; please file us a bug report and/or try to ", - "express the as_strided operation in terms of PyTorch view operations"); + TORCH_CHECK(*smallest_batch_stride >= *largest_example_stride, + "vmap: Calling Tensor.as_strided is not supported unless the batch dims being ", + "vmapped over are at the front of the tensor (in memory layout). When they are ", + "not at the front of the tensor this operation can be error prone so we " + "actively discourage it; please file us a bug report and/or try to ", + "express the as_strided operation in terms of PyTorch view operations"); } // given (sizes, strides, storage_offset) returns the maximum location that // can be indexed (or nullopt if such a location doesn't exist, e.g., tensors // with zero-size dims). static optional maximum_indexable_location( - IntArrayRef sizes, - IntArrayRef strides, - int64_t storage_offset) { + IntArrayRef sizes, IntArrayRef strides, int64_t storage_offset) { auto result = native::storage_size_for(sizes, strides); if (result == 0) { return nullopt; @@ -635,24 +533,15 @@ static void checkBasicAsStridedValidForSlice( auto storage_offset = maybe_storage_offset.value_or(base_offset); - auto max_as_strided_loc = - maximum_indexable_location(sizes, strides, storage_offset); - auto max_slice_loc = - maximum_indexable_location(slice_sizes, slice_strides, base_offset); + auto max_as_strided_loc = maximum_indexable_location(sizes, strides, storage_offset); + auto max_slice_loc = maximum_indexable_location(slice_sizes, slice_strides, base_offset); if (!max_as_strided_loc.has_value()) { return; } if (!max_slice_loc.has_value()) { - TORCH_CHECK( - false, - "result = tensor.as_strided(", - sizes, - ",", - strides, - ",", - storage_offset, - ")", + TORCH_CHECK(false, + "result = tensor.as_strided(", sizes, ",", strides, ",", storage_offset, ")", "can access memory outside of `tensor`. `tensor` has no storage but the ", "passed-in (size, stride, storage_offset) imply a result with some storage. ", "This is not supported inside of vmap, please try to rewrite the ", @@ -661,31 +550,15 @@ static void checkBasicAsStridedValidForSlice( TORCH_CHECK( *max_as_strided_loc <= *max_slice_loc && base_offset <= storage_offset, - "result = tensor.as_strided(", - sizes, - ",", - strides, - ",", - storage_offset, - ")", + "result = tensor.as_strided(", sizes, ",", strides, ",", storage_offset, ")", "can access memory outside of `tensor`. `result` can access some", - "memory in range [", - storage_offset, - ", ", - *max_as_strided_loc, - "], but ", - "`tensor` can only access some memory in range [", - base_offset, - ", ", - *max_slice_loc, - "]. This is not supported inside of vmap, please try to", + "memory in range [", storage_offset, ", ", *max_as_strided_loc, "], but ", + "`tensor` can only access some memory in range [", base_offset, ", ", + *max_slice_loc, "]. This is not supported inside of vmap, please try to", "rewrite the `as_strided` call as a sequence of PyTorch view operations"); } -Tensor _reshape_alias_batching_rule( - const Tensor& self, - IntArrayRef sizes, - IntArrayRef strides) { +Tensor _reshape_alias_batching_rule(const Tensor& self, IntArrayRef sizes, IntArrayRef strides) { return reshape_batching_rule(self, sizes); } @@ -693,29 +566,22 @@ Tensor _new_zeros_with_same_feature_meta_batching_rule( const Tensor& self, const Tensor& other, int64_t unused_num_batch_dims) { - TORCH_CHECK( - isBatchedTensor(self) && !isBatchedTensor(other), - "Only the 'batched grad' use case is supported in PyTorch core."); + TORCH_CHECK(isBatchedTensor(self) && !isBatchedTensor(other), + "Only the 'batched grad' use case is supported in PyTorch core."); - TORCH_INTERNAL_ASSERT( - unused_num_batch_dims == 0, - "num_batch_dims should not be explicitly passed in because it will be overridden"); - auto self_physical_view = - at::MultiBatchVmapTransform::logicalToPhysical(self); + TORCH_INTERNAL_ASSERT(unused_num_batch_dims == 0, + "num_batch_dims should not be explicitly passed in because it will be overridden"); + auto self_physical_view = at::MultiBatchVmapTransform::logicalToPhysical(self); const auto& self_physical_tensor = self_physical_view.tensor(); int64_t num_batch_dims = self_physical_view.numBatchDims(); checkBatchDimsAtFrontInLayout(self_physical_tensor.strides(), num_batch_dims); - auto result = at::_new_zeros_with_same_feature_meta( - self_physical_tensor, other, num_batch_dims); + auto result = at::_new_zeros_with_same_feature_meta(self_physical_tensor, other, num_batch_dims); return self_physical_view.getPhysicalToLogicalMap().apply(result); } -bool _has_same_storage_numel_batching_rule( - const Tensor& self, - const Tensor& other) { - TORCH_CHECK( - isBatchedTensor(self) && !isBatchedTensor(other), - "Only the 'batched grad' use case is supported in PyTorch core."); +bool _has_same_storage_numel_batching_rule(const Tensor& self, const Tensor& other) { + TORCH_CHECK(isBatchedTensor(self) && !isBatchedTensor(other), + "Only the 'batched grad' use case is supported in PyTorch core."); // The _has_same_storage_numel check is skipped if the tangent is a batched // tensor because using as_strided to access storage locations not indexable // by the input tensor is not supported in vmap @@ -743,8 +609,7 @@ bool _has_same_storage_numel_batching_rule( // However, we consider the above for-loop comprehension to be a user error: // a user should have written the following if they wanted to use as_strided // in a per-sample way: -// >>> z = [x[i].as_strided([1], [1], 1 + x[i].storage_offset() - 1) for i in -// range(4)] +// >>> z = [x[i].as_strided([1], [1], 1 + x[i].storage_offset() - 1) for i in range(4)] Tensor as_strided_batching_rule( const Tensor& tensor, IntArrayRef sizes, @@ -756,22 +621,17 @@ Tensor as_strided_batching_rule( const auto& physical_tensor = physical_view.tensor(); // We can't rely on the physical as_strided call to do this for us because - // we do some sanity checks on the size/strides before calling into - // as_strided. - TORCH_CHECK( - sizes.size() == strides.size(), + // we do some sanity checks on the size/strides before calling into as_strided. + TORCH_CHECK(sizes.size() == strides.size(), "Tensor.as_strided(size, stride, ...): size and stride must have the ", - "same length! Got size ", - sizes, - " and stride ", - strides); + "same length! Got size ", sizes, " and stride ", strides); // Sanity checks: // 1. All batch dims are at the front in memory layout (not necessary for // correctness, but we are worried the user might be doing crazy things) - // 2. as_strided(sizes, strides, storage_offset + tensor[i].offset() - - // tensor.offset()) is valid for a slice of the input tensor. See Note: [When - // will the as_strided batching rule fail?] for details. + // 2. as_strided(sizes, strides, storage_offset + tensor[i].offset() - tensor.offset()) + // is valid for a slice of the input tensor. + // See Note: [When will the as_strided batching rule fail?] for details. checkBatchDimsAtFrontInLayout(physical_tensor.strides(), num_batch_dims); checkBasicAsStridedValidForSlice( physical_tensor, num_batch_dims, sizes, strides, storage_offset); @@ -785,8 +645,8 @@ Tensor as_strided_batching_rule( physical_strides.insert( physical_strides.end(), strides.begin(), strides.end()); - // If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - - // xs.offset()) is valid for all i, then it turns out that + // If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) + // is valid for all i, then it turns out that // xs.as_strided(physical_sizes, physical_strides, offset) always succeeds // and creates a tensor y such that each y[i] references the same memory // locations as zi. See NOTE: [When will the as_strided batching rule fail?] @@ -796,15 +656,14 @@ Tensor as_strided_batching_rule( } // NOTE: [When will the as_strided batching rule fail?] -// If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - -// xs.offset()) is valid for all i, then it turns out that +// If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) +// is valid for all i, then it turns out that // xs.as_strided(physical_sizes, physical_strides, offset) always succeeds and // creates a tensor y such that each y[i] refers to the same memory as zi. // -// Let's say we have xs[i].as_strided(sizes, strides, offset + xs[i].offset() - -// xs.offset()). Furthermore, let's say that as a part of being "valid" this -// as_strided call does not return a result that can index memory not indexable -// by xs[i]. +// Let's say we have xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()). +// Furthermore, let's say that as a part of being "valid" this as_strided call +// does not return a result that can index memory not indexable by xs[i]. // // WLOG, assume that there's only one batch dim and it is at the front of the // `xs` tensor. Let B be the batch size and S be the stride of the batch dim. @@ -829,15 +688,13 @@ Tensor as_strided_batching_rule( // - strides are positive // - offset is positive // -// Claim 1: if xs[i].as_strided(sizes, strides, offset + xs[i].offset() - -// xs.offset()) is valid, then -// ([B] + sizes, [S] + strides, offset + xs.offset()) are in bounds for `xs`'s -// storage. +// Claim 1: if xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) +// is valid, then +// ([B] + sizes, [S] + strides, offset + xs.offset()) are in bounds for `xs`'s storage. // // If we have the claim, then xs.as_strided([B] + sizes, [S] + strides, offset) // won't error out. So all we need to check is that the memory locations are -// what we expected. See [Hand-wavy proof of Claim 1] for proof (it's not very -// important) +// what we expected. See [Hand-wavy proof of Claim 1] for proof (it's not very important) // // xs.as_strided(physical_sizes, physical_strides, offset) is equivalent to // xs.as_strided([B] + sizes, [S] + strides, offset) @@ -857,10 +714,8 @@ Tensor as_strided_batching_rule( // [Hand-wavy proof of Claim 1] // Part of our definition of being valid is that xs[i].as_strided(...) // must return a tensor that only uses memory indexable by xs[i]. -// This means that (sizes, strides, offset + xs[i].offset() - xs.offset()) -// satisfies: -// offset + xs[i].offset() - xs.offset() + 1 + \sum_j (sizes[j] - 1) * -// strides[j] +// This means that (sizes, strides, offset + xs[i].offset() - xs.offset()) satisfies: +// offset + xs[i].offset() - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j] // <= xs[i].offset() + 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j) // (the largest-index memory location of xs[i].as_strided(...) must be \leq // the largest-index memory location of xs[i]) @@ -890,8 +745,7 @@ Tensor unwrap_and_call(const Tensor& input, ExtraArgs... args) { auto* input_batched = unsafeGetBatchedImpl(input); auto output_physical = Func(input_batched->value(), args...); auto old_bdims = input_batched->bdims(); - return makeBatched( - output_physical, BatchDims(old_bdims.begin(), old_bdims.end())); + return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end())); } template @@ -899,23 +753,17 @@ Tensor unwrap_and_call_method(const Tensor& input, ExtraArgs... extra_args) { auto* input_batched = unsafeGetBatchedImpl(input); auto output_physical = (input_batched->value().*Func)(extra_args...); auto old_bdims = input_batched->bdims(); - return makeBatched( - output_physical, BatchDims(old_bdims.begin(), old_bdims.end())); + return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end())); } -Tensor pow_scalar_Tensor_batching_rule( - const Scalar& other, - const Tensor& self) { +Tensor pow_scalar_Tensor_batching_rule(const Scalar& other, const Tensor& self) { auto* self_batched = unsafeGetBatchedImpl(self); auto output_physical = at::pow(other, self_batched->value()); auto old_bdims = self_batched->bdims(); - return makeBatched( - output_physical, BatchDims(old_bdims.begin(), old_bdims.end())); + return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end())); } -Tensor clone_batching_rule( - const Tensor& self, - optional memory_format) { +Tensor clone_batching_rule(const Tensor& self, optional memory_format) { // Memory format support is a little tricky because vmap is allowed to move // around batch dimensions and some memory formats are rank-dependent. // Another weird case is: @@ -923,38 +771,32 @@ Tensor clone_batching_rule( // allow the user to clone a Tensor with 3 logical dimensions and 1 batch // dim into a ChannelsLast Tensor? What about a Tensor with 3 logical dims // and N>1 batch dims? - TORCH_CHECK( - !memory_format.has_value() || memory_format == MemoryFormat::Preserve || - memory_format == MemoryFormat::Contiguous, + TORCH_CHECK(!memory_format.has_value() || memory_format == MemoryFormat::Preserve + || memory_format == MemoryFormat::Contiguous, "NYI: Tensor.clone(memory_format) inside vmap is only supported with ", "memory_format torch.preserve_format or torch.contiguous_format (got ", - *memory_format, - ")"); + *memory_format, ")"); if (memory_format == MemoryFormat::Contiguous) { // There is an ambiguity here when the batch dims are not at the front of // the tensor. // >>> x = torch.randn(3, B0, 5) - // >>> y = vmap(lambda x: x.clone(torch.contiguous_format), in_dims=1, - // out_dims=0)(x) + // >>> y = vmap(lambda x: x.clone(torch.contiguous_format), in_dims=1, out_dims=0)(x) // >>> y[0].is_contiguous() // ??? // Should we make the whole tensor contiguous, or should we // make the non-batch dims contiguous? We've chosen the latter because - // philosophically vmap hides the batch dims and operates on a per-sample - // level. + // philosophically vmap hides the batch dims and operates on a per-sample level. auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self); auto output_physical = at::clone(physical_view.tensor(), memory_format); return physical_view.getPhysicalToLogicalMap().apply(output_physical); } - TORCH_INTERNAL_ASSERT( - !memory_format.has_value() || memory_format == MemoryFormat::Preserve); + TORCH_INTERNAL_ASSERT(!memory_format.has_value() || memory_format == MemoryFormat::Preserve); auto* self_batched = unsafeGetBatchedImpl(self); auto output_physical = at::clone(self_batched->value(), memory_format); auto old_bdims = self_batched->bdims(); - return makeBatched( - output_physical, BatchDims(old_bdims.begin(), old_bdims.end())); + return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end())); } // Note [Batching rules for matmul-like operators] @@ -968,15 +810,10 @@ Tensor mv_batching_rule(const Tensor& self, const Tensor& other) { auto other_batched = isBatchedTensor(other); // A shape checking API would be nice... - TORCH_CHECK( - self.dim() == 2 && other.dim() == 1, + TORCH_CHECK(self.dim() == 2 && other.dim() == 1, "mv(self, other): Shape mismatch: expected matrix " - "(got `self` of size ", - self.sizes(), - ") ", - "and vector (got `other` of size ", - other.sizes(), - ")"); + "(got `self` of size ", self.sizes(), ") ", + "and vector (got `other` of size ", other.sizes(), ")"); // See Note [Batching rules for matmul-like operators] for why we have cases if (self_batched && !other_batched) { @@ -996,39 +833,34 @@ Tensor mv_batching_rule(const Tensor& self, const Tensor& other) { // self_physical: [..., L, K], other_physical: [..., K] // We view the tensors as [..., L, K], [..., K, 1], perform matmul to get // a tensor of size [..., L, 1], and unsqueeze the last dim. - auto physical_args = - MultiBatchVmapTransform::logicalToPhysical({self, other}); + auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other}); auto result = at::matmul( - physical_args[0].tensor(), physical_args[1].tensor().unsqueeze(-1)); + physical_args[0].tensor(), + physical_args[1].tensor().unsqueeze(-1)); return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1)); } TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor"); } Tensor _make_dual_batching_rule( - c10::DispatchKeySet ks, - const Tensor& primal, - const Tensor& tangent, - int64_t level) { + c10::DispatchKeySet ks, + const Tensor& primal, + const Tensor& tangent, + int64_t level +) { DispatchKeySet after_batched_keyset = DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::Batched); - return at::redispatch::_make_dual( - ks & after_batched_keyset, primal, tangent, level); + return at::redispatch::_make_dual(ks & after_batched_keyset, primal, tangent, level); } Tensor dot_batching_rule(const Tensor& self, const Tensor& other) { auto self_batched = isBatchedTensor(self); auto other_batched = isBatchedTensor(other); - TORCH_CHECK( - /*logical*/ self.dim() == 1 && /*logical*/ other.dim() == 1, + TORCH_CHECK(/*logical*/self.dim() == 1 && /*logical*/other.dim() == 1, "dot(self, other): Shape mismatch: vector " - "(got `self` of size ", - self.sizes(), - ") ", - "and vector (got `other` of size ", - other.sizes(), - ")"); + "(got `self` of size ", self.sizes(), ") ", + "and vector (got `other` of size ", other.sizes(), ")"); // See Note [Batching rules for matmul-like operators] for why we have cases if (self_batched && !other_batched) { @@ -1047,34 +879,24 @@ Tensor dot_batching_rule(const Tensor& self, const Tensor& other) { } if (self_batched && other_batched) { // self_physical: [..., K], other_physical: [..., K] - // View the tensors as [..., 1, K] and [..., K, 1], perform matmul, and - // unsqueeze. - auto physical_args = - MultiBatchVmapTransform::logicalToPhysical({self, other}); + // View the tensors as [..., 1, K] and [..., K, 1], perform matmul, and unsqueeze. + auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other}); auto result = at::matmul( physical_args[0].tensor().unsqueeze(-2), physical_args[1].tensor().unsqueeze(-1)); - return physical_args[0].getPhysicalToLogicalMap().apply( - result.squeeze(-1).squeeze(-1)); + return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1).squeeze(-1)); } TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor"); } Tensor bmm_batching_rule(const Tensor& self, const Tensor& other) { - TORCH_CHECK( - /*logical*/ self.dim() == 3 && /*logical*/ other.dim() == 3, + TORCH_CHECK(/*logical*/self.dim() == 3 && /*logical*/other.dim() == 3, "bmm(self, other): Shape mismatch: expected 3D `self` " - "(got `self` of size ", - self.sizes(), - ") ", - "and 3D `other` (got `other` of size ", - other.sizes(), - ")"); - - auto physical_args = - BroadcastingVmapTransform::logicalToPhysical({self, other}); - auto result = - at::matmul(physical_args[0].tensor(), physical_args[1].tensor()); + "(got `self` of size ", self.sizes(), ") ", + "and 3D `other` (got `other` of size ", other.sizes(), ")"); + + auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other}); + auto result = at::matmul(physical_args[0].tensor(), physical_args[1].tensor()); return physical_args[0].getPhysicalToLogicalMap().apply(result); } @@ -1082,15 +904,10 @@ Tensor mm_batching_rule(const Tensor& self, const Tensor& other) { auto self_batched = isBatchedTensor(self); auto other_batched = isBatchedTensor(other); - TORCH_CHECK( - /*logical*/ self.dim() == 2 && /*logical*/ other.dim() == 2, + TORCH_CHECK(/*logical*/self.dim() == 2 && /*logical*/other.dim() == 2, "mm(self, other): Shape mismatch: expected matrix " - "(got `self` of size ", - self.sizes(), - ") ", - "and matrix (got `other` of size ", - other.sizes(), - ")"); + "(got `self` of size ", self.sizes(), ") ", + "and matrix (got `other` of size ", other.sizes(), ")"); // See Note [Batching rules for matmul-like operators] for why we have cases if (self_batched && !other_batched) { @@ -1104,12 +921,9 @@ Tensor mm_batching_rule(const Tensor& self, const Tensor& other) { return other_physical.getPhysicalToLogicalMap().apply(result); } if (self_batched && other_batched) { - auto physical_args = - MultiBatchVmapTransform::logicalToPhysical({self, other}); - auto result = - at::matmul(physical_args[0].tensor(), physical_args[1].tensor()); - return physical_args[0].getPhysicalToLogicalMap().apply( - result.squeeze(-1).squeeze(-1)); + auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other}); + auto result = at::matmul(physical_args[0].tensor(), physical_args[1].tensor()); + return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1).squeeze(-1)); } TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor"); } @@ -1117,35 +931,30 @@ Tensor mm_batching_rule(const Tensor& self, const Tensor& other) { Tensor cat_batching_rule(const ITensorListRef& tensors, int64_t dim) { auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors); auto physical_tensors = fmap( - physical_views, - [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); }); + physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); }); TORCH_INTERNAL_ASSERT( - !tensors.empty(), - "The dispatcher should not have dispatched here otherwise."); - auto result = - at::cat(physical_tensors, physical_views[0].getPhysicalDim(dim)); + !tensors.empty(), "The dispatcher should not have dispatched here otherwise."); + auto result = at::cat(physical_tensors, physical_views[0].getPhysicalDim(dim)); return physical_views[0].getPhysicalToLogicalMap().apply(result); } Tensor stack_batching_rule(TensorList tensors, int64_t dim) { auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors); auto physical_tensors = fmap( - physical_views, - [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); }); + physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); }); TORCH_INTERNAL_ASSERT( - !tensors.empty(), - "The dispatcher should not have dispatched here otherwise."); + !tensors.empty(), "The dispatcher should not have dispatched here otherwise."); // NB: stack wraps the dimensionality to (logical dim + 1), so we have to // manually handle that here. - auto dim_physical = physical_views[0].numBatchDims() + - maybe_wrap_dim(dim, /*logical*/ tensors[0].dim() + 1); + auto dim_physical = + physical_views[0].numBatchDims() + maybe_wrap_dim(dim, /*logical*/tensors[0].dim() + 1); auto result = at::stack(physical_tensors, dim_physical); return physical_views[0].getPhysicalToLogicalMap().apply(result); } -// I am quite sad that we need to register operators with exploded -// TensorOptions, even though the native:: implementations can use -// TensorOptions&. This also makes it hard to metaprogram: i.e., we can't use +// I am quite sad that we need to register operators with exploded TensorOptions, +// even though the native:: implementations can use TensorOptions&. +// This also makes it hard to metaprogram: i.e., we can't use // unwrap_and_call<..., at::to> because at::to takes TensorOptions& (!!) Tensor to_dtype_layout_batching_rule( const Tensor& self, @@ -1153,18 +962,17 @@ Tensor to_dtype_layout_batching_rule( optional layout, optional device, optional pin_memory, - bool non_blocking, - bool copy, + bool non_blocking, bool copy, optional memory_format) { - auto options = - TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory( - pin_memory); + auto options = TensorOptions() + .dtype(dtype) + .layout(layout) + .device(device) + .pinned_memory(pin_memory); auto* input_batched = unsafeGetBatchedImpl(self); - auto output_physical = - input_batched->value().to(options, non_blocking, copy, memory_format); + auto output_physical = input_batched->value().to(options, non_blocking, copy, memory_format); auto old_bdims = input_batched->bdims(); - return makeBatched( - output_physical, BatchDims(old_bdims.begin(), old_bdims.end())); + return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end())); } Tensor new_zeros_batching_rule( @@ -1176,9 +984,11 @@ Tensor new_zeros_batching_rule( optional pin_memory) { auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self); auto physical_size = physical_view.getPhysicalShape(size); - auto options = - TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory( - pin_memory); + auto options = TensorOptions() + .dtype(dtype) + .layout(layout) + .device(device) + .pinned_memory(pin_memory); auto result = physical_view.tensor().new_zeros(physical_size, options); return physical_view.getPhysicalToLogicalMap().apply(result); } @@ -1192,10 +1002,7 @@ Tensor new_empty_batching_rule( c10::optional pin_memory) { auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self); auto physical_size = physical_view.getPhysicalShape(size); - auto result = physical_view.tensor().new_empty( - physical_size, - TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory( - pin_memory)); + auto result = physical_view.tensor().new_empty(physical_size, TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory)); return physical_view.getPhysicalToLogicalMap().apply(result); } @@ -1212,8 +1019,8 @@ Tensor new_empty_strided_batching_rule( // Let [B0, B1, B2] be the shape of the batch dims. We're going to create // the batch dimensions at the front of the tensor (in memory layout), - // irrespective of whether or not they are actually at the front (in memory - // layout) in the original `self` tensor. This is because when a user calls + // irrespective of whether or not they are actually at the front (in memory layout) + // in the original `self` tensor. This is because when a user calls // `new_empty_strided` in general, the `strides` they provide are for a new // tensor and have no relation to the strides of the original tensor. // @@ -1241,13 +1048,10 @@ Tensor new_empty_strided_batching_rule( // physical_strides = [B1 * B2 * S, B2 * S, S] auto physical_strides = at::detail::defaultStrides(batch_shape); - TORCH_CHECK( - size.size() == stride.size(), - "new_empty_strided(sizes, strides): dimensionality of sizes (", - size.size(), - ") must match dimensionality of strides (", - stride.size(), - ")"); + TORCH_CHECK(size.size() == stride.size(), + "new_empty_strided(sizes, strides): dimensionality of sizes (", + size.size(), ") must match dimensionality of strides (", + stride.size(), ")"); auto storage_size = native::storage_size_for(size, stride); for (auto& physical_stride : physical_strides) { physical_stride *= storage_size; @@ -1262,36 +1066,28 @@ Tensor new_empty_strided_batching_rule( } template -Tensor comparison_pointwise_batching_rule( - const Tensor& self, - const Tensor& other) { - auto physical_args = - BroadcastingVmapTransform::logicalToPhysical({self, other}); +Tensor comparison_pointwise_batching_rule(const Tensor& self, const Tensor& other) { + auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other}); auto result = Func(physical_args[0].tensor(), physical_args[1].tensor()); return physical_args[0].getPhysicalToLogicalMap().apply(result); } -} // namespace +} TORCH_LIBRARY_IMPL(_, Batched, m) { - m.fallback(torch::CppFunction::makeFromBoxedFunction< - &batchedTensorForLoopFallback>()); + m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedTensorForLoopFallback>()); } TORCH_LIBRARY_IMPL(aten, Batched, m) { // NB: Ideally we would like some operators, like size.int, to "fallthrough" // to the underlying implementation. However, because a BatchedTensor is a - // Tensor wrapper, it only has one dispatch key (Batched) on it. The - // resolution here is to just directly call the underlying implementation. - m.impl( - "size.int", - static_cast(native::size)); + // Tensor wrapper, it only has one dispatch key (Batched) on it. The resolution + // here is to just directly call the underlying implementation. + m.impl("size.int", static_cast(native::size)); m.impl("_add_batch_dim", native::_add_batch_dim); m.impl("_remove_batch_dim", native::_remove_batch_dim); m.impl("_make_dual", _make_dual_batching_rule); m.impl("_has_same_storage_numel", _has_same_storage_numel_batching_rule); m.impl("is_same_size", native::is_same_size); - m.impl( - "_new_zeros_with_same_feature_meta", - _new_zeros_with_same_feature_meta_batching_rule); + m.impl("_new_zeros_with_same_feature_meta", _new_zeros_with_same_feature_meta_batching_rule); m.impl("sum.dim_IntList", sum_batching_rule); m.impl("is_complex", native::is_complex); @@ -1310,17 +1106,14 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { m.impl("expand", expand_batching_rule); m.impl("expand_as", native::expand_as); // composite wrt autograd m.impl("movedim.intlist", movedim_batching_rule); - m.impl( - "movedim.int", - static_cast( - native::movedim)); // composite wrt autograd + m.impl("movedim.int", static_cast(native::movedim)); // composite wrt autograd // There is another variant of narrow. However, we don't // want to support the other variant yet bc it isn't documented... m.impl("narrow", native::narrow_symint); // composite wrt autograd - m.impl("numpy_T", native::numpy_T); // composite wrt autograd + m.impl("numpy_T", native::numpy_T); // composite wrt autograd m.impl("matrix_H", native::matrix_H); // composite wrt autograd - m.impl("mT", native::mT); // composite wrt autograd - m.impl("mH", native::mH); // composite wrt autograd + m.impl("mT", native::mT); // composite wrt autograd + m.impl("mH", native::mH); // composite wrt autograd m.impl("permute", permute_batching_rule); m.impl("reshape", reshape_batching_rule); m.impl("_reshape_alias", _reshape_alias_batching_rule); @@ -1348,8 +1141,8 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { m.impl("clamp_max", clamp_max_batching_rule); // unary pointwise, out-of-place, no additional arguments. -#define UNARY_POINTWISE(op) \ - m.impl(#op, unwrap_and_call); +#define UNARY_POINTWISE(op) m.impl(#op, \ + unwrap_and_call); UNARY_POINTWISE(abs); UNARY_POINTWISE(acos); UNARY_POINTWISE(asin); @@ -1382,42 +1175,32 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { UNARY_POINTWISE(tanh); UNARY_POINTWISE(trunc); #undef UNARY_POINTWISE -#define TO_BATCHING_RULE(name, ...) \ - { \ - using to_type = Tensor (Tensor::*)(__VA_ARGS__) const; \ - m.impl(name, unwrap_and_call_method); \ +#define TO_BATCHING_RULE(name, ...) \ + { \ + using to_type = Tensor(Tensor::*)(__VA_ARGS__) const; \ + m.impl(name, unwrap_and_call_method< \ + to_type, &Tensor::to, __VA_ARGS__>);\ } - TO_BATCHING_RULE( - "to.device", Device, ScalarType, bool, bool, optional) + TO_BATCHING_RULE("to.device", Device, ScalarType, bool, bool, optional) TO_BATCHING_RULE("to.dtype", ScalarType, bool, bool, optional) - TO_BATCHING_RULE( - "to.other", const Tensor&, bool, bool, optional) + TO_BATCHING_RULE("to.other", const Tensor&, bool, bool, optional) m.impl("to.dtype_layout", to_dtype_layout_batching_rule); #undef TO_BATCHING_RULE m.impl("clone", clone_batching_rule); - using TensorTensorScalarType = - Tensor (*)(const Tensor&, const Tensor&, const Scalar&); + using TensorTensorScalarType = Tensor (*)(const Tensor&, const Tensor&, const Scalar&); using TensorTensorType = Tensor (*)(const Tensor&, const Tensor&); using TensorScalarType = Tensor (*)(const Tensor&, const Scalar&); -#define BINARY_POINTWISE(op) \ - m.impl( \ - #op ".Tensor", \ - binary_pointwise_batching_rule); \ - m.impl( \ - #op ".Scalar", \ - unwrap_and_call); -#define BINARY_POINTWISE_VA(op, ...) \ - { \ +#define BINARY_POINTWISE(op) \ + m.impl(#op".Tensor", binary_pointwise_batching_rule); \ + m.impl(#op".Scalar", unwrap_and_call); +#define BINARY_POINTWISE_VA(op, ...) \ + { \ using Binop = Tensor (*)(const Tensor&, const Tensor&, __VA_ARGS__); \ - using Unop = Tensor (*)(const Tensor&, const Scalar&, __VA_ARGS__); \ - m.impl( \ - #op ".Tensor", \ - binary_pointwise_batching_rule); \ - m.impl( \ - #op ".Scalar", \ - unwrap_and_call); \ + using Unop = Tensor (*)(const Tensor&, const Scalar&, __VA_ARGS__); \ + m.impl(#op".Tensor", binary_pointwise_batching_rule); \ + m.impl(#op".Scalar", unwrap_and_call); \ } BINARY_POINTWISE_VA(add, const Scalar&); @@ -1426,37 +1209,18 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { BINARY_POINTWISE(mul); BINARY_POINTWISE(div); { - using Binop = Tensor (*)( - const Tensor&, const Tensor&, c10::optional); - using Unop = Tensor (*)( - const Tensor&, const Scalar&, c10::optional); - m.impl( - "div.Tensor_mode", - binary_pointwise_batching_rule< - Binop, - at::div, - c10::optional>); - m.impl( - "div.Scalar_mode", - unwrap_and_call< - Unop, - at::div, - const Scalar&, - c10::optional>); + using Binop = Tensor (*)(const Tensor&, const Tensor&, c10::optional); + using Unop = Tensor (*)(const Tensor&, const Scalar&, c10::optional); + m.impl("div.Tensor_mode", binary_pointwise_batching_rule>); + m.impl("div.Scalar_mode", unwrap_and_call>); } // at::pow has three out-of-place overloads - m.impl( - "pow.Tensor_Tensor", - binary_pointwise_batching_rule); - m.impl( - "pow.Tensor_Scalar", - unwrap_and_call); + m.impl("pow.Tensor_Tensor", binary_pointwise_batching_rule); + m.impl("pow.Tensor_Scalar", unwrap_and_call); m.impl("pow.Scalar", pow_scalar_Tensor_batching_rule); - m.impl( - "sigmoid_backward", - binary_pointwise_batching_rule); + m.impl("sigmoid_backward", binary_pointwise_batching_rule); m.impl( "threshold_backward", binary_pointwise_batching_rule< @@ -1467,28 +1231,17 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { // for at::result_type, call the native::result_type implementation. // We don't have to do anything special because native::result_type operates // on the logical shape of the tensors. - m.impl( - "result_type.Tensor", - static_cast( - native::result_type)); - m.impl( - "result_type.Scalar", - static_cast( - native::result_type)); - m.impl( - "result_type.Scalar_Tensor", - static_cast( - native::result_type)); - m.impl( - "result_type.Scalar_Scalar", - static_cast( - native::result_type)); + m.impl("result_type.Tensor", static_cast(native::result_type)); + m.impl("result_type.Scalar", static_cast(native::result_type)); + m.impl("result_type.Scalar_Tensor", static_cast(native::result_type)); + m.impl("result_type.Scalar_Scalar", static_cast(native::result_type)); #undef BINARY_POINTWISE_VA #undef BINARY_POINTWISE -#define TRIVIAL_OP(op) \ - m.impl(#op, unwrap_and_call); + +#define TRIVIAL_OP(op) m.impl(#op, \ + unwrap_and_call); // complex number view operators TRIVIAL_OP(imag) TRIVIAL_OP(real); @@ -1524,13 +1277,9 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) { m.impl("contiguous", contiguous_batching_rule); // Comparison ops -#define COMPARISON_POINTWISE(op) \ - m.impl( \ - #op ".Tensor", \ - comparison_pointwise_batching_rule); \ - m.impl( \ - #op ".Scalar", \ - unwrap_and_call); +#define COMPARISON_POINTWISE(op) \ + m.impl(#op".Tensor", comparison_pointwise_batching_rule); \ + m.impl(#op".Scalar", unwrap_and_call); COMPARISON_POINTWISE(eq); COMPARISON_POINTWISE(gt);