Skip to content

allow torch.bmm on nested_tensors of dim == 3 or (dim==4 and size(1)==1) #88519

@puririshi98

Description

@puririshi98

🚀 The feature, motivation and pitch

In the case of 4 dimensions, we should additionally check that the size(1)==1. This is because for us to use torch.bmm for segment_matmul, we split a tensor of shape (x, y, z) and we get a list of x tensors w/ shape (1,y, z). If I make a nested tensor out of these and put it into bmm it complains that it 'must be a 3D tensor' from these checks

example:

>>> torch.bmm(torch.nested.nested_tensor(list(torch.randn((10,5,5)).split(1))),torch.nested.nested_tensor(list(torch.ran
dn((10,5,5)).split(1))))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: batch1 must be a 3D tensor

This is problematic because this is essential for our segment_matmul implementation we want the following behavior:

    r"""Performs dense-dense matrix multiplication according to segments along
    the first dimension of :obj:`inputs` as given by :obj:`ptr`, utilizing
    dedicated kernels that effectively parallelize over groups.

    .. code-block:: python

        inputs = torch.randn(8, 16)
        ptr = torch.tensor([0, 5, 8])
        other = torch.randn(2, 16, 32)

        out = pyg_lib.ops.segment_matmul(inputs, ptr, other)
        assert out.size() == (8, 32)
        assert out[0:5] == inputs[0:5] @ other[0]
        assert out[5:8] == inputs[5:8] @ other[1]

    Args:
        input (torch.Tensor): The left operand 2D matrix of shape
            :obj:`[N, K]`.
        ptr (torch.Tensor): Compressed vector of shape :obj:`[B + 1]`, holding
            the boundaries of segments.
        other (torch.Tensor): The right operand 3D tensor of shape
            :obj:`[B, K, M]`.

    Returns:
        torch.Tensor: The 2D output matrix of shape :obj:`[N, M]`.
    """

Here is the code I am using to allow torch>=1.14 to use torch.bmm instead of our custom kernel:

at::Tensor segment_matmul_kernel(const at::Tensor& input,
                                 const at::Tensor& ptr,
                                 const at::Tensor& other) {
  const auto size = pyg::utils::size_from_ptr(ptr).cpu();
  const auto sizes = at::IntArrayRef(size.data_ptr<int64_t>(), size.numel());
#if TORCH_VERSION_MINOR >= 14 or TORCH_VERSION_MAJOR > 1
  auto input_nested = at::_nested_tensor_from_tensor_list(
      input.contiguous().split_with_sizes(/*split_size=*/sizes, /*dim=*/0));
  auto other_nested = at::_nested_tensor_from_tensor_list(
                          other.contiguous().split(/*split_size=*/1, /*dim=*/0))
                          .squeeze(1);
  auto out_nested = at::native::bmm_nested_cuda(input_nested, other_nested);
  auto out = at::cat(out_nested.contiguous().unbind());
#else
  const auto out = input.new_empty({input.size(0), other.size(-1)});
  grouped_matmul_out_kernel(
      input.contiguous().split_with_sizes(/*split_size=*/sizes, /*dim=*/0),
      other.contiguous().split(/*split_size=*/1, /*dim=*/0),
      out.split_with_sizes(/*split_size=*/sizes, /*dim=*/0));
#endif

  return out;
}

In the case of torch>=1.14 we trigger a batch2 must be a 3D tensor error unless I call squeeze(1) on the output other.contiguous().split(/*split_size=*/1, /*dim=*/0)) . However this is a problem since when the nested tensor consists of many subtensors this slows things down dramatically compared to the torch < 1.14 pathway. This is because our custom grouped_matmul kernel functions even when other has shape (x, 1, y, z). This should be possible for the native implementation too and it would make it feasible for us to use. As you can see from this graph, even though our custom kernel is not faster than torch.bmm we end up with a much slower op using torch 1.14 vs torch < 1.14
image

Alternatives

I have tried implementing in python and cpp and in both cases using the new method is slower even though our kernel itself is not faster

Additional context

No response

cc @cpuhrsch @jbschlosser @bhosmer @drisspg @mikaylagawarecki

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: nestedtensorNestedTensor tag see issue #25032triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions