Skip to content

Commit

Permalink
Update on "Remove .impl_UNBOXED() and functionalities associated with…
Browse files Browse the repository at this point in the history
… it"

Since all ops are c10-full, we can remove .impl_UNBOXED now.
This also removes the ability of KernelFunction or CppFunction to store unboxedOnly kernels.

Differential Revision: [D25490225](https://our.internmc.facebook.com/intern/diff/D25490225/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D25490225/)!

[ghstack-poisoned]
  • Loading branch information
smessmer committed Jan 6, 2021
2 parents 416581d + 981282a commit 9c99d93
Show file tree
Hide file tree
Showing 23 changed files with 362 additions and 180 deletions.
11 changes: 10 additions & 1 deletion aten/src/ATen/VmapTransforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,17 @@ struct VmapPhysicalToLogicalMap;
// The levels bitset specifies which vmap levels correspond to the batch
// dimensions at the front of the tensor. In particular, the number of set bits
// corresponds to the number of batch dimensions on `tensor` and the rightmost
// bit of `levels` specifies the minimum number of nested vmaps we are in at
// bit of `levels` specifies the maximum number of nested vmaps we are in at
// this point in time.
// For example, given:
// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3})
//
// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less
// than or equal to 3.
// bitset: 010100
// ^
// |
// levels: 012345
struct TORCH_API VmapPhysicalView {
VmapPhysicalView(Tensor&& tensor, std::bitset<kVmapNumLevels> levels)
: levels_(levels), tensor_(tensor) {
Expand Down
17 changes: 7 additions & 10 deletions aten/src/ATen/native/Distributions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ DEFINE_DISPATCH(bernoulli_tensor_stub);
DEFINE_DISPATCH(bernoulli_scalar_stub);
DEFINE_DISPATCH(cauchy_stub);
DEFINE_DISPATCH(exponential_stub);
DEFINE_DISPATCH(multinomial_stub);
DEFINE_DISPATCH(multinomial_with_replacement_stub);
DEFINE_DISPATCH(geometric_stub);
DEFINE_DISPATCH(log_normal_stub);
DEFINE_DISPATCH(uniform_stub);
Expand Down Expand Up @@ -497,8 +497,10 @@ Tensor& multinomial_out(
// Reference:
// https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503
// Half is not supported on CPU.
if (!with_replacement &&
!(self.device().is_cpu() && self.scalar_type() == ScalarType::Half)) {
TORCH_CHECK(
!(self.device().is_cpu() && self.scalar_type() == ScalarType::Half),
"multinomial is not implemented for half on CPU");
if (!with_replacement) {
// Sanity checks on `self`.
auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)).item();
TORCH_CHECK(
Expand Down Expand Up @@ -537,13 +539,8 @@ Tensor& multinomial_out(
return result;
}

multinomial_stub(
result.device().type(),
result,
self,
n_sample,
with_replacement,
gen);
multinomial_with_replacement_stub(
result.device().type(), result, self, n_sample, gen);
return result;
}

Expand Down
44 changes: 25 additions & 19 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,23 +98,25 @@ static inline void check_cat_shape_except_dim(const Tensor & first, const Tensor
if (dim == dimension) {
continue;
}
int64_t first_dim_size = first.size(dim);
int64_t second_dim_size = second.size(dim);
int64_t first_dim_size = first.sizes()[dim];
int64_t second_dim_size = second.sizes()[dim];
TORCH_CHECK(first_dim_size == second_dim_size, "Sizes of tensors must match except in dimension ",
dimension, ". Got ", first_dim_size, " and ", second_dim_size, " in dimension ", dim,
" (The offending index is ", index, ")");
}
}

static bool should_skip(const Tensor& t) {
return t.numel() == 0 && t.dim() == 1;
}

Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
// previously, size [0] tensors were the only possible empty tensors; thus, it wasn't possible
// to cat empty tensors unless all the other tensors were 1-dimensional, so we allowed these tensors
// to be "skipped". We maintain this behavior for backwards compatibility, but only for this specific
// size (i.e. other empty sizes are not skipped).
// FIXME: warn if this is the case
bool allSkipped = true;

bool allContiguous = true;
Tensor notSkippedTensor;

// Inputs cannot alias the output tensor
for (int64_t i = 0; i < tensors.size(); i++) {
Expand All @@ -126,19 +128,23 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
}
at::assert_no_internal_overlap(result);

auto should_skip = [](const Tensor& t) { return t.numel() == 0 && t.dim() == 1; };
for (auto const &tensor : tensors) {
if (should_skip(tensor)) {
continue;
const Tensor* pnotSkippedTensor = [](TensorList tensors) -> const Tensor* {
for (auto const &tensor : tensors) {
if (should_skip(tensor)) {
continue;
}
// we've found a non-empty tensor
return &tensor;
}
// we've found a non-empty tensor
allSkipped = false;
notSkippedTensor = tensor;
break;
}
if (allSkipped) {
return nullptr;
}(tensors);

if (!pnotSkippedTensor) {
// FIXME: warn if this is the case -- see comment about skipped
// tensors at top of function.
return result;
}
const Tensor& notSkippedTensor = *pnotSkippedTensor;

TORCH_CHECK(tensors.size() > 0, "expected a non-empty list of Tensors");
TORCH_CHECK(dim <= notSkippedTensor.dim(), "dimension ", dim, "out of range");
Expand All @@ -161,7 +167,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
continue;
}
check_cat_shape_except_dim(notSkippedTensor, tensor, dim, i);
cat_dim_size += tensor.size(dim);
cat_dim_size += tensor.sizes()[dim];

if (!tensor.is_contiguous(first_tensor_mem_format)) {
allContiguous = false;
Expand Down Expand Up @@ -196,8 +202,8 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
if (reuse_iterator &&
result.is_contiguous(first_tensor_mem_format) &&
no_type_promotion) {
auto source_slice = notSkippedTensor;
auto slice_dim_size = source_slice.size(dim);
const auto& source_slice = notSkippedTensor;
auto slice_dim_size = source_slice.sizes()[dim];
auto result_slice = result.narrow(dim, 0, slice_dim_size);
auto result_slice_data = result_slice.data_ptr();
auto result_stride_bytes = result.stride(dim) * elementSize(result.scalar_type());
Expand Down Expand Up @@ -226,7 +232,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
if (should_skip(tensor)) {
continue;
}
auto slice_dim_size = tensor.size(dim);
auto slice_dim_size = tensor.sizes()[dim];
auto result_slice = result.narrow(dim, offset, slice_dim_size);

auto iter = TensorIteratorConfig()
Expand Down
4 changes: 3 additions & 1 deletion aten/src/ATen/native/UnaryOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ DECLARE_DISPATCH(void(*)(TensorIterator&, c10::optional<Generator>), random_full
DECLARE_DISPATCH(void(*)(TensorIterator&, c10::optional<Generator>), random_stub);
DECLARE_DISPATCH(void(*)(TensorIterator&, const int64_t), polygamma_stub);
DECLARE_DISPATCH(void(*)(TensorIterator&, Scalar a, Scalar b), clamp_stub);
DECLARE_DISPATCH(void(*)(Tensor&, const Tensor&, int64_t, bool, c10::optional<Generator>), multinomial_stub);
DECLARE_DISPATCH(
void (*)(Tensor&, const Tensor&, int64_t, c10::optional<Generator>),
multinomial_with_replacement_stub);
DECLARE_DISPATCH(
void (*)(
TensorIterator&,
Expand Down
8 changes: 5 additions & 3 deletions aten/src/ATen/native/cpu/CatKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,20 @@ struct InputMeta {

InputMeta(const Tensor& t, int64_t dim, int64_t inner)
: data_ptr(t.data_ptr())
, inner_size(t.size(dim) * inner) {}
, inner_size(t.sizes()[dim] * inner) {}
};

template <typename scalar_t>
void cat_serial_kernel_impl(Tensor& result, TensorList tensors, int64_t dim) {
int64_t outer = result.numel() / (result.size(dim) * result.stride(dim));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
dim >= 0 && dim < result.dim(), "dim out of range in cat_serial_kernel_impl");
int64_t outer = result.numel() / (result.sizes()[dim] * result.strides()[dim]);
scalar_t* result_data = result.data_ptr<scalar_t>();
int64_t ninputs = tensors.size();
std::vector<InputMeta> inputs;
inputs.reserve(ninputs);
for (auto const &tensor : tensors) {
inputs.emplace_back(tensor, dim, result.stride(dim));
inputs.emplace_back(tensor, dim, result.strides()[dim]);
}

using Vec = vec256::Vec256<scalar_t>;
Expand Down
50 changes: 15 additions & 35 deletions aten/src/ATen/native/cpu/MultinomialKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@ namespace at {
namespace native {
namespace {

template<typename scalar_t>
void multinomial_apply(Tensor& result, const Tensor& self, const int64_t n_sample, const bool with_replacement, c10::optional<Generator> generator) {
template <typename scalar_t>
void multinomial_with_replacement_apply(
Tensor& result,
const Tensor& self,
const int64_t n_sample,
c10::optional<Generator> generator) {
auto gen = get_generator_or_default<CPUGeneratorImpl>(generator, detail::getDefaultCPUGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
Expand Down Expand Up @@ -61,8 +65,6 @@ void multinomial_apply(Tensor& result, const Tensor& self, const int64_t n_sampl
}

TORCH_CHECK(sum > 0, "invalid multinomial distribution (sum of probabilities <= 0)");
TORCH_CHECK(with_replacement || (n_categories - n_zeros >= n_sample),
"invalid multinomial distribution (with replacement=False, not enough non-negative category to sample)");

/* normalize cumulative probability distribution so that last val is 1
i.e. doesn't assume original self row sums to one */
Expand Down Expand Up @@ -100,45 +102,23 @@ void multinomial_apply(Tensor& result, const Tensor& self, const int64_t n_sampl

/* store in result tensor (will be incremented for lua compat by wrapper) */
result_ptr[i * result_dist_stride_0 + j * result_dist_stride_1] = sample_idx;

/* Once a sample is drawn, it cannot be drawn again. ie sample without replacement */
if (!with_replacement && j < n_sample - 1) {
/* update cumulative distribution so that sample cannot be drawn again */
scalar_t diff;
scalar_t new_val = 0;
scalar_t sum;

if (sample_idx != 0) {
new_val = cum_dist_ptr[(sample_idx - 1) * cum_dist_stride_0];
}
/* marginal cumulative mass (i.e. original probability) of sample */
diff = cum_dist_ptr[sample_idx * cum_dist_stride_0] - new_val;
/* new sum of marginals is not one anymore... */
sum = 1.0 - diff;
for (int64_t k = 0; k < n_categories; k++) {
new_val = cum_dist_ptr[k * cum_dist_stride_0];
if (k >= sample_idx) {
/* remove sampled probability mass from later cumulative probabilities */
new_val -= diff;
}
/* make total marginals sum to one */
new_val /= sum;
cum_dist_ptr[k * cum_dist_stride_0] = new_val;
}
}
}
}
}

static void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n_sample, const bool with_replacement, c10::optional<Generator> gen) {
static void multinomial_with_replacement_kernel_impl(
Tensor& result,
const Tensor& self,
const int64_t n_sample,
c10::optional<Generator> gen) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "multinomial", [&] {
multinomial_apply<scalar_t>(result, self, n_sample, with_replacement, gen);
multinomial_with_replacement_apply<scalar_t>(result, self, n_sample, gen);
});
}

}

REGISTER_DISPATCH(multinomial_stub, &multinomial_kernel_impl);

REGISTER_DISPATCH(
multinomial_with_replacement_stub,
&multinomial_with_replacement_kernel_impl);
}
}
21 changes: 18 additions & 3 deletions aten/src/ATen/native/cuda/Dropout.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo<scalar_t, IndexType> a,

accscalar_t pinv = accscalar_t(1)/p;

// Helps align the total number of times curand_uniform4 is called by each thread for the same totalElements
// in the vec=2 and vec=4 cases.
bool gridxvec_loop_state = 0;

float4 rand;

// Note: Vectorized loads means we'll stride each thread by an additional VEC factor, as we'll load VEC elements at a time
for (IndexType linearIndex = idx * VEC;
linearIndex < totalElements;
Expand All @@ -69,12 +75,21 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo<scalar_t, IndexType> a,
//curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything
// Note: need a new set of random values per 4 elements -- we'll handle VEC elements in this thread, so need ceil(VEC / 4)
// sets of rand.
float4 rand = curand_uniform4(&state);
if ((VEC == 4) || (gridxvec_loop_state == 0)) {
rand = curand_uniform4(&state);
} else {
// sets up the last two values we generated last iteration to be used this iteration.
rand.x = rand.z;
rand.y = rand.w;
gridxvec_loop_state ^= 1;
}

rand.x = rand.x < p;
rand.y = rand.y < p;
rand.z = rand.z < p;
rand.w = rand.w < p;
if (VEC == 4) {
rand.z = rand.z < p;
rand.w = rand.w < p;
}

// Note: We explicitly check for is_contiguous() before launching the vectorized kernel
// and replace IndexToOffset call with linearIndex to allow vectorization of NHWC (or other)
Expand Down
13 changes: 8 additions & 5 deletions aten/src/ATen/native/cuda/MultinomialKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,11 @@ sampleMultinomialOnce(int64_t* dest,
}
}
void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n_sample, const bool with_replacement, c10::optional<Generator> generator) {
void multinomial_with_replacement_kernel_impl(
Tensor& result,
const Tensor& self,
const int64_t n_sample,
c10::optional<Generator> generator) {
auto gen = get_generator_or_default<CUDAGeneratorImpl>(generator, cuda::detail::getDefaultCUDAGenerator());
int inputSize = self.dim();
Expand Down Expand Up @@ -371,7 +375,6 @@ void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n
PhiloxCudaState rng_engine_inputs;
if (with_replacement) {
// Binary search is warp divergent (so effectively we're running
// with just a single thread), but for better utilization,
// we need each block to have at least 4 warps.
Expand Down Expand Up @@ -402,7 +405,6 @@ void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n
prefixSum.data_ptr<scalar_t>(),
normDist.data_ptr<scalar_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
});
Expand All @@ -412,6 +414,7 @@ void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n
}
}
REGISTER_DISPATCH(multinomial_stub, &multinomial_kernel_impl);
REGISTER_DISPATCH(
multinomial_with_replacement_stub,
&multinomial_with_replacement_kernel_impl);
}}

0 comments on commit 9c99d93

Please sign in to comment.