Skip to content

Commit

Permalink
Move group_index_select_dim0_gpu meta impl to python & remove CPU syn…
Browse files Browse the repository at this point in the history
…c in GPU kernel (#2573)

Summary:
Pull Request resolved: #2573

As title

Reviewed By: sryap

Differential Revision: D57135334

fbshipit-source-id: 6d0f669ef86fa339355e319b04f20a38e9843e7f
  • Loading branch information
williamwen42 authored and facebook-github-bot committed May 14, 2024
1 parent af97deb commit 8f5eabc
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 69 deletions.
61 changes: 61 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict

import math
from typing import Callable, List, Optional, Tuple

import torch
Expand Down Expand Up @@ -579,3 +580,63 @@ def bounds_check_indices_abstract(
from the original function `fbgemm::bounds_check_indices`
"""
return


@impl_abstract("fbgemm::group_index_select_dim0_gpu_impl")
def group_index_select_dim0_gpu_impl_abstract(
inputs: List[torch.Tensor], group_size: int
) -> List[torch.Tensor]:
"""
Calculate output shapes for group_index_select_dim0_gpu_impl
without the actual data.
"""
indices_group = inputs[:group_size]
input_group = inputs[group_size:]
torch._check(len(input_group) == group_size)

ret = []
for i in range(group_size):
size = list(input_group[i].size())
ret.append(input_group[i].new_empty([indices_group[i].size(0)] + size[1:]))

# divide by 2 since sizeof(int64_t) / sizeof(int32_t) = 2
args_tensor_numel = 4 * group_size + 1 + int(math.ceil(group_size / 2))

ret.append(
# sizeof(int64_t) = 8, torch.uint8 = at::kByte
input_group[0].new_empty(
args_tensor_numel * 8, dtype=torch.uint8, pin_memory=True
)
)

ret.append(torch.zeros(5, dtype=torch.int64, device="cpu"))

return ret


@impl_abstract("fbgemm::group_index_select_dim0_gpu_backward")
def group_index_select_dim0_gpu_backward_abstract(
all_inputs: List[torch.Tensor], output_shape_group_ref: List[torch.SymInt]
) -> List[torch.Tensor]:
"""
Calculate output shapes for group_index_select_dim0_gpu_backward
without the actual data.
"""
torch._check(len(all_inputs) > 3)
group_size = (len(all_inputs) - 3) // 2
ret = []

# indices
for _ in range(group_size):
ret.append(all_inputs[0].new_empty(0))

# inputs
output_dim = len(output_shape_group_ref) // group_size
for i in range(group_size):
ret.append(
all_inputs[0].new_empty(
output_shape_group_ref[i * output_dim : (i + 1) * output_dim]
)
)

return ret
25 changes: 25 additions & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2941,6 +2941,20 @@ torch::autograd::variable_list group_index_select_dim0(
return output_group;
}

torch::autograd::variable_list group_index_select_dim0_gpu_impl_cpu(
at::TensorList all_indices_input,
const int64_t group_size) {
throw std::runtime_error(
"group_index_select_dim0_gpu_impl is not implemented for CPU");
}

torch::autograd::variable_list group_index_select_dim0_gpu_backward_cpu(
at::TensorList all_inputs,
c10::SymIntArrayRef output_shape_group_ref) {
throw std::runtime_error(
"group_index_select_dim0_gpu_backward is not implemented for CPU");
}

Tensor bottom_k_per_row(
const Tensor& input,
const Tensor& k_offsets,
Expand Down Expand Up @@ -3104,6 +3118,11 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"group_index_select_dim0(Tensor[] input_group, Tensor[] indices_group) -> Tensor[]",
{PT2_COMPLIANT_TAG});
// group_index_select_dim0_gpu helper functions - not defined for CPU!
m.def(
"group_index_select_dim0_gpu_impl(Tensor[] inputs, int group_size) -> Tensor[]");
m.def(
"group_index_select_dim0_gpu_backward(Tensor[] inputs, SymInt[] output_shape_group) -> Tensor[]");
// This is an one-off op to be used in split_embedding_utils.py for zipf
// generation w/o replacement along dim=-1. If requires_unique=True, find
// smallest unique k. If the number of unique elements is less than k,
Expand Down Expand Up @@ -3193,6 +3212,12 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
DISPATCH_TO_CPU("index_select_dim0", fbgemm_gpu::index_select_dim0);
DISPATCH_TO_CPU(
"group_index_select_dim0", fbgemm_gpu::group_index_select_dim0);
DISPATCH_TO_CPU(
"group_index_select_dim0_gpu_impl",
fbgemm_gpu::group_index_select_dim0_gpu_impl_cpu);
DISPATCH_TO_CPU(
"group_index_select_dim0_gpu_backward",
fbgemm_gpu::group_index_select_dim0_gpu_backward_cpu);
DISPATCH_TO_CPU("bottom_k_per_row", fbgemm_gpu::bottom_k_per_row);
}

Expand Down
69 changes: 0 additions & 69 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,6 @@ class GroupIndexSelectDim0GPUOp
at::TensorOptions().dtype(at::kLong));
TORCH_CHECK(saved_data_t.is_contiguous());
memcpy(saved_data_t.data_ptr<int64_t>(), saved_data, sizeof(saved_data));
saved_data_t = saved_data_t.to(first_input.device(), true);

group_index_select_or_add_cuda(
input_ptrs,
Expand Down Expand Up @@ -471,7 +470,6 @@ class GroupIndexSelectDim0GPUOp
all_inputs.cbegin() + group_size, all_inputs.cbegin() + 2 * group_size);

// Retrieve saved data
saved_data = saved_data.to(at::kCPU);
TORCH_CHECK(saved_data.device() == at::kCPU);
TORCH_CHECK(saved_data.is_contiguous());
int64_t* saved_data_ptr = saved_data.data_ptr<int64_t>();
Expand Down Expand Up @@ -653,57 +651,6 @@ torch::autograd::variable_list group_index_select_dim0_gpu_impl(
return GroupIndexSelectDim0GPUOp::apply(all_indices_input, group_size);
}

torch::autograd::variable_list group_index_select_dim0_gpu_impl_meta(
at::TensorList all_indices_input,
const int64_t group_size) {
auto [input_group, indices_group] =
group_index_select_dim0_unpack(all_indices_input, group_size);

int num_groups = input_group.size();
TORCH_CHECK(num_groups == (int)indices_group.size())
std::vector<Tensor> res;
for (const auto i : c10::irange(num_groups)) {
auto output_size = input_group[i].sym_sizes().vec();
output_size[0] = indices_group[i].sym_size(0);
res.push_back(at::zeros_symint(output_size, input_group[i].options()));
}
int64_t args_tensor_numel =
4 * group_size + 1 + compute_num_int64s<int32_t>(group_size);
res.push_back(at::zeros_symint(
{c10::SymInt(args_tensor_numel * sizeof(int64_t))},
at::TensorOptions().dtype(at::kByte).pinned_memory(true).device(
input_group[0].device())));
res.push_back(at::zeros_symint(
{c10::SymInt(5)},
at::TensorOptions().dtype(at::kLong).device(input_group[0].device())));
return res;
}

torch::autograd::variable_list group_index_select_dim0_gpu_backward_meta(
at::TensorList all_inputs,
c10::SymIntArrayRef output_shape_group_ref) {
TORCH_CHECK(all_inputs.size() > 3);
const auto group_size = (all_inputs.size() - 3) / 2;
std::vector<Tensor> outputs;
outputs.reserve(group_size * 2 + 1);

// grad for indices
for (int i = 0; i < (int)group_size; i++) {
outputs.push_back(
at::zeros_symint({c10::SymInt(0)}, all_inputs[0].options()));
}
// grad for inputs
const auto output_dim = output_shape_group_ref.size() / group_size;
for (int i = 0; i < (int)group_size; i++) {
outputs.push_back(at::zeros_symint(
std::vector<c10::SymInt>(
output_shape_group_ref.cbegin() + i * output_dim,
output_shape_group_ref.cbegin() + (i + 1) * output_dim),
all_inputs[0].options()));
}
return outputs;
}

torch::autograd::variable_list group_index_select_dim0_gpu(
at::TensorList input_group,
at::TensorList indices_group) {
Expand Down Expand Up @@ -739,13 +686,6 @@ torch::autograd::variable_list group_index_select_dim0_gpu(
}
} // namespace fbgemm_gpu

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"group_index_select_dim0_gpu_impl(Tensor[] inputs, int group_size) -> Tensor[]");
m.def(
"group_index_select_dim0_gpu_backward(Tensor[] inputs, SymInt[] output_shape_group) -> Tensor[]");
}

TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
DISPATCH_TO_CUDA(
"reorder_batched_ad_lengths", fbgemm_gpu::reorder_batched_ad_lengths_gpu);
Expand Down Expand Up @@ -780,15 +720,6 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
"group_index_select_dim0", fbgemm_gpu::group_index_select_dim0_gpu);
}

TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
m.impl(
"group_index_select_dim0_gpu_impl",
&fbgemm_gpu::group_index_select_dim0_gpu_impl_meta);
m.impl(
"group_index_select_dim0_gpu_backward",
&fbgemm_gpu::group_index_select_dim0_gpu_backward_meta);
}

TORCH_LIBRARY_IMPL(fbgemm, AutogradCUDA, m) {
m.impl("group_index_select_dim0", &fbgemm_gpu::group_index_select_dim0_gpu);
m.impl(
Expand Down

0 comments on commit 8f5eabc

Please sign in to comment.