-
Notifications
You must be signed in to change notification settings - Fork 25.1k
Description
🚀 The feature, motivation and pitch
In the case of 4 dimensions, we should additionally check that the size(1)==1. This is because for us to use torch.bmm for segment_matmul, we split a tensor of shape (x, y, z) and we get a list of x tensors w/ shape (1,y, z). If I make a nested tensor out of these and put it into bmm it complains that it 'must be a 3D tensor' from these checks
example:
>>> torch.bmm(torch.nested.nested_tensor(list(torch.randn((10,5,5)).split(1))),torch.nested.nested_tensor(list(torch.ran
dn((10,5,5)).split(1))))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: batch1 must be a 3D tensor
This is problematic because this is essential for our segment_matmul implementation we want the following behavior:
r"""Performs dense-dense matrix multiplication according to segments along
the first dimension of :obj:`inputs` as given by :obj:`ptr`, utilizing
dedicated kernels that effectively parallelize over groups.
.. code-block:: python
inputs = torch.randn(8, 16)
ptr = torch.tensor([0, 5, 8])
other = torch.randn(2, 16, 32)
out = pyg_lib.ops.segment_matmul(inputs, ptr, other)
assert out.size() == (8, 32)
assert out[0:5] == inputs[0:5] @ other[0]
assert out[5:8] == inputs[5:8] @ other[1]
Args:
input (torch.Tensor): The left operand 2D matrix of shape
:obj:`[N, K]`.
ptr (torch.Tensor): Compressed vector of shape :obj:`[B + 1]`, holding
the boundaries of segments.
other (torch.Tensor): The right operand 3D tensor of shape
:obj:`[B, K, M]`.
Returns:
torch.Tensor: The 2D output matrix of shape :obj:`[N, M]`.
"""
Here is the code I am using to allow torch>=1.14 to use torch.bmm instead of our custom kernel:
at::Tensor segment_matmul_kernel(const at::Tensor& input,
const at::Tensor& ptr,
const at::Tensor& other) {
const auto size = pyg::utils::size_from_ptr(ptr).cpu();
const auto sizes = at::IntArrayRef(size.data_ptr<int64_t>(), size.numel());
#if TORCH_VERSION_MINOR >= 14 or TORCH_VERSION_MAJOR > 1
auto input_nested = at::_nested_tensor_from_tensor_list(
input.contiguous().split_with_sizes(/*split_size=*/sizes, /*dim=*/0));
auto other_nested = at::_nested_tensor_from_tensor_list(
other.contiguous().split(/*split_size=*/1, /*dim=*/0))
.squeeze(1);
auto out_nested = at::native::bmm_nested_cuda(input_nested, other_nested);
auto out = at::cat(out_nested.contiguous().unbind());
#else
const auto out = input.new_empty({input.size(0), other.size(-1)});
grouped_matmul_out_kernel(
input.contiguous().split_with_sizes(/*split_size=*/sizes, /*dim=*/0),
other.contiguous().split(/*split_size=*/1, /*dim=*/0),
out.split_with_sizes(/*split_size=*/sizes, /*dim=*/0));
#endif
return out;
}
In the case of torch>=1.14 we trigger a batch2 must be a 3D tensor
error unless I call squeeze(1)
on the output other.contiguous().split(/*split_size=*/1, /*dim=*/0))
. However this is a problem since when the nested tensor consists of many subtensors this slows things down dramatically compared to the torch < 1.14 pathway. This is because our custom grouped_matmul kernel functions even when other has shape (x, 1, y, z). This should be possible for the native implementation too and it would make it feasible for us to use. As you can see from this graph, even though our custom kernel is not faster than torch.bmm we end up with a much slower op using torch 1.14 vs torch < 1.14
Alternatives
I have tried implementing in python and cpp and in both cases using the new method is slower even though our kernel itself is not faster
Additional context
No response
cc @cpuhrsch @jbschlosser @bhosmer @drisspg @mikaylagawarecki