Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 32 additions & 16 deletions aten/src/ATen/BatchedFallback.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <ATen/BatchedFallback.h>
#include <ATen/MatrixRef.h>
#include <ATen/VmapTransforms.h>

namespace at {
Expand Down Expand Up @@ -55,9 +56,9 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta
TORCH_CHECK(areAllReturnsTensors(schema) && !areAnyArgumentsTensorList(schema),
"Batching rule not implemented for ", schema, ". ",
"We could not generate a fallback.");
TORCH_CHECK(num_returns == 1,
TORCH_CHECK(num_returns >= 1,
"Batching rule not implemented for ", schema, ". ",
"We do not yet support operations with multiple returns.");
"The fallback path does not support operations with no returns.");
TORCH_WARN("Batching rule not implemented for ", schema, " falling back "
"to slow (for loop and stack) implementation");

Expand Down Expand Up @@ -105,8 +106,14 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta
// Strategy: For each batch, we are going to push slices (where applicable)
// of the arguments onto `stack`, call `op`, and store the result in
// `output_shards`.
std::vector<Tensor> output_shards;
output_shards.reserve(num_batches);
//
// NOTE: [Output shards layout]
// Assume that the operator has three outputs: a, b, c.
// The layout of output_shards is as follows:
// [ a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3]
// This is so that we can call at::stack([a0...a3]), at::stack([b0...b3])
// more easily in the next step.
std::vector<Tensor> output_shards(num_batches * num_returns);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MatrixRef may be of interest here!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the indexing code to use MatrixRef! It's a nice way to abstract that behavior, thanks for pointing it out


for (int64_t linear_idx = 0; linear_idx < num_batches; ++linear_idx) {
auto index = computeIndex(linear_idx, batch_sizes);
Expand All @@ -130,21 +137,30 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta

op.callBoxed(stack);

// We assume there is a single tensor return
output_shards.emplace_back(torch::jit::pop(stack).toTensor());
// Store the result into `output_shards`. See NOTE: [Output shards layout]
// to learn about the details of how we store the shards.
const auto returns = torch::jit::last(stack, num_returns);
for (int64_t return_idx = 0; return_idx < returns.size(); ++return_idx) {
output_shards[num_batches * return_idx + linear_idx] = returns[return_idx].toTensor();
}
torch::jit::drop(stack, num_returns);
}

// Stack the tensors together to form the result.
auto flat_output = at::stack(output_shards);
VmapDimVector output_sizes(batch_sizes);
output_sizes.insert(
output_sizes.end(),
flat_output.sizes().begin() + 1,
flat_output.sizes().end());
// For each output Tensor, stack the shards of the tensor together to form a return
torch::jit::drop(stack, num_arguments);
torch::jit::push(
stack,
input_physical_views.front().newLogicalFromPhysical(flat_output.view(output_sizes)));
auto output_shards_chunks = MatrixRef<Tensor>(output_shards, num_batches);
for (int64_t return_idx = 0; return_idx < num_returns; ++return_idx) {
auto shards = output_shards_chunks[return_idx];
auto flat_output = at::stack(shards);
VmapDimVector output_sizes(batch_sizes);
output_sizes.insert(
output_sizes.end(),
flat_output.sizes().begin() + 1,
flat_output.sizes().end());
torch::jit::push(
stack,
input_physical_views.front().newLogicalFromPhysical(flat_output.view(output_sizes)));
}
}

} // namespace at
59 changes: 40 additions & 19 deletions test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,6 @@ def test_unsupported_op_err_msg(self):
with self.assertRaisesRegex(RuntimeError, "doesn't work on in-place or view ops"):
vmap(torch.as_strided, (0, None, None))(tensor, [2, 3], [0, 0])

# We don't support multiple returns yet
with self.assertRaisesRegex(RuntimeError, 'multiple returns'):
vmap(torch.var_mean)(tensor)

# The fallback doesn't support TensorList
with self.assertRaisesRegex(RuntimeError, 'Batching rule not implemented'):
vmap(lambda t: torch.stack([t]))(tensor)
Expand Down Expand Up @@ -442,20 +438,21 @@ def foo(x):
vmap(foo, in_dims=(0,))(torch.randn(2, 3))
vmap(foo, in_dims=(1,))(torch.randn(2, 3))

def _assert_uses_vmap_fallback(self, vmap_args, inputs):
with warnings.catch_warnings(record=True) as wa:
result = vmap(*vmap_args)(*inputs)
self.assertEqual(len(wa), 2)
self.assertRegex(str(wa[-1].message),
r'falling back to slow \(for loop and stack\) implementation')

def test_fallback_sub(self):
# NB: One day we will implement a batching rule for torch.sub.
# If/when we do, this test should be replaced to test the fallback
# path on another operator to avoid bitrot.
x = torch.randn(5, 7, 11)
y = torch.randn(5, 7, 11)

# Test the fallback path raises a warning
with warnings.catch_warnings(record=True) as wa:
result = vmap(torch.sub)(x, y)
self.assertEqual(len(wa), 2)
self.assertRegex(str(wa[-1].message),
r'falling back to slow \(for loop and stack\) implementation')
self.assertEqual(result, x - y)
self._assert_uses_vmap_fallback((torch.sub,), (x, y))

# fallback on torch.sub
x = torch.randn(7, 11, 5)
Expand Down Expand Up @@ -486,18 +483,42 @@ def run_test(batch_size):
index = torch.tensor([0, 4, 2])
values = torch.randn(B0, 3, 13)

with warnings.catch_warnings(record=True) as wa:
result = vmap(torch.index_add, (0, None, None, 0))(x, dim, index, values)
self.assertEqual(len(wa), 2)
self.assertRegex(str(wa[-1].message),
r'falling back to slow \(for loop and stack\) implementation')
expected = torch.index_add(
x, dim + 1, index, values.view(B0, 3, 1, 13))
self.assertEqual(result, expected)
self._assert_uses_vmap_fallback((torch.index_add, (0, None, None, 0)), (x, dim, index, values))

result = vmap(torch.index_add, (0, None, None, 0))(x, dim, index, values)
expected = torch.index_add(
x, dim + 1, index, values.view(B0, 3, 1, 13))
self.assertEqual(result, expected)

run_test(batch_size=5)
run_test(batch_size=1237)

def test_fallback_multiple_returns(self):
# NB: One day we will implement a batching rule for torch.var_mean
# If/when we do, this test should be replaced to test the fallback
# path on another operator to avoid bitrot.
B0, B1, B2 = 2, 3, 1237
tensor = torch.randn(B0, 10)

self._assert_uses_vmap_fallback((torch.var_mean,), (tensor,))

# fallback correctness on torch.var_mean
result = vmap(torch.var_mean)(tensor)
expected = torch.var_mean(tensor, dim=1)
self.assertEqual(result, expected)

# nested vmap
tensor = torch.randn(B0, B1, 10)
result = vmap(vmap(torch.var_mean))(tensor)
expected = torch.var_mean(tensor, dim=2)
self.assertEqual(result, expected)

# big batch size, nested vmap
tensor = torch.randn(B0, B1, B2, 10)
result = vmap(vmap(vmap(torch.var_mean)))(tensor)
expected = torch.var_mean(tensor, dim=3)
self.assertEqual(result, expected)


def slice_inputs(inputs, bdims, i):
result = []
Expand Down